alts: use new ProtocolNegotiator style for ALTS

This change does a few core things, which result in a lot of churn in other parts.  It's not as bad as it seems.

Core things:

1.  AltsProtocolNegotiator is now a shim class, same as ProtocolNegotiators
2.  The protocol negotiators are now in the new style, where there is at most 1 negotiation handler in the pipe at a time.
3.  TsiHandshakeHandler is rewritten with respect to the above.  All errors and buffering are handled by the WBAEH.
4.  TsiFrameHandler is only installed once the negotiation is successful, eliminating the state handling.


The churn in mainly in GoogleDefaultChannel and the GCE channel, which now reuse the *handlers* rather than the negotiators.  This makes it significantly easier to reason about the pipeline state.  The tests are also a source of churn, which no long need to check for most buffering and error conditions.
This commit is contained in:
Carl Mastrangelo 2019-05-22 16:33:07 -07:00 committed by GitHub
parent f8fffeff12
commit 7834a50525
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 585 additions and 716 deletions

View File

@ -28,18 +28,12 @@ import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.alts.internal.AltsProtocolNegotiator.ClientAltsProtocolNegotiatorFactory;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.NettyChannelBuilder;
import java.util.logging.Level;
import java.util.logging.Logger;
@ -60,8 +54,6 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
private boolean enableUntrustedAlts;
private AltsProtocolNegotiator negotiatorForTest;
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
public static final AltsChannelBuilder forTarget(String target) {
return new AltsChannelBuilder(target);
@ -74,8 +66,6 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
private AltsChannelBuilder(String target) {
delegate = NettyChannelBuilder.forTarget(target);
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
delegate(), new ProtocolNegotiatorFactory());
}
/**
@ -125,41 +115,20 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
delegate().intercept(new FailingClientInterceptor(status));
}
}
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
delegate(),
new ClientAltsProtocolNegotiatorFactory(
targetServiceAccountsBuilder.build(), handshakerChannelPool));
return delegate().build();
}
@VisibleForTesting
@Nullable
AltsProtocolNegotiator getProtocolNegotiatorForTest() {
return negotiatorForTest;
}
private final class ProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
@Override
public AltsProtocolNegotiator buildProtocolNegotiator() {
final ImmutableList<String> targetServiceAccounts = targetServiceAccountsBuilder.build();
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetServiceAccounts(targetServiceAccounts)
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
}
};
return negotiatorForTest =
AltsProtocolNegotiator.createClientNegotiator(
altsHandshakerFactory, lazyHandshakerChannel);
}
ProtocolNegotiator getProtocolNegotiatorForTest() {
return new ClientAltsProtocolNegotiatorFactory(
targetServiceAccountsBuilder.build(), handshakerChannelPool)
.buildProtocolNegotiator();
}
/** An implementation of {@link ClientInterceptor} that fails each call. */

View File

@ -33,14 +33,7 @@ import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer.Factory;
import io.grpc.ServerTransportFilter;
import io.grpc.Status;
import io.grpc.alts.internal.AltsHandshakerOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.NettyServerBuilder;
@ -192,18 +185,8 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
}
}
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
delegate.protocolNegotiator(
AltsProtocolNegotiator.createServerNegotiator(
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
return AltsTsiHandshaker.newServer(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
}
},
lazyHandshakerChannel));
AltsProtocolNegotiator.serverAltsProtocolNegotiator(handshakerChannelPool));
return delegate.build();
}

View File

@ -18,23 +18,18 @@ package io.grpc.alts;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.CallCredentials;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.alts.internal.AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.handler.ssl.SslContext;
import javax.net.ssl.SSLException;
@ -47,12 +42,21 @@ public final class ComputeEngineChannelBuilder
extends ForwardingChannelBuilder<GoogleDefaultChannelBuilder> {
private final NettyChannelBuilder delegate;
private GoogleDefaultProtocolNegotiator negotiatorForTest;
private ComputeEngineChannelBuilder(String target) {
delegate = NettyChannelBuilder.forTarget(target);
SslContext sslContext;
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException e) {
throw new RuntimeException(e);
}
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
delegate(), new ProtocolNegotiatorFactory());
delegate(),
new GoogleDefaultProtocolNegotiatorFactory(
/* targetServiceAccounts= */ ImmutableList.<String>of(),
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
sslContext));
CallCredentials credentials = MoreCallCredentials.from(ComputeEngineCredentials.create());
Status status = Status.OK;
if (!CheckGcpEnvironment.isOnGcp()) {
@ -79,40 +83,17 @@ public final class ComputeEngineChannelBuilder
}
@VisibleForTesting
GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() {
return negotiatorForTest;
}
private final class ProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
@Override
public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
final LazyChannel lazyHandshakerChannel =
new LazyChannel(
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL));
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
}
};
SslContext sslContext;
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException ex) {
throw new RuntimeException(ex);
}
return negotiatorForTest =
new GoogleDefaultProtocolNegotiator(
altsHandshakerFactory, lazyHandshakerChannel, sslContext);
ProtocolNegotiator getProtocolNegotiatorForTest() {
SslContext sslContext;
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException e) {
throw new RuntimeException(e);
}
return new GoogleDefaultProtocolNegotiatorFactory(
/* targetServiceAccounts= */ ImmutableList.<String>of(),
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
sslContext)
.buildProtocolNegotiator();
}
}

View File

