mirror of https://github.com/grpc/grpc-java.git
alts: use new ProtocolNegotiator style for ALTS
This change does a few core things, which result in a lot of churn in other parts. It's not as bad as it seems. Core things: 1. AltsProtocolNegotiator is now a shim class, same as ProtocolNegotiators 2. The protocol negotiators are now in the new style, where there is at most 1 negotiation handler in the pipe at a time. 3. TsiHandshakeHandler is rewritten with respect to the above. All errors and buffering are handled by the WBAEH. 4. TsiFrameHandler is only installed once the negotiation is successful, eliminating the state handling. The churn in mainly in GoogleDefaultChannel and the GCE channel, which now reuse the *handlers* rather than the negotiators. This makes it significantly easier to reason about the pipeline state. The tests are also a source of churn, which no long need to check for most buffering and error conditions.
This commit is contained in:
parent
f8fffeff12
commit
7834a50525
|
|
@ -28,18 +28,12 @@ import io.grpc.ManagedChannel;
|
|||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.MethodDescriptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.AltsClientOptions;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
|
||||
import io.grpc.alts.internal.AltsTsiHandshaker;
|
||||
import io.grpc.alts.internal.HandshakerServiceGrpc;
|
||||
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
|
||||
import io.grpc.alts.internal.TsiHandshaker;
|
||||
import io.grpc.alts.internal.TsiHandshakerFactory;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.ClientAltsProtocolNegotiatorFactory;
|
||||
import io.grpc.internal.GrpcUtil;
|
||||
import io.grpc.internal.ObjectPool;
|
||||
import io.grpc.internal.SharedResourcePool;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
|
@ -60,8 +54,6 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
|
|||
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
|
||||
private boolean enableUntrustedAlts;
|
||||
|
||||
private AltsProtocolNegotiator negotiatorForTest;
|
||||
|
||||
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
|
||||
public static final AltsChannelBuilder forTarget(String target) {
|
||||
return new AltsChannelBuilder(target);
|
||||
|
|
@ -74,8 +66,6 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
|
|||
|
||||
private AltsChannelBuilder(String target) {
|
||||
delegate = NettyChannelBuilder.forTarget(target);
|
||||
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
|
||||
delegate(), new ProtocolNegotiatorFactory());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -125,41 +115,20 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
|
|||
delegate().intercept(new FailingClientInterceptor(status));
|
||||
}
|
||||
}
|
||||
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
|
||||
delegate(),
|
||||
new ClientAltsProtocolNegotiatorFactory(
|
||||
targetServiceAccountsBuilder.build(), handshakerChannelPool));
|
||||
|
||||
return delegate().build();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
@Nullable
|
||||
AltsProtocolNegotiator getProtocolNegotiatorForTest() {
|
||||
return negotiatorForTest;
|
||||
}
|
||||
|
||||
private final class ProtocolNegotiatorFactory
|
||||
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
|
||||
|
||||
@Override
|
||||
public AltsProtocolNegotiator buildProtocolNegotiator() {
|
||||
final ImmutableList<String> targetServiceAccounts = targetServiceAccountsBuilder.build();
|
||||
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
|
||||
TsiHandshakerFactory altsHandshakerFactory =
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker(String authority) {
|
||||
AltsClientOptions handshakerOptions =
|
||||
new AltsClientOptions.Builder()
|
||||
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
|
||||
.setTargetServiceAccounts(targetServiceAccounts)
|
||||
.setTargetName(authority)
|
||||
.build();
|
||||
return AltsTsiHandshaker.newClient(
|
||||
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
|
||||
}
|
||||
};
|
||||
return negotiatorForTest =
|
||||
AltsProtocolNegotiator.createClientNegotiator(
|
||||
altsHandshakerFactory, lazyHandshakerChannel);
|
||||
}
|
||||
ProtocolNegotiator getProtocolNegotiatorForTest() {
|
||||
return new ClientAltsProtocolNegotiatorFactory(
|
||||
targetServiceAccountsBuilder.build(), handshakerChannelPool)
|
||||
.buildProtocolNegotiator();
|
||||
}
|
||||
|
||||
/** An implementation of {@link ClientInterceptor} that fails each call. */
|
||||
|
|
|
|||
|
|
@ -33,14 +33,7 @@ import io.grpc.ServerServiceDefinition;
|
|||
import io.grpc.ServerStreamTracer.Factory;
|
||||
import io.grpc.ServerTransportFilter;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.AltsHandshakerOptions;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
|
||||
import io.grpc.alts.internal.AltsTsiHandshaker;
|
||||
import io.grpc.alts.internal.HandshakerServiceGrpc;
|
||||
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
|
||||
import io.grpc.alts.internal.TsiHandshaker;
|
||||
import io.grpc.alts.internal.TsiHandshakerFactory;
|
||||
import io.grpc.internal.ObjectPool;
|
||||
import io.grpc.internal.SharedResourcePool;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
|
|
@ -192,18 +185,8 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
|
|||
}
|
||||
}
|
||||
|
||||
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
|
||||
delegate.protocolNegotiator(
|
||||
AltsProtocolNegotiator.createServerNegotiator(
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker(String authority) {
|
||||
return AltsTsiHandshaker.newServer(
|
||||
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
|
||||
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
|
||||
}
|
||||
},
|
||||
lazyHandshakerChannel));
|
||||
AltsProtocolNegotiator.serverAltsProtocolNegotiator(handshakerChannelPool));
|
||||
return delegate.build();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,23 +18,18 @@ package io.grpc.alts;
|
|||
|
||||
import com.google.auth.oauth2.ComputeEngineCredentials;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import io.grpc.CallCredentials;
|
||||
import io.grpc.ForwardingChannelBuilder;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.AltsClientOptions;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
|
||||
import io.grpc.alts.internal.AltsTsiHandshaker;
|
||||
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
|
||||
import io.grpc.alts.internal.HandshakerServiceGrpc;
|
||||
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
|
||||
import io.grpc.alts.internal.TsiHandshaker;
|
||||
import io.grpc.alts.internal.TsiHandshakerFactory;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory;
|
||||
import io.grpc.auth.MoreCallCredentials;
|
||||
import io.grpc.internal.GrpcUtil;
|
||||
import io.grpc.internal.SharedResourcePool;
|
||||
import io.grpc.netty.GrpcSslContexts;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import javax.net.ssl.SSLException;
|
||||
|
|
@ -47,12 +42,21 @@ public final class ComputeEngineChannelBuilder
|
|||
extends ForwardingChannelBuilder<GoogleDefaultChannelBuilder> {
|
||||
|
||||
private final NettyChannelBuilder delegate;
|
||||
private GoogleDefaultProtocolNegotiator negotiatorForTest;
|
||||
|
||||
private ComputeEngineChannelBuilder(String target) {
|
||||
delegate = NettyChannelBuilder.forTarget(target);
|
||||
SslContext sslContext;
|
||||
try {
|
||||
sslContext = GrpcSslContexts.forClient().build();
|
||||
} catch (SSLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
|
||||
delegate(), new ProtocolNegotiatorFactory());
|
||||
delegate(),
|
||||
new GoogleDefaultProtocolNegotiatorFactory(
|
||||
/* targetServiceAccounts= */ ImmutableList.<String>of(),
|
||||
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
|
||||
sslContext));
|
||||
CallCredentials credentials = MoreCallCredentials.from(ComputeEngineCredentials.create());
|
||||
Status status = Status.OK;
|
||||
if (!CheckGcpEnvironment.isOnGcp()) {
|
||||
|
|
@ -79,40 +83,17 @@ public final class ComputeEngineChannelBuilder
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() {
|
||||
return negotiatorForTest;
|
||||
}
|
||||
|
||||
private final class ProtocolNegotiatorFactory
|
||||
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
|
||||
|
||||
@Override
|
||||
public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
|
||||
final LazyChannel lazyHandshakerChannel =
|
||||
new LazyChannel(
|
||||
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL));
|
||||
TsiHandshakerFactory altsHandshakerFactory =
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker(String authority) {
|
||||
AltsClientOptions handshakerOptions =
|
||||
new AltsClientOptions.Builder()
|
||||
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
|
||||
.setTargetName(authority)
|
||||
.build();
|
||||
return AltsTsiHandshaker.newClient(
|
||||
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
|
||||
}
|
||||
};
|
||||
SslContext sslContext;
|
||||
try {
|
||||
sslContext = GrpcSslContexts.forClient().build();
|
||||
} catch (SSLException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
return negotiatorForTest =
|
||||
new GoogleDefaultProtocolNegotiator(
|
||||
altsHandshakerFactory, lazyHandshakerChannel, sslContext);
|
||||
ProtocolNegotiator getProtocolNegotiatorForTest() {
|
||||
SslContext sslContext;
|
||||
try {
|
||||
sslContext = GrpcSslContexts.forClient().build();
|
||||
} catch (SSLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new GoogleDefaultProtocolNegotiatorFactory(
|
||||
/* targetServiceAccounts= */ ImmutableList.<String>of(),
|
||||
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
|
||||
sslContext)
|
||||
.buildProtocolNegotiator();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,23 +18,18 @@ package io.grpc.alts;
|
|||
|
||||
import com.google.auth.oauth2.GoogleCredentials;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import io.grpc.CallCredentials;
|
||||
import io.grpc.ForwardingChannelBuilder;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.AltsClientOptions;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
|
||||
import io.grpc.alts.internal.AltsTsiHandshaker;
|
||||
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
|
||||
import io.grpc.alts.internal.HandshakerServiceGrpc;
|
||||
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
|
||||
import io.grpc.alts.internal.TsiHandshaker;
|
||||
import io.grpc.alts.internal.TsiHandshakerFactory;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory;
|
||||
import io.grpc.auth.MoreCallCredentials;
|
||||
import io.grpc.internal.GrpcUtil;
|
||||
import io.grpc.internal.SharedResourcePool;
|
||||
import io.grpc.netty.GrpcSslContexts;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import java.io.IOException;
|
||||
|
|
@ -49,12 +44,21 @@ public final class GoogleDefaultChannelBuilder
|
|||
extends ForwardingChannelBuilder<GoogleDefaultChannelBuilder> {
|
||||
|
||||
private final NettyChannelBuilder delegate;
|
||||
private GoogleDefaultProtocolNegotiator negotiatorForTest;
|
||||
|
||||
private GoogleDefaultChannelBuilder(String target) {
|
||||
delegate = NettyChannelBuilder.forTarget(target);
|
||||
SslContext sslContext;
|
||||
try {
|
||||
sslContext = GrpcSslContexts.forClient().build();
|
||||
} catch (SSLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
|
||||
delegate(), new ProtocolNegotiatorFactory());
|
||||
delegate(),
|
||||
new GoogleDefaultProtocolNegotiatorFactory(
|
||||
/* targetServiceAccounts= */ ImmutableList.<String>of(),
|
||||
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
|
||||
sslContext));
|
||||
@Nullable CallCredentials credentials = null;
|
||||
Status status = Status.OK;
|
||||
try {
|
||||
|
|
@ -84,40 +88,17 @@ public final class GoogleDefaultChannelBuilder
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() {
|
||||
return negotiatorForTest;
|
||||
}
|
||||
|
||||
private final class ProtocolNegotiatorFactory
|
||||
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
|
||||
|
||||
@Override
|
||||
public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
|
||||
final LazyChannel lazyHandshakerChannel =
|
||||
new LazyChannel(
|
||||
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL));
|
||||
TsiHandshakerFactory altsHandshakerFactory =
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker(String authority) {
|
||||
AltsClientOptions handshakerOptions =
|
||||
new AltsClientOptions.Builder()
|
||||
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
|
||||
.setTargetName(authority)
|
||||
.build();
|
||||
return AltsTsiHandshaker.newClient(
|
||||
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
|
||||
}
|
||||
};
|
||||
SslContext sslContext;
|
||||
try {
|
||||
sslContext = GrpcSslContexts.forClient().build();
|
||||
} catch (SSLException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
return negotiatorForTest =
|
||||
new GoogleDefaultProtocolNegotiator(
|
||||
altsHandshakerFactory, lazyHandshakerChannel, sslContext);
|
||||
ProtocolNegotiator getProtocolNegotiatorForTest() {
|
||||
SslContext sslContext;
|
||||
try {
|
||||
sslContext = GrpcSslContexts.forClient().build();
|
||||
} catch (SSLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new GoogleDefaultProtocolNegotiatorFactory(
|
||||
/* targetServiceAccounts= */ ImmutableList.<String>of(),
|
||||
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
|
||||
sslContext)
|
||||
.buildProtocolNegotiator();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,10 @@
|
|||
|
||||
package io.grpc.alts.internal;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.Any;
|
||||
import io.grpc.Attributes;
|
||||
import io.grpc.Channel;
|
||||
|
|
@ -27,102 +29,266 @@ import io.grpc.InternalChannelz.Security;
|
|||
import io.grpc.SecurityLevel;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
|
||||
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
|
||||
import io.grpc.internal.GrpcAttributes;
|
||||
import io.grpc.internal.ObjectPool;
|
||||
import io.grpc.netty.GrpcHttp2ConnectionHandler;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import io.grpc.netty.InternalProtocolNegotiators.AbstractBufferingHandler;
|
||||
import io.grpc.netty.InternalProtocolNegotiators;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import io.netty.util.AsciiString;
|
||||
import java.util.logging.Level;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* A GRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that provides ALTS
|
||||
* A gRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that provides ALTS
|
||||
* security on the wire, similar to Netty's {@code SslHandler}.
|
||||
*/
|
||||
public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
|
||||
|
||||
// TODO(carl-mastrangelo): rename this AltsProtocolNegotiators.
|
||||
public final class AltsProtocolNegotiator {
|
||||
private static final Logger logger = Logger.getLogger(AltsProtocolNegotiator.class.getName());
|
||||
|
||||
@Grpc.TransportAttr
|
||||
public static final Attributes.Key<TsiPeer> TSI_PEER_KEY = Attributes.Key.create("TSI_PEER");
|
||||
|
||||
@Grpc.TransportAttr
|
||||
public static final Attributes.Key<AltsAuthContext> ALTS_CONTEXT_KEY =
|
||||
Attributes.Key.create("ALTS_CONTEXT_KEY");
|
||||
public static final Attributes.Key<Object> AUTH_CONTEXT_KEY =
|
||||
Attributes.Key.create("AUTH_CONTEXT_KEY");
|
||||
|
||||
private static final AsciiString scheme = AsciiString.of("https");
|
||||
private static final AsciiString SCHEME = AsciiString.of("https");
|
||||
|
||||
/** Creates a negotiator used for ALTS client. */
|
||||
public static AltsProtocolNegotiator createClientNegotiator(
|
||||
final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
|
||||
final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator {
|
||||
/**
|
||||
* ClientAltsProtocolNegotiatorFactory is a factory for doing client side negotiation of an ALTS
|
||||
* channel.
|
||||
*/
|
||||
public static final class ClientAltsProtocolNegotiatorFactory
|
||||
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
|
||||
|
||||
private final ImmutableList<String> targetServiceAccounts;
|
||||
private final LazyChannel lazyHandshakerChannel;
|
||||
|
||||
public ClientAltsProtocolNegotiatorFactory(
|
||||
List<String> targetServiceAccounts,
|
||||
ObjectPool<Channel> handshakerChannelPool) {
|
||||
this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts);
|
||||
this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ProtocolNegotiator buildProtocolNegotiator() {
|
||||
return new ClientAltsProtocolNegotiator(
|
||||
new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel),
|
||||
lazyHandshakerChannel);
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
private static final class ClientAltsProtocolNegotiator implements ProtocolNegotiator {
|
||||
private final TsiHandshakerFactory handshakerFactory;
|
||||
private final LazyChannel lazyHandshakerChannel;
|
||||
|
||||
ClientAltsProtocolNegotiator(
|
||||
TsiHandshakerFactory handshakerFactory, LazyChannel lazyHandshakerChannel) {
|
||||
this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory");
|
||||
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsciiString scheme() {
|
||||
return SCHEME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
|
||||
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
|
||||
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
|
||||
ChannelHandler thh =
|
||||
new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator());
|
||||
ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(thh);
|
||||
return wuah;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
lazyHandshakerChannel.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a protocol negotiator for ALTS on the server side.
|
||||
*/
|
||||
public static ProtocolNegotiator serverAltsProtocolNegotiator(
|
||||
ObjectPool<Channel> handshakerChannelPool) {
|
||||
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
|
||||
final class ServerTsiHandshakerFactory implements TsiHandshakerFactory {
|
||||
|
||||
@Override
|
||||
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||
public TsiHandshaker newHandshaker(@Nullable String authority) {
|
||||
assert authority == null;
|
||||
return AltsTsiHandshaker.newServer(
|
||||
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
|
||||
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
|
||||
}
|
||||
}
|
||||
|
||||
return new ServerAltsProtocolNegotiator(
|
||||
new ServerTsiHandshakerFactory(), lazyHandshakerChannel);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static final class ServerAltsProtocolNegotiator implements ProtocolNegotiator {
|
||||
private final TsiHandshakerFactory handshakerFactory;
|
||||
private final LazyChannel lazyHandshakerChannel;
|
||||
|
||||
@VisibleForTesting
|
||||
ServerAltsProtocolNegotiator(
|
||||
TsiHandshakerFactory handshakerFactory, LazyChannel lazyHandshakerChannel) {
|
||||
this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory");
|
||||
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsciiString scheme() {
|
||||
return SCHEME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||
TsiHandshaker handshaker = handshakerFactory.newHandshaker(/* authority= */ null);
|
||||
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
|
||||
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
|
||||
ChannelHandler thh =
|
||||
new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator());
|
||||
ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(thh);
|
||||
return wuah;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
logger.finest("ALTS Server ProtocolNegotiator Closed");
|
||||
lazyHandshakerChannel.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A Protocol Negotiator factory which can switch between ALTS and TLS based on EAG Attrs.
|
||||
*/
|
||||
public static final class GoogleDefaultProtocolNegotiatorFactory
|
||||
implements ProtocolNegotiatorFactory {
|
||||
private final ImmutableList<String> targetServiceAccounts;
|
||||
private final LazyChannel lazyHandshakerChannel;
|
||||
private final SslContext sslContext;
|
||||
|
||||
/**
|
||||
* Creates Negotiator Factory, which will either use the targetServiceAccounts and
|
||||
* handshakerChannelPool, or the sslContext.
|
||||
*/
|
||||
public GoogleDefaultProtocolNegotiatorFactory(
|
||||
List<String> targetServiceAccounts,
|
||||
ObjectPool<Channel> handshakerChannelPool,
|
||||
SslContext sslContext) {
|
||||
this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts);
|
||||
this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
|
||||
this.sslContext = checkNotNull(sslContext, "sslContext");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ProtocolNegotiator buildProtocolNegotiator() {
|
||||
return new GoogleDefaultProtocolNegotiator(
|
||||
new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel),
|
||||
lazyHandshakerChannel,
|
||||
sslContext);
|
||||
}
|
||||
}
|
||||
|
||||
private static final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator {
|
||||
private final TsiHandshakerFactory handshakerFactory;
|
||||
private final LazyChannel lazyHandshakerChannel;
|
||||
private final SslContext sslContext;
|
||||
|
||||
GoogleDefaultProtocolNegotiator(
|
||||
TsiHandshakerFactory handshakerFactory,
|
||||
LazyChannel lazyHandshakerChannel,
|
||||
SslContext sslContext) {
|
||||
this.handshakerFactory = checkNotNull(handshakerFactory, "handshakerFactory");
|
||||
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
|
||||
this.sslContext = checkNotNull(sslContext, "checkNotNull");
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsciiString scheme() {
|
||||
return SCHEME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
|
||||
ChannelHandler securityHandler;
|
||||
if (grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY) != null
|
||||
|| grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND) != null) {
|
||||
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
|
||||
return new BufferUntilAltsNegotiatedHandler(
|
||||
grpcHandler,
|
||||
new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)),
|
||||
new TsiFrameHandler());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
logger.finest("ALTS Client ProtocolNegotiator Closed");
|
||||
lazyHandshakerChannel.close();
|
||||
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
|
||||
securityHandler =
|
||||
new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator());
|
||||
} else {
|
||||
securityHandler = InternalProtocolNegotiators.clientTlsHandler(
|
||||
gnh, sslContext, grpcHandler.getAuthority());
|
||||
}
|
||||
ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(securityHandler);
|
||||
return wuah;
|
||||
}
|
||||
|
||||
return new ClientAltsProtocolNegotiator();
|
||||
@Override
|
||||
public void close() {
|
||||
logger.finest("ALTS Server ProtocolNegotiator Closed");
|
||||
lazyHandshakerChannel.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final AsciiString scheme() {
|
||||
return scheme;
|
||||
}
|
||||
private static final class ClientTsiHandshakerFactory implements TsiHandshakerFactory {
|
||||
|
||||
/** Creates a negotiator used for ALTS server. */
|
||||
public static AltsProtocolNegotiator createServerNegotiator(
|
||||
final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
|
||||
final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator {
|
||||
private final ImmutableList<String> targetServiceAccounts;
|
||||
private final LazyChannel lazyHandshakerChannel;
|
||||
|
||||
@Override
|
||||
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||
TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null);
|
||||
return new BufferUntilAltsNegotiatedHandler(
|
||||
grpcHandler,
|
||||
new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)),
|
||||
new TsiFrameHandler());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
logger.finest("ALTS Server ProtocolNegotiator Closed");
|
||||
lazyHandshakerChannel.close();
|
||||
}
|
||||
ClientTsiHandshakerFactory(
|
||||
ImmutableList<String> targetServiceAccounts, LazyChannel lazyHandshakerChannel) {
|
||||
this.targetServiceAccounts = checkNotNull(targetServiceAccounts, "targetServiceAccounts");
|
||||
this.lazyHandshakerChannel = checkNotNull(lazyHandshakerChannel, "lazyHandshakerChannel");
|
||||
}
|
||||
|
||||
return new ServerAltsProtocolNegotiator();
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker(@Nullable String authority) {
|
||||
AltsClientOptions handshakerOptions =
|
||||
new AltsClientOptions.Builder()
|
||||
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
|
||||
.setTargetServiceAccounts(targetServiceAccounts)
|
||||
.setTargetName(authority)
|
||||
.build();
|
||||
return AltsTsiHandshaker.newClient(
|
||||
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
|
||||
}
|
||||
}
|
||||
|
||||
/** Channel created from a channel pool lazily. */
|
||||
public static class LazyChannel {
|
||||
@VisibleForTesting
|
||||
static final class LazyChannel {
|
||||
private final ObjectPool<Channel> channelPool;
|
||||
private Channel channel;
|
||||
|
||||
public LazyChannel(ObjectPool<Channel> channelPool) {
|
||||
this.channelPool = channelPool;
|
||||
@VisibleForTesting
|
||||
LazyChannel(ObjectPool<Channel> channelPool) {
|
||||
this.channelPool = checkNotNull(channelPool, "channelPool");
|
||||
}
|
||||
|
||||
/**
|
||||
* If channel is null, gets a channel from the channel pool, otherwise, returns the cached
|
||||
* channel.
|
||||
*/
|
||||
public synchronized Channel get() {
|
||||
synchronized Channel get() {
|
||||
if (channel == null) {
|
||||
channel = channelPool.getObject();
|
||||
}
|
||||
|
|
@ -130,87 +296,37 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
|
|||
}
|
||||
|
||||
/** Returns the cached channel to the channel pool. */
|
||||
public synchronized void close() {
|
||||
synchronized void close() {
|
||||
if (channel != null) {
|
||||
channelPool.returnObject(channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Buffers all writes until the ALTS handshake is complete. */
|
||||
@VisibleForTesting
|
||||
static final class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler {
|
||||
|
||||
private final GrpcHttp2ConnectionHandler grpcHandler;
|
||||
|
||||
BufferUntilAltsNegotiatedHandler(
|
||||
GrpcHttp2ConnectionHandler grpcHandler, ChannelHandler... negotiationhandlers) {
|
||||
super(negotiationhandlers);
|
||||
// Save the gRPC handler. The ALTS handler doesn't support buffering before the handshake
|
||||
// completes, so we wait until the handshake was successful before adding the grpc handler.
|
||||
this.grpcHandler = grpcHandler;
|
||||
}
|
||||
|
||||
// TODO: Remove this once https://github.com/grpc/grpc-java/pull/3715 is in.
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
|
||||
logger.log(Level.FINEST, "Exception while buffering for ALTS Negotiation", cause);
|
||||
fail(ctx, cause);
|
||||
ctx.fireExceptionCaught(cause);
|
||||
}
|
||||
private static final class AltsHandshakeValidator extends TsiHandshakeHandler.HandshakeValidator {
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
|
||||
if (logger.isLoggable(Level.FINEST)) {
|
||||
logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[] {evt});
|
||||
public SecurityDetails validatePeerObject(Object peerObject) throws GeneralSecurityException {
|
||||
AltsAuthContext altsAuthContext = (AltsAuthContext) peerObject;
|
||||
// Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if
|
||||
// Rpc Protocol Versions mismatch.
|
||||
RpcVersionsCheckResult checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersionsUtil.getRpcProtocolVersions(),
|
||||
altsAuthContext.getPeerRpcVersions());
|
||||
if (!checkResult.getResult()) {
|
||||
String errorMessage =
|
||||
"Local Rpc Protocol Versions "
|
||||
+ RpcProtocolVersionsUtil.getRpcProtocolVersions()
|
||||
+ " are not compatible with peer Rpc Protocol Versions "
|
||||
+ altsAuthContext.getPeerRpcVersions();
|
||||
throw Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException();
|
||||
}
|
||||
if (evt instanceof TsiHandshakeCompletionEvent) {
|
||||
TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt;
|
||||
if (altsEvt.isSuccess()) {
|
||||
// Add the gRPC handler just before this handler. We only allow the grpcHandler to be
|
||||
// null to support testing. In production, a grpc handler will always be provided.
|
||||
if (grpcHandler != null) {
|
||||
ctx.pipeline().addBefore(ctx.name(), null, grpcHandler);
|
||||
AltsAuthContext altsContext = (AltsAuthContext) altsEvt.context();
|
||||
Preconditions.checkNotNull(altsContext);
|
||||
// Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if
|
||||
// Rpc Protocol Versions mismatch.
|
||||
RpcVersionsCheckResult checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersionsUtil.getRpcProtocolVersions(),
|
||||
altsContext.getPeerRpcVersions());
|
||||
if (!checkResult.getResult()) {
|
||||
String errorMessage =
|
||||
"Local Rpc Protocol Versions "
|
||||
+ RpcProtocolVersionsUtil.getRpcProtocolVersions().toString()
|
||||
+ "are not compatible with peer Rpc Protocol Versions "
|
||||
+ altsContext.getPeerRpcVersions().toString();
|
||||
logger.finest(errorMessage);
|
||||
fail(ctx, Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException());
|
||||
}
|
||||
grpcHandler.handleProtocolNegotiationCompleted(
|
||||
Attributes.newBuilder()
|
||||
.set(TSI_PEER_KEY, altsEvt.peer())
|
||||
.set(ALTS_CONTEXT_KEY, altsContext)
|
||||
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
|
||||
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
|
||||
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
|
||||
.build(),
|
||||
new Security(new OtherSecurity("alts", Any.pack(altsContext.context))));
|
||||
}
|
||||
logger.finest("Flushing ALTS buffered data");
|
||||
// Now write any buffered data and remove this handler.
|
||||
writeBufferedAndRemove(ctx);
|
||||
} else {
|
||||
logger.log(Level.FINEST, "ALTS handshake failed", altsEvt.cause());
|
||||
fail(ctx, unavailableException("ALTS handshake failed", altsEvt.cause()));
|
||||
}
|
||||
}
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
|
||||
private static RuntimeException unavailableException(String msg, Throwable cause) {
|
||||
return Status.UNAVAILABLE.withCause(cause).withDescription(msg).asRuntimeException();
|
||||
return new SecurityDetails(
|
||||
SecurityLevel.PRIVACY_AND_INTEGRITY,
|
||||
new Security(new OtherSecurity("alts", Any.pack(altsAuthContext.context))));
|
||||
}
|
||||
}
|
||||
|
||||
private AltsProtocolNegotiator() {}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,71 +0,0 @@
|
|||
/*
|
||||
* Copyright 2018 The gRPC Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.internal;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
|
||||
import io.grpc.internal.GrpcAttributes;
|
||||
import io.grpc.netty.GrpcHttp2ConnectionHandler;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import io.grpc.netty.InternalProtocolNegotiators;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import io.netty.util.AsciiString;
|
||||
|
||||
/** A client-side GPRC {@link ProtocolNegotiator} for Google Default Channel. */
|
||||
public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator {
|
||||
|
||||
private final ProtocolNegotiator altsProtocolNegotiator;
|
||||
private final ProtocolNegotiator tlsProtocolNegotiator;
|
||||
|
||||
/** Constructor for protocol negotiator of Google Default Channel. */
|
||||
public GoogleDefaultProtocolNegotiator(
|
||||
TsiHandshakerFactory altsFactory, LazyChannel lazyHandshakerChannel, SslContext sslContext) {
|
||||
altsProtocolNegotiator =
|
||||
AltsProtocolNegotiator.createClientNegotiator(altsFactory, lazyHandshakerChannel);
|
||||
tlsProtocolNegotiator = InternalProtocolNegotiators.tls(sslContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsciiString scheme() {
|
||||
assert tlsProtocolNegotiator.scheme().equals(altsProtocolNegotiator.scheme());
|
||||
return tlsProtocolNegotiator.scheme();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
GoogleDefaultProtocolNegotiator(
|
||||
ProtocolNegotiator altsProtocolNegotiator, ProtocolNegotiator tlsProtocolNegotiator) {
|
||||
this.altsProtocolNegotiator = altsProtocolNegotiator;
|
||||
this.tlsProtocolNegotiator = tlsProtocolNegotiator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||
if (grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_ADDR_AUTHORITY) != null
|
||||
|| grpcHandler.getEagAttributes().get(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND) != null) {
|
||||
return altsProtocolNegotiator.newHandler(grpcHandler);
|
||||
} else {
|
||||
return tlsProtocolNegotiator.newHandler(grpcHandler);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
altsProtocolNegotiator.close();
|
||||
tlsProtocolNegotiator.close();
|
||||
}
|
||||
}
|
||||
|
|
@ -19,9 +19,7 @@ package io.grpc.alts.internal;
|
|||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.alts.internal.TsiFrameProtector.Consumer;
|
||||
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelException;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
|
|
@ -33,7 +31,6 @@ import java.net.SocketAddress;
|
|||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
|
@ -47,72 +44,33 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
|
||||
private TsiFrameProtector protector;
|
||||
private PendingWriteQueue pendingUnprotectedWrites;
|
||||
private State state = State.HANDSHAKE_NOT_FINISHED;
|
||||
private boolean closeInitiated = false;
|
||||
private boolean closeInitiated;
|
||||
|
||||
@VisibleForTesting
|
||||
enum State {
|
||||
HANDSHAKE_NOT_FINISHED,
|
||||
PROTECTED,
|
||||
CLOSED,
|
||||
HANDSHAKE_FAILED
|
||||
public TsiFrameHandler(TsiFrameProtector protector) {
|
||||
this.protector = checkNotNull(protector, "protector");
|
||||
}
|
||||
|
||||
public TsiFrameHandler() {}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
|
||||
logger.finest("TsiFrameHandler added");
|
||||
super.handlerAdded(ctx);
|
||||
assert pendingUnprotectedWrites == null;
|
||||
pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception {
|
||||
if (logger.isLoggable(Level.FINEST)) {
|
||||
logger.log(Level.FINEST, "TsiFrameHandler user event triggered", new Object[]{event});
|
||||
}
|
||||
if (event instanceof TsiHandshakeCompletionEvent) {
|
||||
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
|
||||
if (tsiEvent.isSuccess()) {
|
||||
setProtector(tsiEvent.protector());
|
||||
} else {
|
||||
state = State.HANDSHAKE_FAILED;
|
||||
}
|
||||
// Ignore errors. Another handler in the pipeline must handle TSI Errors.
|
||||
}
|
||||
// Keep propagating the message, as others may want to read it.
|
||||
super.userEventTriggered(ctx, event);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void setProtector(TsiFrameProtector protector) {
|
||||
logger.finest("TsiFrameHandler protector set");
|
||||
checkState(this.protector == null);
|
||||
this.protector = checkNotNull(protector);
|
||||
this.state = State.PROTECTED;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
|
||||
checkState(
|
||||
state == State.PROTECTED,
|
||||
"Cannot read frames while the TSI handshake is %s", state);
|
||||
checkState(protector != null, "decode() called after close()");
|
||||
protector.unprotect(in, out, ctx.alloc());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
|
||||
throws Exception {
|
||||
checkState(
|
||||
state == State.PROTECTED,
|
||||
"Cannot write frames while the TSI handshake state is %s", state);
|
||||
@SuppressWarnings("FutureReturnValueIgnored") // for setSuccess
|
||||
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) {
|
||||
checkState(protector != null, "write() called after close()");
|
||||
ByteBuf msg = (ByteBuf) message;
|
||||
if (!msg.isReadable()) {
|
||||
// Nothing to encode.
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError = promise.setSuccess();
|
||||
promise.setSuccess();
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -122,30 +80,11 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
|
||||
@Override
|
||||
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||
if (!pendingUnprotectedWrites.isEmpty()) {
|
||||
if (pendingUnprotectedWrites != null && !pendingUnprotectedWrites.isEmpty()) {
|
||||
pendingUnprotectedWrites.removeAndFailAll(
|
||||
new ChannelException("Pending write on removal of TSI handler"));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
pendingUnprotectedWrites.removeAndFailAll(cause);
|
||||
super.exceptionCaught(ctx, cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
|
||||
ctx.bind(localAddress, promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void connect(
|
||||
ChannelHandlerContext ctx,
|
||||
SocketAddress remoteAddress,
|
||||
SocketAddress localAddress,
|
||||
ChannelPromise promise) {
|
||||
ctx.connect(remoteAddress, localAddress, promise);
|
||||
destroyProtector();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -154,6 +93,12 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
ctx.disconnect(promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
doClose(ctx);
|
||||
ctx.close(promise);
|
||||
}
|
||||
|
||||
private void doClose(ChannelHandlerContext ctx) {
|
||||
if (closeInitiated) {
|
||||
return;
|
||||
|
|
@ -165,51 +110,34 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
flush(ctx);
|
||||
}
|
||||
} catch (GeneralSecurityException e) {
|
||||
logger.log(Level.FINE, "Ignoring error on flush before close", e);
|
||||
logger.log(Level.FINE, "Ignored error on flush before close", e);
|
||||
} finally {
|
||||
state = State.CLOSED;
|
||||
release();
|
||||
pendingUnprotectedWrites = null;
|
||||
destroyProtector();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
doClose(ctx);
|
||||
ctx.close(promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
doClose(ctx);
|
||||
ctx.deregister(promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void read(ChannelHandlerContext ctx) {
|
||||
ctx.read();
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("FutureReturnValueIgnored") // for aggregatePromise.doneAllocatingPromises
|
||||
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
|
||||
if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) {
|
||||
logger.fine(
|
||||
String.format("FrameHandler is inactive(%s), channel id: %s",
|
||||
state, ctx.channel().id().asShortText()));
|
||||
if (protector == null) {
|
||||
// TODO(carl-mastrangelo): this should be a checkState. AbstractNettyHandler.exceptionCaught
|
||||
// transitively calls flush even after closed, for some reason.
|
||||
pendingUnprotectedWrites.removeAndFailAll(
|
||||
new ChannelException("Pending write on removal of TSI handler"));
|
||||
logger.fine("flush() called after close()");
|
||||
return;
|
||||
}
|
||||
checkState(
|
||||
state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state);
|
||||
final ProtectedPromise aggregatePromise =
|
||||
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
||||
|
||||
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
|
||||
|
||||
if (pendingUnprotectedWrites.isEmpty()) {
|
||||
// Return early if there's nothing to write. Otherwise protector.protectFlush() below may
|
||||
// not check for "no-data" and go on writing the 0-byte "data" to the socket with the
|
||||
// protection framing.
|
||||
return;
|
||||
}
|
||||
final ProtectedPromise aggregatePromise =
|
||||
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
||||
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
|
||||
|
||||
// Drain the unprotected writes.
|
||||
while (!pendingUnprotectedWrites.isEmpty()) {
|
||||
ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current();
|
||||
|
|
@ -218,25 +146,54 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove());
|
||||
}
|
||||
|
||||
protector.protectFlush(
|
||||
bufs,
|
||||
new Consumer<ByteBuf>() {
|
||||
@Override
|
||||
public void accept(ByteBuf b) {
|
||||
ctx.writeAndFlush(b, aggregatePromise.newPromise());
|
||||
}
|
||||
},
|
||||
ctx.alloc());
|
||||
final class ProtectedFrameWriteFlusher implements Consumer<ByteBuf> {
|
||||
|
||||
@Override
|
||||
public void accept(ByteBuf byteBuf) {
|
||||
ctx.writeAndFlush(byteBuf, aggregatePromise.newPromise());
|
||||
}
|
||||
}
|
||||
|
||||
protector.protectFlush(bufs, new ProtectedFrameWriteFlusher(), ctx.alloc());
|
||||
// We're done writing, start the flow of promise events.
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError = aggregatePromise.doneAllocatingPromises();
|
||||
aggregatePromise.doneAllocatingPromises();
|
||||
}
|
||||
|
||||
private void release() {
|
||||
// Only here to fulfill ChannelOutboundHandler
|
||||
@Override
|
||||
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
|
||||
ctx.bind(localAddress, promise);
|
||||
}
|
||||
|
||||
// Only here to fulfill ChannelOutboundHandler
|
||||
@Override
|
||||
public void connect(
|
||||
ChannelHandlerContext ctx,
|
||||
SocketAddress remoteAddress,
|
||||
SocketAddress localAddress,
|
||||
ChannelPromise promise) {
|
||||
ctx.connect(remoteAddress, localAddress, promise);
|
||||
}
|
||||
|
||||
// Only here to fulfill ChannelOutboundHandler
|
||||
@Override
|
||||
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
ctx.deregister(promise);
|
||||
}
|
||||
|
||||
// Only here to fulfill ChannelOutboundHandler
|
||||
@Override
|
||||
public void read(ChannelHandlerContext ctx) {
|
||||
ctx.read();
|
||||
}
|
||||
|
||||
private void destroyProtector() {
|
||||
if (protector != null) {
|
||||
protector.destroy();
|
||||
protector = null;
|
||||
try {
|
||||
protector.destroy();
|
||||
} finally {
|
||||
protector = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,19 +17,25 @@
|
|||
package io.grpc.alts.internal;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static io.grpc.alts.internal.AltsProtocolNegotiator.AUTH_CONTEXT_KEY;
|
||||
import static io.grpc.alts.internal.AltsProtocolNegotiator.TSI_PEER_KEY;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.MoreObjects;
|
||||
import io.grpc.Attributes;
|
||||
import io.grpc.ChannelLogger.ChannelLogLevel;
|
||||
import io.grpc.InternalChannelz.Security;
|
||||
import io.grpc.SecurityLevel;
|
||||
import io.grpc.alts.internal.TsiHandshakeHandler.HandshakeValidator.SecurityDetails;
|
||||
import io.grpc.internal.GrpcAttributes;
|
||||
import io.grpc.netty.InternalProtocolNegotiationEvent;
|
||||
import io.grpc.netty.InternalProtocolNegotiators;
|
||||
import io.grpc.netty.ProtocolNegotiationEvent;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
|
|
@ -38,118 +44,63 @@ import javax.annotation.Nullable;
|
|||
*/
|
||||
public final class TsiHandshakeHandler extends ByteToMessageDecoder {
|
||||
|
||||
private static final Logger logger = Logger.getLogger(TsiHandshakeHandler.class.getName());
|
||||
/**
|
||||
* Validates a Tsi Peer object.
|
||||
*/
|
||||
public abstract static class HandshakeValidator {
|
||||
|
||||
public static final class SecurityDetails {
|
||||
|
||||
private final SecurityLevel securityLevel;
|
||||
private final Security security;
|
||||
|
||||
/**
|
||||
* Constructs SecurityDetails.
|
||||
*/
|
||||
public SecurityDetails(io.grpc.SecurityLevel securityLevel, @Nullable Security security) {
|
||||
this.securityLevel = checkNotNull(securityLevel, "securityLevel");
|
||||
this.security = security;
|
||||
}
|
||||
|
||||
public Security getSecurity() {
|
||||
return security;
|
||||
}
|
||||
|
||||
public SecurityLevel getSecurityLevel() {
|
||||
return securityLevel;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a Tsi Peer object.
|
||||
*/
|
||||
public abstract SecurityDetails validatePeerObject(Object peerObject)
|
||||
throws GeneralSecurityException;
|
||||
}
|
||||
|
||||
private static final int HANDSHAKE_FRAME_SIZE = 1024;
|
||||
|
||||
private final NettyTsiHandshaker handshaker;
|
||||
private boolean started;
|
||||
private final HandshakeValidator handshakeValidator;
|
||||
private final ChannelHandler next;
|
||||
|
||||
private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault();
|
||||
|
||||
/**
|
||||
* This buffer doesn't store any state. We just hold onto it in case we end up allocating a buffer
|
||||
* that ends up being unused.
|
||||
* Constructs a TsiHandshakeHandler.
|
||||
*/
|
||||
private ByteBuf buffer;
|
||||
|
||||
public TsiHandshakeHandler(NettyTsiHandshaker handshaker) {
|
||||
this.handshaker = checkNotNull(handshaker);
|
||||
}
|
||||
|
||||
/**
|
||||
* Event that is fired once the TSI handshake is complete, which may be because it was successful
|
||||
* or there was an error.
|
||||
*/
|
||||
public static final class TsiHandshakeCompletionEvent {
|
||||
|
||||
private final Throwable cause;
|
||||
private final TsiPeer peer;
|
||||
private final Object context;
|
||||
private final TsiFrameProtector protector;
|
||||
|
||||
/** Creates a new event that indicates a successful handshake. */
|
||||
@VisibleForTesting
|
||||
TsiHandshakeCompletionEvent(
|
||||
TsiFrameProtector protector, TsiPeer peer, @Nullable Object peerObject) {
|
||||
this.cause = null;
|
||||
this.peer = checkNotNull(peer);
|
||||
this.protector = checkNotNull(protector);
|
||||
this.context = peerObject;
|
||||
}
|
||||
|
||||
/** Creates a new event that indicates an unsuccessful handshake/. */
|
||||
TsiHandshakeCompletionEvent(Throwable cause) {
|
||||
this.cause = checkNotNull(cause);
|
||||
this.peer = null;
|
||||
this.protector = null;
|
||||
this.context = null;
|
||||
}
|
||||
|
||||
/** Return {@code true} if the handshake was successful. */
|
||||
public boolean isSuccess() {
|
||||
return cause == null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the {@link Throwable} if {@link #isSuccess()} returns {@code false} and so the
|
||||
* handshake failed.
|
||||
*/
|
||||
@Nullable
|
||||
public Throwable cause() {
|
||||
return cause;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public TsiPeer peer() {
|
||||
return peer;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Object context() {
|
||||
return context;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
TsiFrameProtector protector() {
|
||||
return protector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return MoreObjects.toStringHelper(this)
|
||||
.add("peer", peer)
|
||||
.add("protector", protector)
|
||||
.add("context", context)
|
||||
.add("cause", cause)
|
||||
.toString();
|
||||
}
|
||||
public TsiHandshakeHandler(
|
||||
ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator) {
|
||||
this.handshaker = checkNotNull(handshaker, "handshaker");
|
||||
this.handshakeValidator = checkNotNull(handshakeValidator, "handshakeValidator");
|
||||
this.next = checkNotNull(next, "next");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
|
||||
logger.finest("TsiHandshakeHandler added");
|
||||
maybeStart(ctx);
|
||||
super.handlerAdded(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx) throws Exception {
|
||||
logger.finest("TsiHandshakeHandler channel active");
|
||||
maybeStart(ctx);
|
||||
super.channelActive(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||
logger.finest("TsiHandshakeHandler handler removed");
|
||||
close();
|
||||
super.handlerRemoved0(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
logger.log(Level.FINEST, "Exception in TsiHandshakeHandler", cause);
|
||||
ctx.fireUserEventTriggered(new TsiHandshakeCompletionEvent(cause));
|
||||
super.exceptionCaught(ctx, cause);
|
||||
InternalProtocolNegotiators.negotiationLogger(ctx)
|
||||
.log(ChannelLogLevel.INFO, "TsiHandshake started");
|
||||
sendHandshake(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -168,71 +119,72 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
|
|||
|
||||
// If the handshake is complete, transition to the framing state.
|
||||
if (!handshaker.isInProgress()) {
|
||||
TsiFrameProtector protector = null;
|
||||
TsiPeer peer = handshaker.extractPeer();
|
||||
Object authContext = handshaker.extractPeerObject();
|
||||
SecurityDetails details = handshakeValidator.validatePeerObject(authContext);
|
||||
// createFrameProtector must be called last.
|
||||
TsiFrameProtector protector = handshaker.createFrameProtector(ctx.alloc());
|
||||
TsiFrameHandler framer;
|
||||
boolean success = false;
|
||||
try {
|
||||
ctx.pipeline().remove(this);
|
||||
protector = handshaker.createFrameProtector(ctx.alloc());
|
||||
TsiHandshakeCompletionEvent evt = new TsiHandshakeCompletionEvent(
|
||||
protector,
|
||||
handshaker.extractPeer(),
|
||||
handshaker.extractPeerObject());
|
||||
protector = null;
|
||||
ctx.fireUserEventTriggered(evt);
|
||||
// No need to do anything with the in buffer, it will be re added to the pipeline when this
|
||||
// handler is removed.
|
||||
framer = new TsiFrameHandler(protector);
|
||||
// replace the current handler with the framer (instead of adding before) since there may
|
||||
// be pending data after the handshake frame. The data will need to be decoded before
|
||||
// being passed to the `next` handler.
|
||||
ctx.pipeline().replace(ctx.name(), null, framer);
|
||||
// Once the framer is in the pipeline, it will be cleaned up when the handler is removed.
|
||||
success = true;
|
||||
} finally {
|
||||
if (protector != null) {
|
||||
if (!success && protector != null) {
|
||||
protector.destroy();
|
||||
}
|
||||
close();
|
||||
}
|
||||
// Add the `next` handler as late as possible, as it will issue writes on being added.
|
||||
ctx.pipeline().addAfter(ctx.pipeline().context(framer).name(), null, next);
|
||||
fireProtocolNegotiationEvent(ctx, peer, authContext, details);
|
||||
}
|
||||
}
|
||||
|
||||
private void maybeStart(ChannelHandlerContext ctx) {
|
||||
if (!started && ctx.channel().isActive()) {
|
||||
started = true;
|
||||
sendHandshake(ctx);
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
|
||||
if (evt instanceof ProtocolNegotiationEvent) {
|
||||
pne = (ProtocolNegotiationEvent) evt;
|
||||
} else {
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
}
|
||||
|
||||
private void fireProtocolNegotiationEvent(
|
||||
ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details) {
|
||||
InternalProtocolNegotiators.negotiationLogger(ctx)
|
||||
.log(ChannelLogLevel.INFO, "TsiHandshake finished");
|
||||
ProtocolNegotiationEvent localPne = pne;
|
||||
Attributes.Builder attrs = InternalProtocolNegotiationEvent.getAttributes(localPne).toBuilder()
|
||||
.set(TSI_PEER_KEY, peer)
|
||||
.set(AUTH_CONTEXT_KEY, authContext)
|
||||
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, details.getSecurityLevel());
|
||||
localPne = InternalProtocolNegotiationEvent.withAttributes(localPne, attrs.build());
|
||||
localPne = InternalProtocolNegotiationEvent.withSecurity(localPne, details.getSecurity());
|
||||
ctx.fireUserEventTriggered(localPne);
|
||||
}
|
||||
|
||||
/** Sends as many bytes as are available from the handshaker to the remote peer. */
|
||||
private void sendHandshake(ChannelHandlerContext ctx) {
|
||||
boolean needToFlush = false;
|
||||
|
||||
// Iterate until there is nothing left to write.
|
||||
@SuppressWarnings("FutureReturnValueIgnored") // for addListener
|
||||
private void sendHandshake(ChannelHandlerContext ctx) throws GeneralSecurityException {
|
||||
while (true) {
|
||||
buffer = getOrCreateBuffer(ctx.alloc());
|
||||
boolean written = false;
|
||||
ByteBuf buf = ctx.alloc().buffer(HANDSHAKE_FRAME_SIZE).retain(); // refcnt = 2
|
||||
try {
|
||||
handshaker.getBytesToSendToPeer(buffer);
|
||||
} catch (GeneralSecurityException e) {
|
||||
throw new RuntimeException(e);
|
||||
handshaker.getBytesToSendToPeer(buf);
|
||||
if (buf.isReadable()) {
|
||||
ctx.writeAndFlush(buf).addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
written = true;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} finally {
|
||||
buf.release(written ? 1 : 2);
|
||||
}
|
||||
if (!buffer.isReadable()) {
|
||||
break;
|
||||
}
|
||||
|
||||
needToFlush = true;
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError = ctx.write(buffer);
|
||||
buffer = null;
|
||||
}
|
||||
|
||||
// If something was written, flush.
|
||||
if (needToFlush) {
|
||||
ctx.flush();
|
||||
}
|
||||
}
|
||||
|
||||
private ByteBuf getOrCreateBuffer(ByteBufAllocator alloc) {
|
||||
if (buffer == null) {
|
||||
buffer = alloc.buffer(HANDSHAKE_FRAME_SIZE);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
private void close() {
|
||||
ReferenceCountUtil.safeRelease(buffer);
|
||||
buffer = null;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package io.grpc.alts;
|
|||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
|
@ -28,17 +27,14 @@ import org.junit.runners.JUnit4;
|
|||
public final class AltsChannelBuilderTest {
|
||||
|
||||
@Test
|
||||
public void buildsNettyChannel() throws Exception {
|
||||
public void buildsNettyChannel() {
|
||||
AltsChannelBuilder builder =
|
||||
AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting();
|
||||
|
||||
ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
|
||||
assertThat(protocolNegotiator).isNull();
|
||||
|
||||
builder.build();
|
||||
|
||||
protocolNegotiator = builder.getProtocolNegotiatorForTest();
|
||||
assertThat(protocolNegotiator).isNotNull();
|
||||
assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class);
|
||||
// Avoids exposing this class
|
||||
assertThat(protocolNegotiator.getClass().getSimpleName())
|
||||
.isEqualTo("ClientAltsProtocolNegotiator");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package io.grpc.alts;
|
|||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
|
@ -33,6 +32,7 @@ public final class ComputeEngineChannelBuilderTest {
|
|||
builder.build();
|
||||
|
||||
ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
|
||||
assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class);
|
||||
assertThat(protocolNegotiator.getClass().getSimpleName())
|
||||
.isEqualTo("GoogleDefaultProtocolNegotiator");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package io.grpc.alts;
|
|||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
|
@ -33,6 +32,7 @@ public final class GoogleDefaultChannelBuilderTest {
|
|||
builder.build();
|
||||
|
||||
ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
|
||||
assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class);
|
||||
assertThat(protocolNegotiator.getClass().getSimpleName())
|
||||
.isEqualTo("GoogleDefaultProtocolNegotiator");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import io.grpc.Attributes;
|
|||
import io.grpc.Channel;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.InternalChannelz;
|
||||
import io.grpc.InternalChannelz.Security;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.SecurityLevel;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
|
||||
|
|
@ -80,6 +81,7 @@ import org.junit.runners.JUnit4;
|
|||
|
||||
/** Tests for {@link AltsProtocolNegotiator}. */
|
||||
@RunWith(JUnit4.class)
|
||||
@SuppressWarnings("FutureReturnValueIgnored")
|
||||
public class AltsProtocolNegotiatorTest {
|
||||
|
||||
private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
|
||||
|
|
@ -90,7 +92,6 @@ public class AltsProtocolNegotiatorTest {
|
|||
private EmbeddedChannel channel;
|
||||
private Throwable caughtException;
|
||||
|
||||
private volatile TsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
|
||||
private ChannelHandler handler;
|
||||
|
||||
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
|
||||
|
|
@ -102,12 +103,12 @@ public class AltsProtocolNegotiatorTest {
|
|||
private final TsiHandshaker mockHandshaker =
|
||||
new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
|
||||
@Override
|
||||
public TsiPeer extractPeer() throws GeneralSecurityException {
|
||||
public TsiPeer extractPeer() {
|
||||
return mockedTsiPeer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object extractPeerObject() throws GeneralSecurityException {
|
||||
public Object extractPeerObject() {
|
||||
return mockedAltsContext;
|
||||
}
|
||||
};
|
||||
|
|
@ -115,24 +116,13 @@ public class AltsProtocolNegotiatorTest {
|
|||
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
ChannelHandler userEventHandler =
|
||||
new ChannelDuplexHandler() {
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
|
||||
if (evt instanceof TsiHandshakeHandler.TsiHandshakeCompletionEvent) {
|
||||
tsiEvent = (TsiHandshakeHandler.TsiHandshakeCompletionEvent) evt;
|
||||
} else {
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ChannelHandler uncaughtExceptionHandler =
|
||||
new ChannelDuplexHandler() {
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
caughtException = cause;
|
||||
super.exceptionCaught(ctx, cause);
|
||||
ctx.close();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -157,9 +147,9 @@ public class AltsProtocolNegotiatorTest {
|
|||
ObjectPool<Channel> fakeChannelPool = new FixedObjectPool<Channel>(fakeChannel);
|
||||
LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool);
|
||||
handler =
|
||||
AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, lazyFakeChannel)
|
||||
new AltsProtocolNegotiator.ServerAltsProtocolNegotiator(handshakerFactory, lazyFakeChannel)
|
||||
.newHandler(grpcHandler);
|
||||
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
|
||||
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler);
|
||||
}
|
||||
|
||||
@After
|
||||
|
|
@ -182,6 +172,8 @@ public class AltsProtocolNegotiatorTest {
|
|||
@Test
|
||||
@SuppressWarnings("unchecked") // List cast
|
||||
public void protectShouldRoundtrip() throws Exception {
|
||||
doHandshake();
|
||||
|
||||
// Write the message 1 character at a time. The message should be buffered
|
||||
// and not interfere with the handshake.
|
||||
final AtomicInteger writeCount = new AtomicInteger();
|
||||
|
|
@ -204,10 +196,6 @@ public class AltsProtocolNegotiatorTest {
|
|||
}
|
||||
channel.flush();
|
||||
|
||||
// Now do the handshake. The buffered message will automatically be protected
|
||||
// and sent.
|
||||
doHandshake();
|
||||
|
||||
// Capture the protected data written to the wire.
|
||||
assertEquals(1, channel.outboundMessages().size());
|
||||
ByteBuf protectedData = channel.readOutbound();
|
||||
|
|
@ -351,7 +339,7 @@ public class AltsProtocolNegotiatorTest {
|
|||
doHandshake();
|
||||
|
||||
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer);
|
||||
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY))
|
||||
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY))
|
||||
.isEqualTo(mockedAltsContext);
|
||||
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
|
||||
.isEqualTo("embedded");
|
||||
|
|
@ -388,7 +376,7 @@ public class AltsProtocolNegotiatorTest {
|
|||
if (caughtException != null) {
|
||||
throw new RuntimeException(caughtException);
|
||||
}
|
||||
assertNotNull(tsiEvent);
|
||||
assertNotNull(grpcHandler.attrs);
|
||||
}
|
||||
|
||||
private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
|
||||
|
|
@ -408,6 +396,7 @@ public class AltsProtocolNegotiatorTest {
|
|||
private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
|
||||
|
||||
private Attributes attrs;
|
||||
private Security securityInfo;
|
||||
|
||||
private CapturingGrpcHttp2ConnectionHandler(
|
||||
Http2ConnectionDecoder decoder,
|
||||
|
|
@ -422,6 +411,7 @@ public class AltsProtocolNegotiatorTest {
|
|||
// If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler
|
||||
channel.pipeline().remove(this);
|
||||
this.attrs = attrs;
|
||||
this.securityInfo = securityInfo;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,16 +16,27 @@
|
|||
|
||||
package io.grpc.alts.internal;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import io.grpc.Attributes;
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.internal.GrpcAttributes;
|
||||
import io.grpc.internal.ObjectPool;
|
||||
import io.grpc.netty.GrpcHttp2ConnectionHandler;
|
||||
import io.grpc.netty.GrpcSslContexts;
|
||||
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
|
@ -33,16 +44,36 @@ import org.junit.runners.JUnit4;
|
|||
|
||||
@RunWith(JUnit4.class)
|
||||
public final class GoogleDefaultProtocolNegotiatorTest {
|
||||
private ProtocolNegotiator altsProtocolNegotiator;
|
||||
private ProtocolNegotiator tlsProtocolNegotiator;
|
||||
private GoogleDefaultProtocolNegotiator googleProtocolNegotiator;
|
||||
private ProtocolNegotiator googleProtocolNegotiator;
|
||||
|
||||
private final ObjectPool<Channel> handshakerChannelPool = new ObjectPool<Channel>() {
|
||||
|
||||
@Override
|
||||
public Channel getObject() {
|
||||
return InProcessChannelBuilder.forName("test").build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Channel returnObject(Object object) {
|
||||
((ManagedChannel) object).shutdownNow();
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
altsProtocolNegotiator = mock(ProtocolNegotiator.class);
|
||||
tlsProtocolNegotiator = mock(ProtocolNegotiator.class);
|
||||
googleProtocolNegotiator =
|
||||
new GoogleDefaultProtocolNegotiator(altsProtocolNegotiator, tlsProtocolNegotiator);
|
||||
public void setUp() throws Exception {
|
||||
SslContext sslContext = GrpcSslContexts.forClient().build();
|
||||
|
||||
googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory(
|
||||
ImmutableList.<String>of(),
|
||||
handshakerChannelPool,
|
||||
sslContext)
|
||||
.buildProtocolNegotiator();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
googleProtocolNegotiator.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -51,9 +82,24 @@ public final class GoogleDefaultProtocolNegotiatorTest {
|
|||
Attributes.newBuilder().set(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND, true).build();
|
||||
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
|
||||
when(mockHandler.getEagAttributes()).thenReturn(eagAttributes);
|
||||
googleProtocolNegotiator.newHandler(mockHandler);
|
||||
verify(altsProtocolNegotiator, times(1)).newHandler(mockHandler);
|
||||
verify(tlsProtocolNegotiator, never()).newHandler(mockHandler);
|
||||
|
||||
final AtomicReference<Throwable> failure = new AtomicReference<>();
|
||||
ChannelHandler exceptionCaught = new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
failure.set(cause);
|
||||
super.exceptionCaught(ctx, cause);
|
||||
}
|
||||
};
|
||||
ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler);
|
||||
EmbeddedChannel chan = new EmbeddedChannel(exceptionCaught);
|
||||
// Add the negotiator handler last, but to the front. Putting this in ctor above would make it
|
||||
// throw early.
|
||||
chan.pipeline().addFirst(h);
|
||||
|
||||
// Check that the message complained about the ALTS code, rather than SSL. ALTS throws on
|
||||
// being added, so it's hard to catch it at the right time to make this assertion.
|
||||
assertThat(failure.get()).hasMessageThat().contains("TsiHandshakeHandler");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -61,8 +107,11 @@ public final class GoogleDefaultProtocolNegotiatorTest {
|
|||
Attributes eagAttributes = Attributes.EMPTY;
|
||||
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
|
||||
when(mockHandler.getEagAttributes()).thenReturn(eagAttributes);
|
||||
googleProtocolNegotiator.newHandler(mockHandler);
|
||||
verify(altsProtocolNegotiator, never()).newHandler(mockHandler);
|
||||
verify(tlsProtocolNegotiator, times(1)).newHandler(mockHandler);
|
||||
when(mockHandler.getAuthority()).thenReturn("authority");
|
||||
|
||||
ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler);
|
||||
EmbeddedChannel chan = new EmbeddedChannel(h);
|
||||
|
||||
assertThat(chan.pipeline().first().getClass().getSimpleName()).isEqualTo("SslHandler");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,18 +18,13 @@ package io.grpc.alts.internal;
|
|||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static com.google.common.truth.Truth.assertWithMessage;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import io.grpc.alts.internal.TsiFrameHandler.State;
|
||||
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
|
||||
import io.grpc.alts.internal.TsiPeer.Property;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.util.CharsetUtil;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
|
@ -46,32 +41,17 @@ public class TsiFrameHandlerTest {
|
|||
@Rule
|
||||
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5));
|
||||
|
||||
private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler();
|
||||
private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler(new IdentityFrameProtector());
|
||||
private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler);
|
||||
|
||||
@Test
|
||||
public void writeAndFlush_beforeHandshakeEventShouldBeIgnored() {
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message before handshake finished", CharsetUtil.UTF_8);
|
||||
|
||||
channel.writeAndFlush(msg);
|
||||
|
||||
assertThat(channel.outboundMessages()).isEmpty();
|
||||
try {
|
||||
channel.checkException();
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
assertThat(e).hasMessageThat().contains(State.HANDSHAKE_NOT_FINISHED.name());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeAndFlush_handshakeSucceed() throws InterruptedException {
|
||||
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message after handshake finished", CharsetUtil.UTF_8);
|
||||
|
||||
channel.writeAndFlush(msg);
|
||||
Object actual = channel.readOutbound();
|
||||
|
||||
assertThat((Object) channel.readOutbound()).isEqualTo(msg);
|
||||
assertThat(actual).isEqualTo(msg);
|
||||
channel.close().sync();
|
||||
channel.checkException();
|
||||
}
|
||||
|
|
@ -92,40 +72,20 @@ public class TsiFrameHandlerTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeAndFlush_handshakeFailed() throws InterruptedException {
|
||||
channel.pipeline().fireUserEventTriggered(new TsiHandshakeCompletionEvent(new Exception()));
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
|
||||
|
||||
channel.writeAndFlush(msg);
|
||||
|
||||
assertThat(channel.outboundMessages()).isEmpty();
|
||||
channel.close().sync();
|
||||
channel.checkException();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void close_shouldFlushRemainingMessage() throws InterruptedException {
|
||||
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
|
||||
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
|
||||
channel.write(msg);
|
||||
|
||||
assertThat(channel.outboundMessages()).isEmpty();
|
||||
|
||||
channel.close().sync();
|
||||
Object actual = channel.readOutbound();
|
||||
|
||||
assertWithMessage("pending write should be flushed on close")
|
||||
.that((Object) channel.readOutbound()).isEqualTo(msg);
|
||||
assertWithMessage("pending write should be flushed on close").that(actual).isEqualTo(msg);
|
||||
channel.checkException();
|
||||
}
|
||||
|
||||
private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() {
|
||||
TsiFrameProtector protector = new IdentityFrameProtector();
|
||||
TsiPeer peer = new TsiPeer(new ArrayList<Property<?>>());
|
||||
return new TsiHandshakeCompletionEvent(protector, peer, new Object());
|
||||
}
|
||||
|
||||
private static final class IdentityFrameProtector implements TsiFrameProtector {
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
package io.grpc.netty;
|
||||
|
||||
import io.grpc.ChannelLogger;
|
||||
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
|
||||
import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler;
|
||||
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
|
|
@ -93,4 +94,9 @@ public final class InternalProtocolNegotiators {
|
|||
public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler next) {
|
||||
return new GrpcNegotiationHandler(next);
|
||||
}
|
||||
|
||||
public static ChannelHandler clientTlsHandler(
|
||||
ChannelHandler next, SslContext sslContext, String authority) {
|
||||
return new ClientTlsHandler(next, sslContext, authority);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue