diff --git a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java index 8e61e2a7f5..5a774daaa6 100644 --- a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java @@ -28,18 +28,12 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.MethodDescriptor; import io.grpc.Status; -import io.grpc.alts.internal.AltsClientOptions; -import io.grpc.alts.internal.AltsProtocolNegotiator; -import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; -import io.grpc.alts.internal.AltsTsiHandshaker; -import io.grpc.alts.internal.HandshakerServiceGrpc; -import io.grpc.alts.internal.RpcProtocolVersionsUtil; -import io.grpc.alts.internal.TsiHandshaker; -import io.grpc.alts.internal.TsiHandshakerFactory; +import io.grpc.alts.internal.AltsProtocolNegotiator.ClientAltsProtocolNegotiatorFactory; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.NettyChannelBuilder; import java.util.logging.Level; import java.util.logging.Logger; @@ -60,8 +54,6 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder targetServiceAccounts = targetServiceAccountsBuilder.build(); - final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); - TsiHandshakerFactory altsHandshakerFactory = - new TsiHandshakerFactory() { - @Override - public TsiHandshaker newHandshaker(String authority) { - AltsClientOptions handshakerOptions = - new AltsClientOptions.Builder() - .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) - .setTargetServiceAccounts(targetServiceAccounts) - .setTargetName(authority) - .build(); - return AltsTsiHandshaker.newClient( - HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions); - } - }; - return negotiatorForTest = - AltsProtocolNegotiator.createClientNegotiator( - altsHandshakerFactory, lazyHandshakerChannel); - } + ProtocolNegotiator getProtocolNegotiatorForTest() { + return new ClientAltsProtocolNegotiatorFactory( + targetServiceAccountsBuilder.build(), handshakerChannelPool) + .buildProtocolNegotiator(); } /** An implementation of {@link ClientInterceptor} that fails each call. */ diff --git a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java index 8c8d55725b..fdda8caac6 100644 --- a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java @@ -33,14 +33,7 @@ import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer.Factory; import io.grpc.ServerTransportFilter; import io.grpc.Status; -import io.grpc.alts.internal.AltsHandshakerOptions; import io.grpc.alts.internal.AltsProtocolNegotiator; -import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; -import io.grpc.alts.internal.AltsTsiHandshaker; -import io.grpc.alts.internal.HandshakerServiceGrpc; -import io.grpc.alts.internal.RpcProtocolVersionsUtil; -import io.grpc.alts.internal.TsiHandshaker; -import io.grpc.alts.internal.TsiHandshakerFactory; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import io.grpc.netty.NettyServerBuilder; @@ -192,18 +185,8 @@ public final class AltsServerBuilder extends ServerBuilder { } } - final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); delegate.protocolNegotiator( - AltsProtocolNegotiator.createServerNegotiator( - new TsiHandshakerFactory() { - @Override - public TsiHandshaker newHandshaker(String authority) { - return AltsTsiHandshaker.newServer( - HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), - new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions())); - } - }, - lazyHandshakerChannel)); + AltsProtocolNegotiator.serverAltsProtocolNegotiator(handshakerChannelPool)); return delegate.build(); } diff --git a/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java b/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java index af3376f14e..f12de57a8a 100644 --- a/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java @@ -18,23 +18,18 @@ package io.grpc.alts; import com.google.auth.oauth2.ComputeEngineCredentials; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import io.grpc.CallCredentials; import io.grpc.ForwardingChannelBuilder; import io.grpc.ManagedChannelBuilder; import io.grpc.Status; -import io.grpc.alts.internal.AltsClientOptions; -import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; -import io.grpc.alts.internal.AltsTsiHandshaker; -import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator; -import io.grpc.alts.internal.HandshakerServiceGrpc; -import io.grpc.alts.internal.RpcProtocolVersionsUtil; -import io.grpc.alts.internal.TsiHandshaker; -import io.grpc.alts.internal.TsiHandshakerFactory; +import io.grpc.alts.internal.AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory; import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.NettyChannelBuilder; import io.netty.handler.ssl.SslContext; import javax.net.ssl.SSLException; @@ -47,12 +42,21 @@ public final class ComputeEngineChannelBuilder extends ForwardingChannelBuilder { private final NettyChannelBuilder delegate; - private GoogleDefaultProtocolNegotiator negotiatorForTest; private ComputeEngineChannelBuilder(String target) { delegate = NettyChannelBuilder.forTarget(target); + SslContext sslContext; + try { + sslContext = GrpcSslContexts.forClient().build(); + } catch (SSLException e) { + throw new RuntimeException(e); + } InternalNettyChannelBuilder.setProtocolNegotiatorFactory( - delegate(), new ProtocolNegotiatorFactory()); + delegate(), + new GoogleDefaultProtocolNegotiatorFactory( + /* targetServiceAccounts= */ ImmutableList.of(), + SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), + sslContext)); CallCredentials credentials = MoreCallCredentials.from(ComputeEngineCredentials.create()); Status status = Status.OK; if (!CheckGcpEnvironment.isOnGcp()) { @@ -79,40 +83,17 @@ public final class ComputeEngineChannelBuilder } @VisibleForTesting - GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() { - return negotiatorForTest; - } - - private final class ProtocolNegotiatorFactory - implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { - - @Override - public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() { - final LazyChannel lazyHandshakerChannel = - new LazyChannel( - SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL)); - TsiHandshakerFactory altsHandshakerFactory = - new TsiHandshakerFactory() { - @Override - public TsiHandshaker newHandshaker(String authority) { - AltsClientOptions handshakerOptions = - new AltsClientOptions.Builder() - .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) - .setTargetName(authority) - .build(); - return AltsTsiHandshaker.newClient( - HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions); - } - }; - SslContext sslContext; - try { - sslContext = GrpcSslContexts.forClient().build(); - } catch (SSLException ex) { - throw new RuntimeException(ex); - } - return negotiatorForTest = - new GoogleDefaultProtocolNegotiator( - altsHandshakerFactory, lazyHandshakerChannel, sslContext); + ProtocolNegotiator getProtocolNegotiatorForTest() { + SslContext sslContext; + try { + sslContext = GrpcSslContexts.forClient().build(); + } catch (SSLException e) { + throw new RuntimeException(e); } + return new GoogleDefaultProtocolNegotiatorFactory( + /* targetServiceAccounts= */ ImmutableList.of(), + SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), + sslContext) + .buildProtocolNegotiator(); } } diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java index a98380de55..39bd416372 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java @@ -18,23 +18,18 @@ package io.grpc.alts; import com.google.auth.oauth2.GoogleCredentials; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import io.grpc.CallCredentials; import io.grpc.ForwardingChannelBuilder; import io.grpc.ManagedChannelBuilder; import io.grpc.Status; -import io.grpc.alts.internal.AltsClientOptions; -import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; -import io.grpc.alts.internal.AltsTsiHandshaker; -import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator; -import io.grpc.alts.internal.HandshakerServiceGrpc; -import io.grpc.alts.internal.RpcProtocolVersionsUtil; -import io.grpc.alts.internal.TsiHandshaker; -import io.grpc.alts.internal.TsiHandshakerFactory; +import io.grpc.alts.internal.AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory; import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.NettyChannelBuilder; import io.netty.handler.ssl.SslContext; import java.io.IOException; @@ -49,12 +44,21 @@ public final class GoogleDefaultChannelBuilder extends ForwardingChannelBuilder { private final NettyChannelBuilder delegate; - private GoogleDefaultProtocolNegotiator negotiatorForTest; private GoogleDefaultChannelBuilder(String target) { delegate = NettyChannelBuilder.forTarget(target); + SslContext sslContext; + try { + sslContext = GrpcSslContexts.forClient().build(); + } catch (SSLException e) { + throw new RuntimeException(e); + } InternalNettyChannelBuilder.setProtocolNegotiatorFactory( - delegate(), new ProtocolNegotiatorFactory()); + delegate(), + new GoogleDefaultProtocolNegotiatorFactory( + /* targetServiceAccounts= */ ImmutableList.of(), + SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), + sslContext)); @Nullable CallCredentials credentials = null; Status status = Status.OK; try { @@ -84,40 +88,17 @@ public final class GoogleDefaultChannelBuilder } @VisibleForTesting - GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() { - return negotiatorForTest; - } - - private final class ProtocolNegotiatorFactory - implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { - - @Override - public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() { - final LazyChannel lazyHandshakerChannel = - new LazyChannel( - SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL)); - TsiHandshakerFactory altsHandshakerFactory = - new TsiHandshakerFactory() { - @Override - public TsiHandshaker newHandshaker(String authority) { - AltsClientOptions handshakerOptions = - new AltsClientOptions.Builder() - .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) - .setTargetName(authority) - .build(); - return AltsTsiHandshaker.newClient( - HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions); - } - }; - SslContext sslContext; - try { - sslContext = GrpcSslContexts.forClient().build(); - } catch (SSLException ex) { - throw new RuntimeException(ex); - } - return negotiatorForTest = - new GoogleDefaultProtocolNegotiator( - altsHandshakerFactory, lazyHandshakerChannel, sslContext); + ProtocolNegotiator getProtocolNegotiatorForTest() { + SslContext sslContext; + try { + sslContext = GrpcSslContexts.forClient().build(); + } catch (SSLException e) { + throw new RuntimeException(e); } + return new GoogleDefaultProtocolNegotiatorFactory( + /* targetServiceAccounts= */ ImmutableList.of(), + SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), + sslContext) + .buildProtocolNegotiator(); } } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index 3845d8a697..e1a251c2a1 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -16,8 +16,10 @@ package io.grpc.alts.internal; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.protobuf.Any; import io.grpc.Attributes; import io.grpc.Channel; @@ -27,102 +29,266 @@ import io.grpc.InternalChannelz.Security; import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult; -import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; -import io.grpc.netty.InternalProtocolNegotiators.AbstractBufferingHandler; +import io.grpc.netty.InternalProtocolNegotiators; import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; -import java.util.logging.Level; +import java.security.GeneralSecurityException; +import java.util.List; import java.util.logging.Logger; +import javax.annotation.Nullable; /** - * A GRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that provides ALTS + * A gRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that provides ALTS * security on the wire, similar to Netty's {@code SslHandler}. */ -public abstract class AltsProtocolNegotiator implements ProtocolNegotiator { - +// TODO(carl-mastrangelo): rename this AltsProtocolNegotiators. +public final class AltsProtocolNegotiator { private static final Logger logger = Logger.getLogger(AltsProtocolNegotiator.class.getName()); @Grpc.TransportAttr public static final Attributes.Key TSI_PEER_KEY = Attributes.Key.create("TSI_PEER"); - @Grpc.TransportAttr - public static final Attributes.Key ALTS_CONTEXT_KEY = - Attributes.Key.create("ALTS_CONTEXT_KEY"); + public static final Attributes.Key AUTH_CONTEXT_KEY = + Attributes.Key.create("AUTH_CONTEXT_KEY"); - private static final AsciiString scheme = AsciiString.of("https"); + private static final AsciiString SCHEME = AsciiString.of("https"); - /** Creates a negotiator used for ALTS client. */ - public static AltsProtocolNegotiator createClientNegotiator( - final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) { - final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator { + /** + * ClientAltsProtocolNegotiatorFactory is a factory for doing client side negotiation of an ALTS + * channel. + */ + public static final class ClientAltsProtocolNegotiatorFactory + implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { + + private final ImmutableList targetServiceAccounts; + private final LazyChannel lazyHandshakerChannel; + + public ClientAltsProtocolNegotiatorFactory( + List targetServiceAccounts, + ObjectPool handshakerChannelPool) { + this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts); + this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); + } + + @Override + public ProtocolNegotiator buildProtocolNegotiator() { + return new ClientAltsProtocolNegotiator( + new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel), + lazyHandshakerChannel); + } + } + + @VisibleForTesting + private static final class ClientAltsProtocolNegotiator implements ProtocolNegotiator { + private final TsiHandshakerFactory handshakerFactory; + private final LazyChannel lazyHandshakerChannel; + + ClientAltsProtocolNegotiator( + TsiHandshakerFactory handshakerFactory, LazyChannel lazyHandshakerChannel) { + this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory"); + this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel"); + } + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority()); + NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker); + ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler); + ChannelHandler thh = + new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator()); + ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(thh); + return wuah; + } + + @Override + public void close() { + lazyHandshakerChannel.close(); + } + } + + /** + * Creates a protocol negotiator for ALTS on the server side. + */ + public static ProtocolNegotiator serverAltsProtocolNegotiator( + ObjectPool handshakerChannelPool) { + final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); + final class ServerTsiHandshakerFactory implements TsiHandshakerFactory { @Override - public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + public TsiHandshaker newHandshaker(@Nullable String authority) { + assert authority == null; + return AltsTsiHandshaker.newServer( + HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), + new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions())); + } + } + + return new ServerAltsProtocolNegotiator( + new ServerTsiHandshakerFactory(), lazyHandshakerChannel); + } + + @VisibleForTesting + static final class ServerAltsProtocolNegotiator implements ProtocolNegotiator { + private final TsiHandshakerFactory handshakerFactory; + private final LazyChannel lazyHandshakerChannel; + + @VisibleForTesting + ServerAltsProtocolNegotiator( + TsiHandshakerFactory handshakerFactory, LazyChannel lazyHandshakerChannel) { + this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory"); + this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel"); + } + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + TsiHandshaker handshaker = handshakerFactory.newHandshaker(/* authority= */ null); + NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker); + ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler); + ChannelHandler thh = + new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator()); + ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(thh); + return wuah; + } + + @Override + public void close() { + logger.finest("ALTS Server ProtocolNegotiator Closed"); + lazyHandshakerChannel.close(); + } + } + + /** + * A Protocol Negotiator factory which can switch between ALTS and TLS based on EAG Attrs. + */ + public static final class GoogleDefaultProtocolNegotiatorFactory + implements ProtocolNegotiatorFactory { + private final ImmutableList targetServiceAccounts; + private final LazyChannel lazyHandshakerChannel; + private final SslContext sslContext; + + /** + * Creates Negotiator Factory, which will either use the targetServiceAccounts and + * handshakerChannelPool, or the sslContext. + */ + public GoogleDefaultProtocolNegotiatorFactory( + List targetServiceAccounts, + ObjectPool handshakerChannelPool, + SslContext sslContext) { + this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts); + this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); + this.sslContext = checkNotNull(sslContext, "sslContext"); + } + + @Override + public ProtocolNegotiator buildProtocolNegotiator() { + return new GoogleDefaultProtocolNegotiator( + new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel), + lazyHandshakerChannel, + sslContext); + } + } + + private static final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator { + private final TsiHandshakerFactory handshakerFactory; + private final LazyChannel lazyHandshakerChannel; + private final SslContext sslContext; + + GoogleDefaultProtocolNegotiator( + TsiHandshakerFactory handshakerFactory, + LazyChannel lazyHandshakerChannel, + SslContext sslContext) { + this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory"); + this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel"); + this.sslContext = checkNotNull(sslContext, "checkNotNull"); + } + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler); + ChannelHandler securityHandler; + if (grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY) != null + || grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND) != null) { TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority()); - return new BufferUntilAltsNegotiatedHandler( - grpcHandler, - new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)), - new TsiFrameHandler()); - } - - @Override - public void close() { - logger.finest("ALTS Client ProtocolNegotiator Closed"); - lazyHandshakerChannel.close(); + NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker); + securityHandler = + new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator()); + } else { + securityHandler = InternalProtocolNegotiators.clientTlsHandler( + gnh, sslContext, grpcHandler.getAuthority()); } + ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(securityHandler); + return wuah; } - return new ClientAltsProtocolNegotiator(); + @Override + public void close() { + logger.finest("ALTS Server ProtocolNegotiator Closed"); + lazyHandshakerChannel.close(); + } } - @Override - public final AsciiString scheme() { - return scheme; - } + private static final class ClientTsiHandshakerFactory implements TsiHandshakerFactory { - /** Creates a negotiator used for ALTS server. */ - public static AltsProtocolNegotiator createServerNegotiator( - final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) { - final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator { + private final ImmutableList targetServiceAccounts; + private final LazyChannel lazyHandshakerChannel; - @Override - public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null); - return new BufferUntilAltsNegotiatedHandler( - grpcHandler, - new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)), - new TsiFrameHandler()); - } - - @Override - public void close() { - logger.finest("ALTS Server ProtocolNegotiator Closed"); - lazyHandshakerChannel.close(); - } + ClientTsiHandshakerFactory( + ImmutableList targetServiceAccounts, LazyChannel lazyHandshakerChannel) { + this.targetServiceAccounts = checkNotNull(targetServiceAccounts, "targetServiceAccounts"); + this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel"); } - return new ServerAltsProtocolNegotiator(); + @Override + public TsiHandshaker newHandshaker(@Nullable String authority) { + AltsClientOptions handshakerOptions = + new AltsClientOptions.Builder() + .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) + .setTargetServiceAccounts(targetServiceAccounts) + .setTargetName(authority) + .build(); + return AltsTsiHandshaker.newClient( + HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions); + } } /** Channel created from a channel pool lazily. */ - public static class LazyChannel { + @VisibleForTesting + static final class LazyChannel { private final ObjectPool channelPool; private Channel channel; - public LazyChannel(ObjectPool channelPool) { - this.channelPool = channelPool; + @VisibleForTesting + LazyChannel(ObjectPool channelPool) { + this.channelPool = checkNotNull(channelPool, "channelPool"); } /** * If channel is null, gets a channel from the channel pool, otherwise, returns the cached * channel. */ - public synchronized Channel get() { + synchronized Channel get() { if (channel == null) { channel = channelPool.getObject(); } @@ -130,87 +296,37 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator { } /** Returns the cached channel to the channel pool. */ - public synchronized void close() { + synchronized void close() { if (channel != null) { channelPool.returnObject(channel); } } } - /** Buffers all writes until the ALTS handshake is complete. */ - @VisibleForTesting - static final class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler { - - private final GrpcHttp2ConnectionHandler grpcHandler; - - BufferUntilAltsNegotiatedHandler( - GrpcHttp2ConnectionHandler grpcHandler, ChannelHandler... negotiationhandlers) { - super(negotiationhandlers); - // Save the gRPC handler. The ALTS handler doesn't support buffering before the handshake - // completes, so we wait until the handshake was successful before adding the grpc handler. - this.grpcHandler = grpcHandler; - } - - // TODO: Remove this once https://github.com/grpc/grpc-java/pull/3715 is in. - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - logger.log(Level.FINEST, "Exception while buffering for ALTS Negotiation", cause); - fail(ctx, cause); - ctx.fireExceptionCaught(cause); - } + private static final class AltsHandshakeValidator extends TsiHandshakeHandler.HandshakeValidator { @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (logger.isLoggable(Level.FINEST)) { - logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[] {evt}); + public SecurityDetails validatePeerObject(Object peerObject) throws GeneralSecurityException { + AltsAuthContext altsAuthContext = (AltsAuthContext) peerObject; + // Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if + // Rpc Protocol Versions mismatch. + RpcVersionsCheckResult checkResult = + RpcProtocolVersionsUtil.checkRpcProtocolVersions( + RpcProtocolVersionsUtil.getRpcProtocolVersions(), + altsAuthContext.getPeerRpcVersions()); + if (!checkResult.getResult()) { + String errorMessage = + "Local Rpc Protocol Versions " + + RpcProtocolVersionsUtil.getRpcProtocolVersions() + + " are not compatible with peer Rpc Protocol Versions " + + altsAuthContext.getPeerRpcVersions(); + throw Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException(); } - if (evt instanceof TsiHandshakeCompletionEvent) { - TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt; - if (altsEvt.isSuccess()) { - // Add the gRPC handler just before this handler. We only allow the grpcHandler to be - // null to support testing. In production, a grpc handler will always be provided. - if (grpcHandler != null) { - ctx.pipeline().addBefore(ctx.name(), null, grpcHandler); - AltsAuthContext altsContext = (AltsAuthContext) altsEvt.context(); - Preconditions.checkNotNull(altsContext); - // Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if - // Rpc Protocol Versions mismatch. - RpcVersionsCheckResult checkResult = - RpcProtocolVersionsUtil.checkRpcProtocolVersions( - RpcProtocolVersionsUtil.getRpcProtocolVersions(), - altsContext.getPeerRpcVersions()); - if (!checkResult.getResult()) { - String errorMessage = - "Local Rpc Protocol Versions " - + RpcProtocolVersionsUtil.getRpcProtocolVersions().toString() - + "are not compatible with peer Rpc Protocol Versions " - + altsContext.getPeerRpcVersions().toString(); - logger.finest(errorMessage); - fail(ctx, Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException()); - } - grpcHandler.handleProtocolNegotiationCompleted( - Attributes.newBuilder() - .set(TSI_PEER_KEY, altsEvt.peer()) - .set(ALTS_CONTEXT_KEY, altsContext) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress()) - .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress()) - .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) - .build(), - new Security(new OtherSecurity("alts", Any.pack(altsContext.context)))); - } - logger.finest("Flushing ALTS buffered data"); - // Now write any buffered data and remove this handler. - writeBufferedAndRemove(ctx); - } else { - logger.log(Level.FINEST, "ALTS handshake failed", altsEvt.cause()); - fail(ctx, unavailableException("ALTS handshake failed", altsEvt.cause())); - } - } - super.userEventTriggered(ctx, evt); - } - - private static RuntimeException unavailableException(String msg, Throwable cause) { - return Status.UNAVAILABLE.withCause(cause).withDescription(msg).asRuntimeException(); + return new SecurityDetails( + SecurityLevel.PRIVACY_AND_INTEGRITY, + new Security(new OtherSecurity("alts", Any.pack(altsAuthContext.context)))); } } + + private AltsProtocolNegotiator() {} } diff --git a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java deleted file mode 100644 index 2953afd20a..0000000000 --- a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2018 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.alts.internal; - -import com.google.common.annotations.VisibleForTesting; -import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; -import io.grpc.internal.GrpcAttributes; -import io.grpc.netty.GrpcHttp2ConnectionHandler; -import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; -import io.grpc.netty.InternalProtocolNegotiators; -import io.netty.channel.ChannelHandler; -import io.netty.handler.ssl.SslContext; -import io.netty.util.AsciiString; - -/** A client-side GPRC {@link ProtocolNegotiator} for Google Default Channel. */ -public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator { - - private final ProtocolNegotiator altsProtocolNegotiator; - private final ProtocolNegotiator tlsProtocolNegotiator; - - /** Constructor for protocol negotiator of Google Default Channel. */ - public GoogleDefaultProtocolNegotiator( - TsiHandshakerFactory altsFactory, LazyChannel lazyHandshakerChannel, SslContext sslContext) { - altsProtocolNegotiator = - AltsProtocolNegotiator.createClientNegotiator(altsFactory, lazyHandshakerChannel); - tlsProtocolNegotiator = InternalProtocolNegotiators.tls(sslContext); - } - - @Override - public AsciiString scheme() { - assert tlsProtocolNegotiator.scheme().equals(altsProtocolNegotiator.scheme()); - return tlsProtocolNegotiator.scheme(); - } - - @VisibleForTesting - GoogleDefaultProtocolNegotiator( - ProtocolNegotiator altsProtocolNegotiator, ProtocolNegotiator tlsProtocolNegotiator) { - this.altsProtocolNegotiator = altsProtocolNegotiator; - this.tlsProtocolNegotiator = tlsProtocolNegotiator; - } - - @Override - public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - if (grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY) != null - || grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND) != null) { - return altsProtocolNegotiator.newHandler(grpcHandler); - } else { - return tlsProtocolNegotiator.newHandler(grpcHandler); - } - } - - @Override - public void close() { - altsProtocolNegotiator.close(); - tlsProtocolNegotiator.close(); - } -} diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java index 264541223b..55bd61b184 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java @@ -19,9 +19,7 @@ package io.grpc.alts.internal; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import com.google.common.annotations.VisibleForTesting; import io.grpc.alts.internal.TsiFrameProtector.Consumer; -import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelException; import io.netty.channel.ChannelHandlerContext; @@ -33,7 +31,6 @@ import java.net.SocketAddress; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.Future; import java.util.logging.Level; import java.util.logging.Logger; @@ -47,72 +44,33 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann private TsiFrameProtector protector; private PendingWriteQueue pendingUnprotectedWrites; - private State state = State.HANDSHAKE_NOT_FINISHED; - private boolean closeInitiated = false; + private boolean closeInitiated; - @VisibleForTesting - enum State { - HANDSHAKE_NOT_FINISHED, - PROTECTED, - CLOSED, - HANDSHAKE_FAILED + public TsiFrameHandler(TsiFrameProtector protector) { + this.protector = checkNotNull(protector, "protector"); } - public TsiFrameHandler() {} - @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - logger.finest("TsiFrameHandler added"); super.handlerAdded(ctx); assert pendingUnprotectedWrites == null; pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx)); } - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception { - if (logger.isLoggable(Level.FINEST)) { - logger.log(Level.FINEST, "TsiFrameHandler user event triggered", new Object[]{event}); - } - if (event instanceof TsiHandshakeCompletionEvent) { - TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event; - if (tsiEvent.isSuccess()) { - setProtector(tsiEvent.protector()); - } else { - state = State.HANDSHAKE_FAILED; - } - // Ignore errors. Another handler in the pipeline must handle TSI Errors. - } - // Keep propagating the message, as others may want to read it. - super.userEventTriggered(ctx, event); - } - - @VisibleForTesting - void setProtector(TsiFrameProtector protector) { - logger.finest("TsiFrameHandler protector set"); - checkState(this.protector == null); - this.protector = checkNotNull(protector); - this.state = State.PROTECTED; - } - @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - checkState( - state == State.PROTECTED, - "Cannot read frames while the TSI handshake is %s", state); + checkState(protector != null, "decode() called after close()"); protector.unprotect(in, out, ctx.alloc()); } @Override - public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) - throws Exception { - checkState( - state == State.PROTECTED, - "Cannot write frames while the TSI handshake state is %s", state); + @SuppressWarnings("FutureReturnValueIgnored") // for setSuccess + public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) { + checkState(protector != null, "write() called after close()"); ByteBuf msg = (ByteBuf) message; if (!msg.isReadable()) { // Nothing to encode. - @SuppressWarnings("unused") // go/futurereturn-lsc - Future possiblyIgnoredError = promise.setSuccess(); + promise.setSuccess(); return; } @@ -122,30 +80,11 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann @Override public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { - if (!pendingUnprotectedWrites.isEmpty()) { + if (pendingUnprotectedWrites != null && !pendingUnprotectedWrites.isEmpty()) { pendingUnprotectedWrites.removeAndFailAll( new ChannelException("Pending write on removal of TSI handler")); } - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - pendingUnprotectedWrites.removeAndFailAll(cause); - super.exceptionCaught(ctx, cause); - } - - @Override - public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { - ctx.bind(localAddress, promise); - } - - @Override - public void connect( - ChannelHandlerContext ctx, - SocketAddress remoteAddress, - SocketAddress localAddress, - ChannelPromise promise) { - ctx.connect(remoteAddress, localAddress, promise); + destroyProtector(); } @Override @@ -154,6 +93,12 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann ctx.disconnect(promise); } + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + doClose(ctx); + ctx.close(promise); + } + private void doClose(ChannelHandlerContext ctx) { if (closeInitiated) { return; @@ -165,51 +110,34 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann flush(ctx); } } catch (GeneralSecurityException e) { - logger.log(Level.FINE, "Ignoring error on flush before close", e); + logger.log(Level.FINE, "Ignored error on flush before close", e); } finally { - state = State.CLOSED; - release(); + pendingUnprotectedWrites = null; + destroyProtector(); } } @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) { - doClose(ctx); - ctx.close(promise); - } - - @Override - public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { - doClose(ctx); - ctx.deregister(promise); - } - - @Override - public void read(ChannelHandlerContext ctx) { - ctx.read(); - } - - @Override + @SuppressWarnings("FutureReturnValueIgnored") // for aggregatePromise.doneAllocatingPromises public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException { - if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) { - logger.fine( - String.format("FrameHandler is inactive(%s), channel id: %s", - state, ctx.channel().id().asShortText())); + if (protector == null) { + // TODO(carl-mastrangelo): this should be a checkState. AbstractNettyHandler.exceptionCaught + // transitively calls flush even after closed, for some reason. + pendingUnprotectedWrites.removeAndFailAll( + new ChannelException("Pending write on removal of TSI handler")); + logger.fine("flush() called after close()"); return; } - checkState( - state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state); - final ProtectedPromise aggregatePromise = - new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size()); - - List bufs = new ArrayList<>(pendingUnprotectedWrites.size()); - if (pendingUnprotectedWrites.isEmpty()) { // Return early if there's nothing to write. Otherwise protector.protectFlush() below may // not check for "no-data" and go on writing the 0-byte "data" to the socket with the // protection framing. return; } + final ProtectedPromise aggregatePromise = + new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size()); + List bufs = new ArrayList<>(pendingUnprotectedWrites.size()); + // Drain the unprotected writes. while (!pendingUnprotectedWrites.isEmpty()) { ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current(); @@ -218,25 +146,54 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove()); } - protector.protectFlush( - bufs, - new Consumer() { - @Override - public void accept(ByteBuf b) { - ctx.writeAndFlush(b, aggregatePromise.newPromise()); - } - }, - ctx.alloc()); + final class ProtectedFrameWriteFlusher implements Consumer { + @Override + public void accept(ByteBuf byteBuf) { + ctx.writeAndFlush(byteBuf, aggregatePromise.newPromise()); + } + } + + protector.protectFlush(bufs, new ProtectedFrameWriteFlusher(), ctx.alloc()); // We're done writing, start the flow of promise events. - @SuppressWarnings("unused") // go/futurereturn-lsc - Future possiblyIgnoredError = aggregatePromise.doneAllocatingPromises(); + aggregatePromise.doneAllocatingPromises(); } - private void release() { + // Only here to fulfill ChannelOutboundHandler + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { + ctx.bind(localAddress, promise); + } + + // Only here to fulfill ChannelOutboundHandler + @Override + public void connect( + ChannelHandlerContext ctx, + SocketAddress remoteAddress, + SocketAddress localAddress, + ChannelPromise promise) { + ctx.connect(remoteAddress, localAddress, promise); + } + + // Only here to fulfill ChannelOutboundHandler + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + ctx.deregister(promise); + } + + // Only here to fulfill ChannelOutboundHandler + @Override + public void read(ChannelHandlerContext ctx) { + ctx.read(); + } + + private void destroyProtector() { if (protector != null) { - protector.destroy(); - protector = null; + try { + protector.destroy(); + } finally { + protector = null; + } } } } diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java index 98dd1f9090..76d9214053 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java @@ -17,19 +17,25 @@ package io.grpc.alts.internal; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.alts.internal.AltsProtocolNegotiator.AUTH_CONTEXT_KEY; +import static io.grpc.alts.internal.AltsProtocolNegotiator.TSI_PEER_KEY; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; +import io.grpc.Attributes; +import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.InternalChannelz.Security; +import io.grpc.SecurityLevel; +import io.grpc.alts.internal.TsiHandshakeHandler.HandshakeValidator.SecurityDetails; +import io.grpc.internal.GrpcAttributes; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.netty.InternalProtocolNegotiators; +import io.grpc.netty.ProtocolNegotiationEvent; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.util.ReferenceCountUtil; import java.security.GeneralSecurityException; import java.util.List; -import java.util.concurrent.Future; -import java.util.logging.Level; -import java.util.logging.Logger; import javax.annotation.Nullable; /** @@ -38,118 +44,63 @@ import javax.annotation.Nullable; */ public final class TsiHandshakeHandler extends ByteToMessageDecoder { - private static final Logger logger = Logger.getLogger(TsiHandshakeHandler.class.getName()); + /** + * Validates a Tsi Peer object. + */ + public abstract static class HandshakeValidator { + + public static final class SecurityDetails { + + private final SecurityLevel securityLevel; + private final Security security; + + /** + * Constructs SecurityDetails. + */ + public SecurityDetails(io.grpc.SecurityLevel securityLevel, @Nullable Security security) { + this.securityLevel = checkNotNull(securityLevel, "securityLevel"); + this.security = security; + } + + public Security getSecurity() { + return security; + } + + public SecurityLevel getSecurityLevel() { + return securityLevel; + } + } + + /** + * Validates a Tsi Peer object. + */ + public abstract SecurityDetails validatePeerObject(Object peerObject) + throws GeneralSecurityException; + } private static final int HANDSHAKE_FRAME_SIZE = 1024; private final NettyTsiHandshaker handshaker; - private boolean started; + private final HandshakeValidator handshakeValidator; + private final ChannelHandler next; + + private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); /** - * This buffer doesn't store any state. We just hold onto it in case we end up allocating a buffer - * that ends up being unused. + * Constructs a TsiHandshakeHandler. */ - private ByteBuf buffer; - - public TsiHandshakeHandler(NettyTsiHandshaker handshaker) { - this.handshaker = checkNotNull(handshaker); - } - - /** - * Event that is fired once the TSI handshake is complete, which may be because it was successful - * or there was an error. - */ - public static final class TsiHandshakeCompletionEvent { - - private final Throwable cause; - private final TsiPeer peer; - private final Object context; - private final TsiFrameProtector protector; - - /** Creates a new event that indicates a successful handshake. */ - @VisibleForTesting - TsiHandshakeCompletionEvent( - TsiFrameProtector protector, TsiPeer peer, @Nullable Object peerObject) { - this.cause = null; - this.peer = checkNotNull(peer); - this.protector = checkNotNull(protector); - this.context = peerObject; - } - - /** Creates a new event that indicates an unsuccessful handshake/. */ - TsiHandshakeCompletionEvent(Throwable cause) { - this.cause = checkNotNull(cause); - this.peer = null; - this.protector = null; - this.context = null; - } - - /** Return {@code true} if the handshake was successful. */ - public boolean isSuccess() { - return cause == null; - } - - /** - * Return the {@link Throwable} if {@link #isSuccess()} returns {@code false} and so the - * handshake failed. - */ - @Nullable - public Throwable cause() { - return cause; - } - - @Nullable - public TsiPeer peer() { - return peer; - } - - @Nullable - public Object context() { - return context; - } - - @Nullable - TsiFrameProtector protector() { - return protector; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("peer", peer) - .add("protector", protector) - .add("context", context) - .add("cause", cause) - .toString(); - } + public TsiHandshakeHandler( + ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator) { + this.handshaker = checkNotNull(handshaker, "handshaker"); + this.handshakeValidator = checkNotNull(handshakeValidator, "handshakeValidator"); + this.next = checkNotNull(next, "next"); } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - logger.finest("TsiHandshakeHandler added"); - maybeStart(ctx); - super.handlerAdded(ctx); - } - - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - logger.finest("TsiHandshakeHandler channel active"); - maybeStart(ctx); - super.channelActive(ctx); - } - - @Override - public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { - logger.finest("TsiHandshakeHandler handler removed"); - close(); - super.handlerRemoved0(ctx); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.log(Level.FINEST, "Exception in TsiHandshakeHandler", cause); - ctx.fireUserEventTriggered(new TsiHandshakeCompletionEvent(cause)); - super.exceptionCaught(ctx, cause); + InternalProtocolNegotiators.negotiationLogger(ctx) + .log(ChannelLogLevel.INFO, "TsiHandshake started"); + sendHandshake(ctx); } @Override @@ -168,71 +119,72 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder { // If the handshake is complete, transition to the framing state. if (!handshaker.isInProgress()) { - TsiFrameProtector protector = null; + TsiPeer peer = handshaker.extractPeer(); + Object authContext = handshaker.extractPeerObject(); + SecurityDetails details = handshakeValidator.validatePeerObject(authContext); + // createFrameProtector must be called last. + TsiFrameProtector protector = handshaker.createFrameProtector(ctx.alloc()); + TsiFrameHandler framer; + boolean success = false; try { - ctx.pipeline().remove(this); - protector = handshaker.createFrameProtector(ctx.alloc()); - TsiHandshakeCompletionEvent evt = new TsiHandshakeCompletionEvent( - protector, - handshaker.extractPeer(), - handshaker.extractPeerObject()); - protector = null; - ctx.fireUserEventTriggered(evt); - // No need to do anything with the in buffer, it will be re added to the pipeline when this - // handler is removed. + framer = new TsiFrameHandler(protector); + // replace the current handler with the framer (instead of adding before) since there may + // be pending data after the handshake frame. The data will need to be decoded before + // being passed to the `next` handler. + ctx.pipeline().replace(ctx.name(), null, framer); + // Once the framer is in the pipeline, it will be cleaned up when the handler is removed. + success = true; } finally { - if (protector != null) { + if (!success && protector != null) { protector.destroy(); } - close(); } + // Add the `next` handler as late as possible, as it will issue writes on being added. + ctx.pipeline().addAfter(ctx.pipeline().context(framer).name(), null, next); + fireProtocolNegotiationEvent(ctx, peer, authContext, details); } } - private void maybeStart(ChannelHandlerContext ctx) { - if (!started && ctx.channel().isActive()) { - started = true; - sendHandshake(ctx); + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof ProtocolNegotiationEvent) { + pne = (ProtocolNegotiationEvent) evt; + } else { + super.userEventTriggered(ctx, evt); } } + private void fireProtocolNegotiationEvent( + ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details) { + InternalProtocolNegotiators.negotiationLogger(ctx) + .log(ChannelLogLevel.INFO, "TsiHandshake finished"); + ProtocolNegotiationEvent localPne = pne; + Attributes.Builder attrs = InternalProtocolNegotiationEvent.getAttributes(localPne).toBuilder() + .set(TSI_PEER_KEY, peer) + .set(AUTH_CONTEXT_KEY, authContext) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, details.getSecurityLevel()); + localPne = InternalProtocolNegotiationEvent.withAttributes(localPne, attrs.build()); + localPne = InternalProtocolNegotiationEvent.withSecurity(localPne, details.getSecurity()); + ctx.fireUserEventTriggered(localPne); + } + /** Sends as many bytes as are available from the handshaker to the remote peer. */ - private void sendHandshake(ChannelHandlerContext ctx) { - boolean needToFlush = false; - - // Iterate until there is nothing left to write. + @SuppressWarnings("FutureReturnValueIgnored") // for addListener + private void sendHandshake(ChannelHandlerContext ctx) throws GeneralSecurityException { while (true) { - buffer = getOrCreateBuffer(ctx.alloc()); + boolean written = false; + ByteBuf buf = ctx.alloc().buffer(HANDSHAKE_FRAME_SIZE).retain(); // refcnt = 2 try { - handshaker.getBytesToSendToPeer(buffer); - } catch (GeneralSecurityException e) { - throw new RuntimeException(e); + handshaker.getBytesToSendToPeer(buf); + if (buf.isReadable()) { + ctx.writeAndFlush(buf).addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + written = true; + } else { + break; + } + } finally { + buf.release(written ? 1 : 2); } - if (!buffer.isReadable()) { - break; - } - - needToFlush = true; - @SuppressWarnings("unused") // go/futurereturn-lsc - Future possiblyIgnoredError = ctx.write(buffer); - buffer = null; } - - // If something was written, flush. - if (needToFlush) { - ctx.flush(); - } - } - - private ByteBuf getOrCreateBuffer(ByteBufAllocator alloc) { - if (buffer == null) { - buffer = alloc.buffer(HANDSHAKE_FRAME_SIZE); - } - return buffer; - } - - private void close() { - ReferenceCountUtil.safeRelease(buffer); - buffer = null; } } diff --git a/alts/src/test/java/io/grpc/alts/AltsChannelBuilderTest.java b/alts/src/test/java/io/grpc/alts/AltsChannelBuilderTest.java index da202c54a5..a44de19b91 100644 --- a/alts/src/test/java/io/grpc/alts/AltsChannelBuilderTest.java +++ b/alts/src/test/java/io/grpc/alts/AltsChannelBuilderTest.java @@ -18,7 +18,6 @@ package io.grpc.alts; import static com.google.common.truth.Truth.assertThat; -import io.grpc.alts.internal.AltsProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import org.junit.Test; import org.junit.runner.RunWith; @@ -28,17 +27,14 @@ import org.junit.runners.JUnit4; public final class AltsChannelBuilderTest { @Test - public void buildsNettyChannel() throws Exception { + public void buildsNettyChannel() { AltsChannelBuilder builder = AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting(); ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest(); - assertThat(protocolNegotiator).isNull(); - - builder.build(); - - protocolNegotiator = builder.getProtocolNegotiatorForTest(); assertThat(protocolNegotiator).isNotNull(); - assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class); + // Avoids exposing this class + assertThat(protocolNegotiator.getClass().getSimpleName()) + .isEqualTo("ClientAltsProtocolNegotiator"); } } diff --git a/alts/src/test/java/io/grpc/alts/ComputeEngineChannelBuilderTest.java b/alts/src/test/java/io/grpc/alts/ComputeEngineChannelBuilderTest.java index 976ba611d1..f7752a2873 100644 --- a/alts/src/test/java/io/grpc/alts/ComputeEngineChannelBuilderTest.java +++ b/alts/src/test/java/io/grpc/alts/ComputeEngineChannelBuilderTest.java @@ -18,7 +18,6 @@ package io.grpc.alts; import static com.google.common.truth.Truth.assertThat; -import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import org.junit.Test; import org.junit.runner.RunWith; @@ -33,6 +32,7 @@ public final class ComputeEngineChannelBuilderTest { builder.build(); ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest(); - assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class); + assertThat(protocolNegotiator.getClass().getSimpleName()) + .isEqualTo("GoogleDefaultProtocolNegotiator"); } } diff --git a/alts/src/test/java/io/grpc/alts/GoogleDefaultChannelBuilderTest.java b/alts/src/test/java/io/grpc/alts/GoogleDefaultChannelBuilderTest.java index 9c84ee24cb..4336aff745 100644 --- a/alts/src/test/java/io/grpc/alts/GoogleDefaultChannelBuilderTest.java +++ b/alts/src/test/java/io/grpc/alts/GoogleDefaultChannelBuilderTest.java @@ -18,7 +18,6 @@ package io.grpc.alts; import static com.google.common.truth.Truth.assertThat; -import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import org.junit.Test; import org.junit.runner.RunWith; @@ -33,6 +32,7 @@ public final class GoogleDefaultChannelBuilderTest { builder.build(); ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest(); - assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class); + assertThat(protocolNegotiator.getClass().getSimpleName()) + .isEqualTo("GoogleDefaultProtocolNegotiator"); } } diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index 7dbae59cf3..7d8e637ba8 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -27,6 +27,7 @@ import io.grpc.Attributes; import io.grpc.Channel; import io.grpc.Grpc; import io.grpc.InternalChannelz; +import io.grpc.InternalChannelz.Security; import io.grpc.ManagedChannel; import io.grpc.SecurityLevel; import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; @@ -80,6 +81,7 @@ import org.junit.runners.JUnit4; /** Tests for {@link AltsProtocolNegotiator}. */ @RunWith(JUnit4.class) +@SuppressWarnings("FutureReturnValueIgnored") public class AltsProtocolNegotiatorTest { private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler(); @@ -90,7 +92,6 @@ public class AltsProtocolNegotiatorTest { private EmbeddedChannel channel; private Throwable caughtException; - private volatile TsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent; private ChannelHandler handler; private TsiPeer mockedTsiPeer = new TsiPeer(Collections.>emptyList()); @@ -102,12 +103,12 @@ public class AltsProtocolNegotiatorTest { private final TsiHandshaker mockHandshaker = new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) { @Override - public TsiPeer extractPeer() throws GeneralSecurityException { + public TsiPeer extractPeer() { return mockedTsiPeer; } @Override - public Object extractPeerObject() throws GeneralSecurityException { + public Object extractPeerObject() { return mockedAltsContext; } }; @@ -115,24 +116,13 @@ public class AltsProtocolNegotiatorTest { @Before public void setup() throws Exception { - ChannelHandler userEventHandler = - new ChannelDuplexHandler() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof TsiHandshakeHandler.TsiHandshakeCompletionEvent) { - tsiEvent = (TsiHandshakeHandler.TsiHandshakeCompletionEvent) evt; - } else { - super.userEventTriggered(ctx, evt); - } - } - }; - ChannelHandler uncaughtExceptionHandler = new ChannelDuplexHandler() { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { caughtException = cause; super.exceptionCaught(ctx, cause); + ctx.close(); } }; @@ -157,9 +147,9 @@ public class AltsProtocolNegotiatorTest { ObjectPool fakeChannelPool = new FixedObjectPool(fakeChannel); LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool); handler = - AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, lazyFakeChannel) + new AltsProtocolNegotiator.ServerAltsProtocolNegotiator(handshakerFactory, lazyFakeChannel) .newHandler(grpcHandler); - channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler); + channel = new EmbeddedChannel(uncaughtExceptionHandler, handler); } @After @@ -182,6 +172,8 @@ public class AltsProtocolNegotiatorTest { @Test @SuppressWarnings("unchecked") // List cast public void protectShouldRoundtrip() throws Exception { + doHandshake(); + // Write the message 1 character at a time. The message should be buffered // and not interfere with the handshake. final AtomicInteger writeCount = new AtomicInteger(); @@ -204,10 +196,6 @@ public class AltsProtocolNegotiatorTest { } channel.flush(); - // Now do the handshake. The buffered message will automatically be protected - // and sent. - doHandshake(); - // Capture the protected data written to the wire. assertEquals(1, channel.outboundMessages().size()); ByteBuf protectedData = channel.readOutbound(); @@ -351,7 +339,7 @@ public class AltsProtocolNegotiatorTest { doHandshake(); assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer); - assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY)) + assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY)) .isEqualTo(mockedAltsContext); assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString()) .isEqualTo("embedded"); @@ -388,7 +376,7 @@ public class AltsProtocolNegotiatorTest { if (caughtException != null) { throw new RuntimeException(caughtException); } - assertNotNull(tsiEvent); + assertNotNull(grpcHandler.attrs); } private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() { @@ -408,6 +396,7 @@ public class AltsProtocolNegotiatorTest { private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { private Attributes attrs; + private Security securityInfo; private CapturingGrpcHttp2ConnectionHandler( Http2ConnectionDecoder decoder, @@ -422,6 +411,7 @@ public class AltsProtocolNegotiatorTest { // If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler channel.pipeline().remove(this); this.attrs = attrs; + this.securityInfo = securityInfo; } } diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index 50b674f829..a9784b2a21 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -16,16 +16,27 @@ package io.grpc.alts.internal; +import static com.google.common.truth.Truth.assertThat; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import io.grpc.Attributes; +import io.grpc.Channel; +import io.grpc.ManagedChannel; +import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.internal.GrpcAttributes; +import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.ssl.SslContext; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -33,16 +44,36 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public final class GoogleDefaultProtocolNegotiatorTest { - private ProtocolNegotiator altsProtocolNegotiator; - private ProtocolNegotiator tlsProtocolNegotiator; - private GoogleDefaultProtocolNegotiator googleProtocolNegotiator; + private ProtocolNegotiator googleProtocolNegotiator; + + private final ObjectPool handshakerChannelPool = new ObjectPool() { + + @Override + public Channel getObject() { + return InProcessChannelBuilder.forName("test").build(); + } + + @Override + public Channel returnObject(Object object) { + ((ManagedChannel) object).shutdownNow(); + return null; + } + }; @Before - public void setUp() { - altsProtocolNegotiator = mock(ProtocolNegotiator.class); - tlsProtocolNegotiator = mock(ProtocolNegotiator.class); - googleProtocolNegotiator = - new GoogleDefaultProtocolNegotiator(altsProtocolNegotiator, tlsProtocolNegotiator); + public void setUp() throws Exception { + SslContext sslContext = GrpcSslContexts.forClient().build(); + + googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory( + ImmutableList.of(), + handshakerChannelPool, + sslContext) + .buildProtocolNegotiator(); + } + + @After + public void tearDown() { + googleProtocolNegotiator.close(); } @Test @@ -51,9 +82,24 @@ public final class GoogleDefaultProtocolNegotiatorTest { Attributes.newBuilder().set(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND, true).build(); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); - googleProtocolNegotiator.newHandler(mockHandler); - verify(altsProtocolNegotiator, times(1)).newHandler(mockHandler); - verify(tlsProtocolNegotiator, never()).newHandler(mockHandler); + + final AtomicReference failure = new AtomicReference<>(); + ChannelHandler exceptionCaught = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + failure.set(cause); + super.exceptionCaught(ctx, cause); + } + }; + ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler); + EmbeddedChannel chan = new EmbeddedChannel(exceptionCaught); + // Add the negotiator handler last, but to the front. Putting this in ctor above would make it + // throw early. + chan.pipeline().addFirst(h); + + // Check that the message complained about the ALTS code, rather than SSL. ALTS throws on + // being added, so it's hard to catch it at the right time to make this assertion. + assertThat(failure.get()).hasMessageThat().contains("TsiHandshakeHandler"); } @Test @@ -61,8 +107,11 @@ public final class GoogleDefaultProtocolNegotiatorTest { Attributes eagAttributes = Attributes.EMPTY; GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); - googleProtocolNegotiator.newHandler(mockHandler); - verify(altsProtocolNegotiator, never()).newHandler(mockHandler); - verify(tlsProtocolNegotiator, times(1)).newHandler(mockHandler); + when(mockHandler.getAuthority()).thenReturn("authority"); + + ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler); + EmbeddedChannel chan = new EmbeddedChannel(h); + + assertThat(chan.pipeline().first().getClass().getSimpleName()).isEqualTo("SslHandler"); } } diff --git a/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java index efc1f57ba3..df2314b739 100644 --- a/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java @@ -18,18 +18,13 @@ package io.grpc.alts.internal; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static org.junit.Assert.fail; -import io.grpc.alts.internal.TsiFrameHandler.State; -import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent; -import io.grpc.alts.internal.TsiPeer.Property; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; import java.security.GeneralSecurityException; -import java.util.ArrayList; import java.util.List; import org.junit.Rule; import org.junit.Test; @@ -46,32 +41,17 @@ public class TsiFrameHandlerTest { @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5)); - private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler(); + private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler(new IdentityFrameProtector()); private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler); - @Test - public void writeAndFlush_beforeHandshakeEventShouldBeIgnored() { - ByteBuf msg = Unpooled.copiedBuffer("message before handshake finished", CharsetUtil.UTF_8); - - channel.writeAndFlush(msg); - - assertThat(channel.outboundMessages()).isEmpty(); - try { - channel.checkException(); - fail(); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains(State.HANDSHAKE_NOT_FINISHED.name()); - } - } - @Test public void writeAndFlush_handshakeSucceed() throws InterruptedException { - channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent()); ByteBuf msg = Unpooled.copiedBuffer("message after handshake finished", CharsetUtil.UTF_8); channel.writeAndFlush(msg); + Object actual = channel.readOutbound(); - assertThat((Object) channel.readOutbound()).isEqualTo(msg); + assertThat(actual).isEqualTo(msg); channel.close().sync(); channel.checkException(); } @@ -92,40 +72,20 @@ public class TsiFrameHandlerTest { } } - @Test - public void writeAndFlush_handshakeFailed() throws InterruptedException { - channel.pipeline().fireUserEventTriggered(new TsiHandshakeCompletionEvent(new Exception())); - ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8); - - channel.writeAndFlush(msg); - - assertThat(channel.outboundMessages()).isEmpty(); - channel.close().sync(); - channel.checkException(); - } - @Test public void close_shouldFlushRemainingMessage() throws InterruptedException { - channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent()); - ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8); channel.write(msg); assertThat(channel.outboundMessages()).isEmpty(); channel.close().sync(); + Object actual = channel.readOutbound(); - assertWithMessage("pending write should be flushed on close") - .that((Object) channel.readOutbound()).isEqualTo(msg); + assertWithMessage("pending write should be flushed on close").that(actual).isEqualTo(msg); channel.checkException(); } - private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() { - TsiFrameProtector protector = new IdentityFrameProtector(); - TsiPeer peer = new TsiPeer(new ArrayList>()); - return new TsiHandshakeCompletionEvent(protector, peer, new Object()); - } - private static final class IdentityFrameProtector implements TsiFrameProtector { @Override diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index b85ce315c7..60f2a3d49f 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -17,6 +17,7 @@ package io.grpc.netty; import io.grpc.ChannelLogger; +import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler; import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; import io.netty.channel.ChannelHandler; @@ -93,4 +94,9 @@ public final class InternalProtocolNegotiators { public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler next) { return new GrpcNegotiationHandler(next); } + + public static ChannelHandler clientTlsHandler( + ChannelHandler next, SslContext sslContext, String authority) { + return new ClientTlsHandler(next, sslContext, authority); + } }