@ -18,23 +18,18 @@ package io.grpc.alts;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.CallCredentials;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.alts.internal.AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.handler.ssl.SslContext;
import java.io.IOException;
@ -49,12 +44,21 @@ public final class GoogleDefaultChannelBuilder
extends ForwardingChannelBuilder<GoogleDefaultChannelBuilder> {
private final NettyChannelBuilder delegate;
private GoogleDefaultProtocolNegotiator negotiatorForTest;
private GoogleDefaultChannelBuilder(String target) {
delegate = NettyChannelBuilder.forTarget(target);
SslContext sslContext;
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException e) {
throw new RuntimeException(e);
}
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
delegate(), new ProtocolNegotiatorFactory());
delegate(),
new GoogleDefaultProtocolNegotiatorFactory(
/* targetServiceAccounts= */ ImmutableList.<String>of(),
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
sslContext));
@Nullable CallCredentials credentials = null;
Status status = Status.OK;
try {
@ -84,40 +88,17 @@ public final class GoogleDefaultChannelBuilder
}
@VisibleForTesting
GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() {
return negotiatorForTest;
}
private final class ProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
@Override
public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
final LazyChannel lazyHandshakerChannel =
new LazyChannel(
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL));
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
}
};
SslContext sslContext;
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException ex) {
throw new RuntimeException(ex);
}
return negotiatorForTest =
new GoogleDefaultProtocolNegotiator(
altsHandshakerFactory, lazyHandshakerChannel, sslContext);
ProtocolNegotiator getProtocolNegotiatorForTest() {
SslContext sslContext;
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException e) {
throw new RuntimeException(e);
}
return new GoogleDefaultProtocolNegotiatorFactory(
/* targetServiceAccounts= */ ImmutableList.<String>of(),
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
sslContext)
.buildProtocolNegotiator();
}
}

View File

@ -16,8 +16,10 @@
package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.Any;
import io.grpc.Attributes;
import io.grpc.Channel;
@ -27,102 +29,266 @@ import io.grpc.InternalChannelz.Security;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiators.AbstractBufferingHandler;
import io.grpc.netty.InternalProtocolNegotiators;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.logging.Level;
import java.security.GeneralSecurityException;
import java.util.List;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
* A GRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that provides ALTS
* A gRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that provides ALTS
* security on the wire, similar to Netty's {@code SslHandler}.
*/
public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
// TODO(carl-mastrangelo): rename this AltsProtocolNegotiators.
public final class AltsProtocolNegotiator {
private static final Logger logger = Logger.getLogger(AltsProtocolNegotiator.class.getName());
@Grpc.TransportAttr
public static final Attributes.Key<TsiPeer> TSI_PEER_KEY = Attributes.Key.create("TSI_PEER");
@Grpc.TransportAttr
public static final Attributes.Key<AltsAuthContext> ALTS_CONTEXT_KEY =
Attributes.Key.create("ALTS_CONTEXT_KEY");
public static final Attributes.Key<Object> AUTH_CONTEXT_KEY =
Attributes.Key.create("AUTH_CONTEXT_KEY");
private static final AsciiString scheme = AsciiString.of("https");
private static final AsciiString SCHEME = AsciiString.of("https");
/** Creates a negotiator used for ALTS client. */
public static AltsProtocolNegotiator createClientNegotiator(
final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator {
/**
* ClientAltsProtocolNegotiatorFactory is a factory for doing client side negotiation of an ALTS
* channel.
*/
public static final class ClientAltsProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
private final ImmutableList<String> targetServiceAccounts;
private final LazyChannel lazyHandshakerChannel;
public ClientAltsProtocolNegotiatorFactory(
List<String> targetServiceAccounts,
ObjectPool<Channel> handshakerChannelPool) {
this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts);
this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
}
@Override
public ProtocolNegotiator buildProtocolNegotiator() {
return new ClientAltsProtocolNegotiator(
new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel),
lazyHandshakerChannel);
}
}
@VisibleForTesting
private static final class ClientAltsProtocolNegotiator implements ProtocolNegotiator {
private final TsiHandshakerFactory handshakerFactory;
private final LazyChannel lazyHandshakerChannel;
ClientAltsProtocolNegotiator(
TsiHandshakerFactory handshakerFactory, LazyChannel lazyHandshakerChannel) {
this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory");
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
}
@Override
public AsciiString scheme() {
return SCHEME;
}
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
ChannelHandler thh =
new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator());
ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(thh);
return wuah;
}
@Override
public void close() {
lazyHandshakerChannel.close();
}
}
/**
* Creates a protocol negotiator for ALTS on the server side.
*/
public static ProtocolNegotiator serverAltsProtocolNegotiator(
ObjectPool<Channel> handshakerChannelPool) {
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
final class ServerTsiHandshakerFactory implements TsiHandshakerFactory {
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
public TsiHandshaker newHandshaker(@Nullable String authority) {
assert authority == null;
return AltsTsiHandshaker.newServer(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
}
}
return new ServerAltsProtocolNegotiator(
new ServerTsiHandshakerFactory(), lazyHandshakerChannel);
}
@VisibleForTesting
static final class ServerAltsProtocolNegotiator implements ProtocolNegotiator {
private final TsiHandshakerFactory handshakerFactory;
private final LazyChannel lazyHandshakerChannel;
@VisibleForTesting
ServerAltsProtocolNegotiator(
TsiHandshakerFactory handshakerFactory, LazyChannel lazyHandshakerChannel) {
this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory");
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
}
@Override
public AsciiString scheme() {
return SCHEME;
}
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(/* authority= */ null);
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
ChannelHandler thh =
new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator());
ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(thh);
return wuah;
}
@Override
public void close() {
logger.finest("ALTS Server ProtocolNegotiator Closed");
lazyHandshakerChannel.close();
}
}
/**
* A Protocol Negotiator factory which can switch between ALTS and TLS based on EAG Attrs.
*/
public static final class GoogleDefaultProtocolNegotiatorFactory
implements ProtocolNegotiatorFactory {
private final ImmutableList<String> targetServiceAccounts;
private final LazyChannel lazyHandshakerChannel;
private final SslContext sslContext;
/**
* Creates Negotiator Factory, which will either use the targetServiceAccounts and
* handshakerChannelPool, or the sslContext.
*/
public GoogleDefaultProtocolNegotiatorFactory(
List<String> targetServiceAccounts,
ObjectPool<Channel> handshakerChannelPool,
SslContext sslContext) {
this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts);
this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
this.sslContext = checkNotNull(sslContext, "sslContext");
}
@Override
public ProtocolNegotiator buildProtocolNegotiator() {
return new GoogleDefaultProtocolNegotiator(
new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel),
lazyHandshakerChannel,
sslContext);
}
}
private static final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator {
private final TsiHandshakerFactory handshakerFactory;
private final LazyChannel lazyHandshakerChannel;
private final SslContext sslContext;
GoogleDefaultProtocolNegotiator(
TsiHandshakerFactory handshakerFactory,
LazyChannel lazyHandshakerChannel,
SslContext sslContext) {
this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory");
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
this.sslContext = checkNotNull(sslContext, "checkNotNull");
}
@Override
public AsciiString scheme() {
return SCHEME;
}
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
ChannelHandler securityHandler;
if (grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY) != null
|| grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND) != null) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
return new BufferUntilAltsNegotiatedHandler(
grpcHandler,
new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)),
new TsiFrameHandler());
}
@Override
public void close() {
logger.finest("ALTS Client ProtocolNegotiator Closed");
lazyHandshakerChannel.close();
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
securityHandler =
new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator());
} else {
securityHandler = InternalProtocolNegotiators.clientTlsHandler(
gnh, sslContext, grpcHandler.getAuthority());
}
ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(securityHandler);
return wuah;
}
return new ClientAltsProtocolNegotiator();
@Override
public void close() {
logger.finest("ALTS Server ProtocolNegotiator Closed");
lazyHandshakerChannel.close();
}
}
@Override
public final AsciiString scheme() {
return scheme;
}
private static final class ClientTsiHandshakerFactory implements TsiHandshakerFactory {
/** Creates a negotiator used for ALTS server. */
public static AltsProtocolNegotiator createServerNegotiator(
final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator {
private final ImmutableList<String> targetServiceAccounts;
private final LazyChannel lazyHandshakerChannel;
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null);
return new BufferUntilAltsNegotiatedHandler(
grpcHandler,
new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)),
new TsiFrameHandler());
}
@Override
public void close() {
logger.finest("ALTS Server ProtocolNegotiator Closed");
lazyHandshakerChannel.close();
}
ClientTsiHandshakerFactory(
ImmutableList<String> targetServiceAccounts, LazyChannel lazyHandshakerChannel) {
this.targetServiceAccounts = checkNotNull(targetServiceAccounts, "targetServiceAccounts");
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
}
return new ServerAltsProtocolNegotiator();
@Override
public TsiHandshaker newHandshaker(@Nullable String authority) {
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetServiceAccounts(targetServiceAccounts)
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
}
}
/** Channel created from a channel pool lazily. */
public static class LazyChannel {
@VisibleForTesting
static final class LazyChannel {
private final ObjectPool<Channel> channelPool;
private Channel channel;
public LazyChannel(ObjectPool<Channel> channelPool) {
this.channelPool = channelPool;
@VisibleForTesting
LazyChannel(ObjectPool<Channel> channelPool) {
this.channelPool = checkNotNull(channelPool, "channelPool");
}
/**
* If channel is null, gets a channel from the channel pool, otherwise, returns the cached
* channel.
*/
public synchronized Channel get() {
synchronized Channel get() {
if (channel == null) {
channel = channelPool.getObject();
}
@ -130,87 +296,37 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
}
/** Returns the cached channel to the channel pool. */
public synchronized void close() {
synchronized void close() {
if (channel != null) {
channelPool.returnObject(channel);
}
}
}
/** Buffers all writes until the ALTS handshake is complete. */
@VisibleForTesting
static final class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler {
private final GrpcHttp2ConnectionHandler grpcHandler;
BufferUntilAltsNegotiatedHandler(
GrpcHttp2ConnectionHandler grpcHandler, ChannelHandler... negotiationhandlers) {
super(negotiationhandlers);
// Save the gRPC handler. The ALTS handler doesn't support buffering before the handshake
// completes, so we wait until the handshake was successful before adding the grpc handler.
this.grpcHandler = grpcHandler;
}
// TODO: Remove this once https://github.com/grpc/grpc-java/pull/3715 is in.
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.log(Level.FINEST, "Exception while buffering for ALTS Negotiation", cause);
fail(ctx, cause);
ctx.fireExceptionCaught(cause);
}
private static final class AltsHandshakeValidator extends TsiHandshakeHandler.HandshakeValidator {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (logger.isLoggable(Level.FINEST)) {
logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[] {evt});
public SecurityDetails validatePeerObject(Object peerObject) throws GeneralSecurityException {
AltsAuthContext altsAuthContext = (AltsAuthContext) peerObject;
// Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if
// Rpc Protocol Versions mismatch.
RpcVersionsCheckResult checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersionsUtil.getRpcProtocolVersions(),
altsAuthContext.getPeerRpcVersions());
if (!checkResult.getResult()) {
String errorMessage =
"Local Rpc Protocol Versions "
+ RpcProtocolVersionsUtil.getRpcProtocolVersions()
+ " are not compatible with peer Rpc Protocol Versions "
+ altsAuthContext.getPeerRpcVersions();
throw Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException();
}
if (evt instanceof TsiHandshakeCompletionEvent) {
TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt;
if (altsEvt.isSuccess()) {
// Add the gRPC handler just before this handler. We only allow the grpcHandler to be
// null to support testing. In production, a grpc handler will always be provided.
if (grpcHandler != null) {
ctx.pipeline().addBefore(ctx.name(), null, grpcHandler);
AltsAuthContext altsContext = (AltsAuthContext) altsEvt.context();
Preconditions.checkNotNull(altsContext);
// Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if
// Rpc Protocol Versions mismatch.
RpcVersionsCheckResult checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersionsUtil.getRpcProtocolVersions(),
altsContext.getPeerRpcVersions());
if (!checkResult.getResult()) {
String errorMessage =
"Local Rpc Protocol Versions "
+ RpcProtocolVersionsUtil.getRpcProtocolVersions().toString()
+ "are not compatible with peer Rpc Protocol Versions "
+ altsContext.getPeerRpcVersions().toString();
logger.finest(errorMessage);
fail(ctx, Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException());
}
grpcHandler.handleProtocolNegotiationCompleted(
Attributes.newBuilder()
.set(TSI_PEER_KEY, altsEvt.peer())
.set(ALTS_CONTEXT_KEY, altsContext)
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
.build(),
new Security(new OtherSecurity("alts", Any.pack(altsContext.context))));
}
logger.finest("Flushing ALTS buffered data");
// Now write any buffered data and remove this handler.
writeBufferedAndRemove(ctx);
} else {
logger.log(Level.FINEST, "ALTS handshake failed", altsEvt.cause());
fail(ctx, unavailableException("ALTS handshake failed", altsEvt.cause()));
}
}
super.userEventTriggered(ctx, evt);
}
private static RuntimeException unavailableException(String msg, Throwable cause) {
return Status.UNAVAILABLE.withCause(cause).withDescription(msg).asRuntimeException();
return new SecurityDetails(
SecurityLevel.PRIVACY_AND_INTEGRITY,
new Security(new OtherSecurity("alts", Any.pack(altsAuthContext.context))));
}
}
private AltsProtocolNegotiator() {}
}

View File

@ -1,71 +0,0 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.internal.GrpcAttributes;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiators;
import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
/** A client-side GPRC {@link ProtocolNegotiator} for Google Default Channel. */
public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator {
private final ProtocolNegotiator altsProtocolNegotiator;
private final ProtocolNegotiator tlsProtocolNegotiator;
/** Constructor for protocol negotiator of Google Default Channel. */
public GoogleDefaultProtocolNegotiator(
TsiHandshakerFactory altsFactory, LazyChannel lazyHandshakerChannel, SslContext sslContext) {
altsProtocolNegotiator =
AltsProtocolNegotiator.createClientNegotiator(altsFactory, lazyHandshakerChannel);
tlsProtocolNegotiator = InternalProtocolNegotiators.tls(sslContext);
}
@Override
public AsciiString scheme() {
assert tlsProtocolNegotiator.scheme().equals(altsProtocolNegotiator.scheme());
return tlsProtocolNegotiator.scheme();
}
@VisibleForTesting
GoogleDefaultProtocolNegotiator(
ProtocolNegotiator altsProtocolNegotiator, ProtocolNegotiator tlsProtocolNegotiator) {
this.altsProtocolNegotiator = altsProtocolNegotiator;
this.tlsProtocolNegotiator = tlsProtocolNegotiator;
}
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
if (grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY) != null
|| grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND) != null) {
return altsProtocolNegotiator.newHandler(grpcHandler);
} else {
return tlsProtocolNegotiator.newHandler(grpcHandler);
}
}
@Override
public void close() {
altsProtocolNegotiator.close();
tlsProtocolNegotiator.close();
}
}

View File

@ -19,9 +19,7 @@ package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.alts.internal.TsiFrameProtector.Consumer;
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelHandlerContext;
@ -33,7 +31,6 @@ import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
@ -47,72 +44,33 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
private TsiFrameProtector protector;
private PendingWriteQueue pendingUnprotectedWrites;
private State state = State.HANDSHAKE_NOT_FINISHED;
private boolean closeInitiated = false;
private boolean closeInitiated;
@VisibleForTesting
enum State {
HANDSHAKE_NOT_FINISHED,
PROTECTED,
CLOSED,
HANDSHAKE_FAILED
public TsiFrameHandler(TsiFrameProtector protector) {
this.protector = checkNotNull(protector, "protector");
}
public TsiFrameHandler() {}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
logger.finest("TsiFrameHandler added");
super.handlerAdded(ctx);
assert pendingUnprotectedWrites == null;
pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx));
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception {
if (logger.isLoggable(Level.FINEST)) {
logger.log(Level.FINEST, "TsiFrameHandler user event triggered", new Object[]{event});
}
if (event instanceof TsiHandshakeCompletionEvent) {
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
if (tsiEvent.isSuccess()) {
setProtector(tsiEvent.protector());
} else {
state = State.HANDSHAKE_FAILED;
}
// Ignore errors. Another handler in the pipeline must handle TSI Errors.
}
// Keep propagating the message, as others may want to read it.
super.userEventTriggered(ctx, event);
}
@VisibleForTesting
void setProtector(TsiFrameProtector protector) {
logger.finest("TsiFrameHandler protector set");
checkState(this.protector == null);
this.protector = checkNotNull(protector);
this.state = State.PROTECTED;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
checkState(
state == State.PROTECTED,
"Cannot read frames while the TSI handshake is %s", state);
checkState(protector != null, "decode() called after close()");
protector.unprotect(in, out, ctx.alloc());
}
@Override
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
throws Exception {
checkState(
state == State.PROTECTED,
"Cannot write frames while the TSI handshake state is %s", state);
@SuppressWarnings("FutureReturnValueIgnored") // for setSuccess
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) {
checkState(protector != null, "write() called after close()");
ByteBuf msg = (ByteBuf) message;
if (!msg.isReadable()) {
// Nothing to encode.
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = promise.setSuccess();
promise.setSuccess();
return;
}
@ -122,30 +80,11 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (!pendingUnprotectedWrites.isEmpty()) {
if (pendingUnprotectedWrites != null && !pendingUnprotectedWrites.isEmpty()) {
pendingUnprotectedWrites.removeAndFailAll(
new ChannelException("Pending write on removal of TSI handler"));
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
pendingUnprotectedWrites.removeAndFailAll(cause);
super.exceptionCaught(ctx, cause);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
ctx.bind(localAddress, promise);
}
@Override
public void connect(
ChannelHandlerContext ctx,
SocketAddress remoteAddress,
SocketAddress localAddress,
ChannelPromise promise) {
ctx.connect(remoteAddress, localAddress, promise);
destroyProtector();
}
@Override
@ -154,6 +93,12 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
ctx.disconnect(promise);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
doClose(ctx);
ctx.close(promise);
}
private void doClose(ChannelHandlerContext ctx) {
if (closeInitiated) {
return;
@ -165,51 +110,34 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
flush(ctx);
}
} catch (GeneralSecurityException e) {
logger.log(Level.FINE, "Ignoring error on flush before close", e);
logger.log(Level.FINE, "Ignored error on flush before close", e);
} finally {
state = State.CLOSED;
release();
pendingUnprotectedWrites = null;
destroyProtector();
}
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
doClose(ctx);
ctx.close(promise);
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
doClose(ctx);
ctx.deregister(promise);
}
@Override
public void read(ChannelHandlerContext ctx) {
ctx.read();
}
@Override
@SuppressWarnings("FutureReturnValueIgnored") // for aggregatePromise.doneAllocatingPromises
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) {
logger.fine(
String.format("FrameHandler is inactive(%s), channel id: %s",
state, ctx.channel().id().asShortText()));
if (protector == null) {
// TODO(carl-mastrangelo): this should be a checkState. AbstractNettyHandler.exceptionCaught
// transitively calls flush even after closed, for some reason.
pendingUnprotectedWrites.removeAndFailAll(
new ChannelException("Pending write on removal of TSI handler"));
logger.fine("flush() called after close()");
return;
}
checkState(
state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state);
final ProtectedPromise aggregatePromise =
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
if (pendingUnprotectedWrites.isEmpty()) {
// Return early if there's nothing to write. Otherwise protector.protectFlush() below may
// not check for "no-data" and go on writing the 0-byte "data" to the socket with the
// protection framing.
return;
}
final ProtectedPromise aggregatePromise =
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
// Drain the unprotected writes.
while (!pendingUnprotectedWrites.isEmpty()) {
ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current();
@ -218,25 +146,54 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove());
}
protector.protectFlush(
bufs,
new Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf b) {
ctx.writeAndFlush(b, aggregatePromise.newPromise());
}
},
ctx.alloc());
final class ProtectedFrameWriteFlusher implements Consumer<ByteBuf> {
@Override
public void accept(ByteBuf byteBuf) {
ctx.writeAndFlush(byteBuf, aggregatePromise.newPromise());
}
}
protector.protectFlush(bufs, new ProtectedFrameWriteFlusher(), ctx.alloc());
// We're done writing, start the flow of promise events.
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = aggregatePromise.doneAllocatingPromises();
aggregatePromise.doneAllocatingPromises();
}
private void release() {
// Only here to fulfill ChannelOutboundHandler
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
ctx.bind(localAddress, promise);
}
// Only here to fulfill ChannelOutboundHandler
@Override
public void connect(
ChannelHandlerContext ctx,
SocketAddress remoteAddress,
SocketAddress localAddress,
ChannelPromise promise) {
ctx.connect(remoteAddress, localAddress, promise);
}
// Only here to fulfill ChannelOutboundHandler
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
ctx.deregister(promise);
}
// Only here to fulfill ChannelOutboundHandler
@Override
public void read(ChannelHandlerContext ctx) {
ctx.read();
}
private void destroyProtector() {
if (protector != null) {
protector.destroy();
protector = null;
try {
protector.destroy();
} finally {
protector = null;
}
}
}
}

View File

@ -17,19 +17,25 @@
package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.alts.internal.AltsProtocolNegotiator.AUTH_CONTEXT_KEY;
import static io.grpc.alts.internal.AltsProtocolNegotiator.TSI_PEER_KEY;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import io.grpc.Attributes;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.InternalChannelz.Security;
import io.grpc.SecurityLevel;
import io.grpc.alts.internal.TsiHandshakeHandler.HandshakeValidator.SecurityDetails;
import io.grpc.internal.GrpcAttributes;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.ProtocolNegotiationEvent;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.ReferenceCountUtil;
import java.security.GeneralSecurityException;
import java.util.List;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
@ -38,118 +44,63 @@ import javax.annotation.Nullable;
*/
public final class TsiHandshakeHandler extends ByteToMessageDecoder {
private static final Logger logger = Logger.getLogger(TsiHandshakeHandler.class.getName());
/**
* Validates a Tsi Peer object.
*/
public abstract static class HandshakeValidator {
public static final class SecurityDetails {
private final SecurityLevel securityLevel;
private final Security security;
/**
* Constructs SecurityDetails.
*/
public SecurityDetails(io.grpc.SecurityLevel securityLevel, @Nullable Security security) {
this.securityLevel = checkNotNull(securityLevel, "securityLevel");
this.security = security;
}
public Security getSecurity() {
return security;
}
public SecurityLevel getSecurityLevel() {
return securityLevel;
}
}
/**
* Validates a Tsi Peer object.
*/
public abstract SecurityDetails validatePeerObject(Object peerObject)
throws GeneralSecurityException;
}
private static final int HANDSHAKE_FRAME_SIZE = 1024;
private final NettyTsiHandshaker handshaker;
private boolean started;
private final HandshakeValidator handshakeValidator;
private final ChannelHandler next;
private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault();
/**
* This buffer doesn't store any state. We just hold onto it in case we end up allocating a buffer
* that ends up being unused.
* Constructs a TsiHandshakeHandler.
*/
private ByteBuf buffer;
public TsiHandshakeHandler(NettyTsiHandshaker handshaker) {
this.handshaker = checkNotNull(handshaker);
}
/**
* Event that is fired once the TSI handshake is complete, which may be because it was successful
* or there was an error.
*/
public static final class TsiHandshakeCompletionEvent {
private final Throwable cause;
private final TsiPeer peer;
private final Object context;
private final TsiFrameProtector protector;
/** Creates a new event that indicates a successful handshake. */
@VisibleForTesting
TsiHandshakeCompletionEvent(
TsiFrameProtector protector, TsiPeer peer, @Nullable Object peerObject) {
this.cause = null;
this.peer = checkNotNull(peer);
this.protector = checkNotNull(protector);
this.context = peerObject;
}
/** Creates a new event that indicates an unsuccessful handshake/. */
TsiHandshakeCompletionEvent(Throwable cause) {
this.cause = checkNotNull(cause);
this.peer = null;
this.protector = null;
this.context = null;
}
/** Return {@code true} if the handshake was successful. */
public boolean isSuccess() {
return cause == null;
}
/**
* Return the {@link Throwable} if {@link #isSuccess()} returns {@code false} and so the
* handshake failed.
*/
@Nullable
public Throwable cause() {
return cause;
}
@Nullable
public TsiPeer peer() {
return peer;
}
@Nullable
public Object context() {
return context;
}
@Nullable
TsiFrameProtector protector() {
return protector;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("peer", peer)
.add("protector", protector)
.add("context", context)
.add("cause", cause)
.toString();
}
public TsiHandshakeHandler(
ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator) {
this.handshaker = checkNotNull(handshaker, "handshaker");
this.handshakeValidator = checkNotNull(handshakeValidator, "handshakeValidator");
this.next = checkNotNull(next, "next");
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
logger.finest("TsiHandshakeHandler added");
maybeStart(ctx);
super.handlerAdded(ctx);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
logger.finest("TsiHandshakeHandler channel active");
maybeStart(ctx);
super.channelActive(ctx);
}
@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
logger.finest("TsiHandshakeHandler handler removed");
close();
super.handlerRemoved0(ctx);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.log(Level.FINEST, "Exception in TsiHandshakeHandler", cause);
ctx.fireUserEventTriggered(new TsiHandshakeCompletionEvent(cause));
super.exceptionCaught(ctx, cause);
InternalProtocolNegotiators.negotiationLogger(ctx)
.log(ChannelLogLevel.INFO, "TsiHandshake started");
sendHandshake(ctx);
}
@Override
@ -168,71 +119,72 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
// If the handshake is complete, transition to the framing state.
if (!handshaker.isInProgress()) {
TsiFrameProtector protector = null;
TsiPeer peer = handshaker.extractPeer();
Object authContext = handshaker.extractPeerObject();
SecurityDetails details = handshakeValidator.validatePeerObject(authContext);
// createFrameProtector must be called last.
TsiFrameProtector protector = handshaker.createFrameProtector(ctx.alloc());
TsiFrameHandler framer;
boolean success = false;
try {
ctx.pipeline().remove(this);
protector = handshaker.createFrameProtector(ctx.alloc());
TsiHandshakeCompletionEvent evt = new TsiHandshakeCompletionEvent(
protector,
handshaker.extractPeer(),
handshaker.extractPeerObject());
protector = null;
ctx.fireUserEventTriggered(evt);
// No need to do anything with the in buffer, it will be re added to the pipeline when this
// handler is removed.
framer = new TsiFrameHandler(protector);
// replace the current handler with the framer (instead of adding before) since there may
// be pending data after the handshake frame. The data will need to be decoded before
// being passed to the `next` handler.
ctx.pipeline().replace(ctx.name(), null, framer);
// Once the framer is in the pipeline, it will be cleaned up when the handler is removed.
success = true;
} finally {
if (protector != null) {
if (!success && protector != null) {
protector.destroy();
}
close();
}
// Add the `next` handler as late as possible, as it will issue writes on being added.
ctx.pipeline().addAfter(ctx.pipeline().context(framer).name(), null, next);
fireProtocolNegotiationEvent(ctx, peer, authContext, details);
}
}
private void maybeStart(ChannelHandlerContext ctx) {
if (!started && ctx.channel().isActive()) {
started = true;
sendHandshake(ctx);
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
pne = (ProtocolNegotiationEvent) evt;
} else {
super.userEventTriggered(ctx, evt);
}
}
private void fireProtocolNegotiationEvent(
ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details) {
InternalProtocolNegotiators.negotiationLogger(ctx)
.log(ChannelLogLevel.INFO, "TsiHandshake finished");
ProtocolNegotiationEvent localPne = pne;
Attributes.Builder attrs = InternalProtocolNegotiationEvent.getAttributes(localPne).toBuilder()
.set(TSI_PEER_KEY, peer)
.set(AUTH_CONTEXT_KEY, authContext)
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, details.getSecurityLevel());
localPne = InternalProtocolNegotiationEvent.withAttributes(localPne, attrs.build());
localPne = InternalProtocolNegotiationEvent.withSecurity(localPne, details.getSecurity());
ctx.fireUserEventTriggered(localPne);
}
/** Sends as many bytes as are available from the handshaker to the remote peer. */
private void sendHandshake(ChannelHandlerContext ctx) {
boolean needToFlush = false;
// Iterate until there is nothing left to write.
@SuppressWarnings("FutureReturnValueIgnored") // for addListener
private void sendHandshake(ChannelHandlerContext ctx) throws GeneralSecurityException {
while (true) {
buffer = getOrCreateBuffer(ctx.alloc());
boolean written = false;
ByteBuf buf = ctx.alloc().buffer(HANDSHAKE_FRAME_SIZE).retain(); // refcnt = 2
try {
handshaker.getBytesToSendToPeer(buffer);
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
handshaker.getBytesToSendToPeer(buf);
if (buf.isReadable()) {
ctx.writeAndFlush(buf).addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
written = true;
} else {
break;
}
} finally {
buf.release(written ? 1 : 2);
}
if (!buffer.isReadable()) {
break;
}
needToFlush = true;
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = ctx.write(buffer);
buffer = null;
}
// If something was written, flush.
if (needToFlush) {
ctx.flush();
}
}
private ByteBuf getOrCreateBuffer(ByteBufAllocator alloc) {
if (buffer == null) {
buffer = alloc.buffer(HANDSHAKE_FRAME_SIZE);
}
return buffer;
}
private void close() {
ReferenceCountUtil.safeRelease(buffer);
buffer = null;
}
}

View File

@ -18,7 +18,6 @@ package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -28,17 +27,14 @@ import org.junit.runners.JUnit4;
public final class AltsChannelBuilderTest {
@Test
public void buildsNettyChannel() throws Exception {
public void buildsNettyChannel() {
AltsChannelBuilder builder =
AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting();
ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
assertThat(protocolNegotiator).isNull();
builder.build();
protocolNegotiator = builder.getProtocolNegotiatorForTest();
assertThat(protocolNegotiator).isNotNull();
assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class);
// Avoids exposing this class
assertThat(protocolNegotiator.getClass().getSimpleName())
.isEqualTo("ClientAltsProtocolNegotiator");
}
}

View File

@ -18,7 +18,6 @@ package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -33,6 +32,7 @@ public final class ComputeEngineChannelBuilderTest {
builder.build();
ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class);
assertThat(protocolNegotiator.getClass().getSimpleName())
.isEqualTo("GoogleDefaultProtocolNegotiator");
}
}

View File

@ -18,7 +18,6 @@ package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -33,6 +32,7 @@ public final class GoogleDefaultChannelBuilderTest {
builder.build();
ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class);
assertThat(protocolNegotiator.getClass().getSimpleName())
.isEqualTo("GoogleDefaultProtocolNegotiator");
}
}

View File

@ -27,6 +27,7 @@ import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.InternalChannelz.Security;
import io.grpc.ManagedChannel;
import io.grpc.SecurityLevel;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
@ -80,6 +81,7 @@ import org.junit.runners.JUnit4;
/** Tests for {@link AltsProtocolNegotiator}. */
@RunWith(JUnit4.class)
@SuppressWarnings("FutureReturnValueIgnored")
public class AltsProtocolNegotiatorTest {
private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
@ -90,7 +92,6 @@ public class AltsProtocolNegotiatorTest {
private EmbeddedChannel channel;
private Throwable caughtException;
private volatile TsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
private ChannelHandler handler;
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
@ -102,12 +103,12 @@ public class AltsProtocolNegotiatorTest {
private final TsiHandshaker mockHandshaker =
new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
public TsiPeer extractPeer() {
return mockedTsiPeer;
}
@Override
public Object extractPeerObject() throws GeneralSecurityException {
public Object extractPeerObject() {
return mockedAltsContext;
}
};
@ -115,24 +116,13 @@ public class AltsProtocolNegotiatorTest {
@Before
public void setup() throws Exception {
ChannelHandler userEventHandler =
new ChannelDuplexHandler() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof TsiHandshakeHandler.TsiHandshakeCompletionEvent) {
tsiEvent = (TsiHandshakeHandler.TsiHandshakeCompletionEvent) evt;
} else {
super.userEventTriggered(ctx, evt);
}
}
};
ChannelHandler uncaughtExceptionHandler =
new ChannelDuplexHandler() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
caughtException = cause;
super.exceptionCaught(ctx, cause);
ctx.close();
}
};
@ -157,9 +147,9 @@ public class AltsProtocolNegotiatorTest {
ObjectPool<Channel> fakeChannelPool = new FixedObjectPool<Channel>(fakeChannel);
LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool);
handler =
AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, lazyFakeChannel)
new AltsProtocolNegotiator.ServerAltsProtocolNegotiator(handshakerFactory, lazyFakeChannel)
.newHandler(grpcHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler);
}
@After
@ -182,6 +172,8 @@ public class AltsProtocolNegotiatorTest {
@Test
@SuppressWarnings("unchecked") // List cast
public void protectShouldRoundtrip() throws Exception {
doHandshake();
// Write the message 1 character at a time. The message should be buffered
// and not interfere with the handshake.
final AtomicInteger writeCount = new AtomicInteger();
@ -204,10 +196,6 @@ public class AltsProtocolNegotiatorTest {
}
channel.flush();
// Now do the handshake. The buffered message will automatically be protected
// and sent.
doHandshake();
// Capture the protected data written to the wire.
assertEquals(1, channel.outboundMessages().size());
ByteBuf protectedData = channel.readOutbound();
@ -351,7 +339,7 @@ public class AltsProtocolNegotiatorTest {
doHandshake();
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer);
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY))
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY))
.isEqualTo(mockedAltsContext);
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
.isEqualTo("embedded");
@ -388,7 +376,7 @@ public class AltsProtocolNegotiatorTest {
if (caughtException != null) {
throw new RuntimeException(caughtException);
}
assertNotNull(tsiEvent);
assertNotNull(grpcHandler.attrs);
}
private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
@ -408,6 +396,7 @@ public class AltsProtocolNegotiatorTest {
private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
private Attributes attrs;
private Security securityInfo;
private CapturingGrpcHttp2ConnectionHandler(
Http2ConnectionDecoder decoder,
@ -422,6 +411,7 @@ public class AltsProtocolNegotiatorTest {
// If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler
channel.pipeline().remove(this);
this.attrs = attrs;
this.securityInfo = securityInfo;
}
}

View File

@ -16,16 +16,27 @@
package io.grpc.alts.internal;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.SslContext;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -33,16 +44,36 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public final class GoogleDefaultProtocolNegotiatorTest {
private ProtocolNegotiator altsProtocolNegotiator;
private ProtocolNegotiator tlsProtocolNegotiator;
private GoogleDefaultProtocolNegotiator googleProtocolNegotiator;
private ProtocolNegotiator googleProtocolNegotiator;
private final ObjectPool<Channel> handshakerChannelPool = new ObjectPool<Channel>() {
@Override
public Channel getObject() {
return InProcessChannelBuilder.forName("test").build();
}
@Override
public Channel returnObject(Object object) {
((ManagedChannel) object).shutdownNow();
return null;
}
};
@Before
public void setUp() {
altsProtocolNegotiator = mock(ProtocolNegotiator.class);
tlsProtocolNegotiator = mock(ProtocolNegotiator.class);
googleProtocolNegotiator =
new GoogleDefaultProtocolNegotiator(altsProtocolNegotiator, tlsProtocolNegotiator);
public void setUp() throws Exception {
SslContext sslContext = GrpcSslContexts.forClient().build();
googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory(
ImmutableList.<String>of(),
handshakerChannelPool,
sslContext)
.buildProtocolNegotiator();
}
@After
public void tearDown() {
googleProtocolNegotiator.close();
}
@Test
@ -51,9 +82,24 @@ public final class GoogleDefaultProtocolNegotiatorTest {
Attributes.newBuilder().set(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND, true).build();
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
when(mockHandler.getEagAttributes()).thenReturn(eagAttributes);
googleProtocolNegotiator.newHandler(mockHandler);
verify(altsProtocolNegotiator, times(1)).newHandler(mockHandler);
verify(tlsProtocolNegotiator, never()).newHandler(mockHandler);
final AtomicReference<Throwable> failure = new AtomicReference<>();
ChannelHandler exceptionCaught = new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
failure.set(cause);
super.exceptionCaught(ctx, cause);
}
};
ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler);
EmbeddedChannel chan = new EmbeddedChannel(exceptionCaught);
// Add the negotiator handler last, but to the front. Putting this in ctor above would make it
// throw early.
chan.pipeline().addFirst(h);
// Check that the message complained about the ALTS code, rather than SSL. ALTS throws on
// being added, so it's hard to catch it at the right time to make this assertion.
assertThat(failure.get()).hasMessageThat().contains("TsiHandshakeHandler");
}
@Test
@ -61,8 +107,11 @@ public final class GoogleDefaultProtocolNegotiatorTest {
Attributes eagAttributes = Attributes.EMPTY;
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
when(mockHandler.getEagAttributes()).thenReturn(eagAttributes);
googleProtocolNegotiator.newHandler(mockHandler);
verify(altsProtocolNegotiator, never()).newHandler(mockHandler);
verify(tlsProtocolNegotiator, times(1)).newHandler(mockHandler);
when(mockHandler.getAuthority()).thenReturn("authority");
ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler);
EmbeddedChannel chan = new EmbeddedChannel(h);
assertThat(chan.pipeline().first().getClass().getSimpleName()).isEqualTo("SslHandler");
}
}

View File

@ -18,18 +18,13 @@ package io.grpc.alts.internal;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static org.junit.Assert.fail;
import io.grpc.alts.internal.TsiFrameHandler.State;
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.grpc.alts.internal.TsiPeer.Property;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import org.junit.Rule;
import org.junit.Test;
@ -46,32 +41,17 @@ public class TsiFrameHandlerTest {
@Rule
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5));
private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler();
private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler(new IdentityFrameProtector());
private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler);
@Test
public void writeAndFlush_beforeHandshakeEventShouldBeIgnored() {
ByteBuf msg = Unpooled.copiedBuffer("message before handshake finished", CharsetUtil.UTF_8);
channel.writeAndFlush(msg);
assertThat(channel.outboundMessages()).isEmpty();
try {
channel.checkException();
fail();
} catch (IllegalStateException e) {
assertThat(e).hasMessageThat().contains(State.HANDSHAKE_NOT_FINISHED.name());
}
}
@Test
public void writeAndFlush_handshakeSucceed() throws InterruptedException {
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
ByteBuf msg = Unpooled.copiedBuffer("message after handshake finished", CharsetUtil.UTF_8);
channel.writeAndFlush(msg);
Object actual = channel.readOutbound();
assertThat((Object) channel.readOutbound()).isEqualTo(msg);
assertThat(actual).isEqualTo(msg);
channel.close().sync();
channel.checkException();
}
@ -92,40 +72,20 @@ public class TsiFrameHandlerTest {
}
}
@Test
public void writeAndFlush_handshakeFailed() throws InterruptedException {
channel.pipeline().fireUserEventTriggered(new TsiHandshakeCompletionEvent(new Exception()));
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
channel.writeAndFlush(msg);
assertThat(channel.outboundMessages()).isEmpty();
channel.close().sync();
channel.checkException();
}
@Test
public void close_shouldFlushRemainingMessage() throws InterruptedException {
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
channel.write(msg);
assertThat(channel.outboundMessages()).isEmpty();
channel.close().sync();
Object actual = channel.readOutbound();
assertWithMessage("pending write should be flushed on close")
.that((Object) channel.readOutbound()).isEqualTo(msg);
assertWithMessage("pending write should be flushed on close").that(actual).isEqualTo(msg);
channel.checkException();
}
private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() {
TsiFrameProtector protector = new IdentityFrameProtector();
TsiPeer peer = new TsiPeer(new ArrayList<Property<?>>());
return new TsiHandshakeCompletionEvent(protector, peer, new Object());
}
private static final class IdentityFrameProtector implements TsiFrameProtector {
@Override

View File

@ -17,6 +17,7 @@
package io.grpc.netty;
import io.grpc.ChannelLogger;
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.channel.ChannelHandler;
@ -93,4 +94,9 @@ public final class InternalProtocolNegotiators {
public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler next) {
return new GrpcNegotiationHandler(next);
}
public static ChannelHandler clientTlsHandler(
ChannelHandler next, SslContext sslContext, String authority) {
return new ClientTlsHandler(next, sslContext, authority);
}
}