ALTS: release handshaker channel if no longer needed (#5210)

* ALTS: release handshaker channel if no longer needed
This commit is contained in:
Jiangtao Li 2019-01-11 14:57:08 -08:00 committed by GitHub
parent 7a547276da
commit 4d90b37a0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 34 deletions

View File

@ -30,6 +30,7 @@ import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions; import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator; import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker; import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.HandshakerServiceGrpc; import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil; import io.grpc.alts.internal.RpcProtocolVersionsUtil;
@ -104,8 +105,9 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
public AltsChannelBuilder setHandshakerAddressForTesting(String handshakerAddress) { public AltsChannelBuilder setHandshakerAddressForTesting(String handshakerAddress) {
// Instead of using the default shared channel to the handshaker service, create a separate // Instead of using the default shared channel to the handshaker service, create a separate
// resource to the test address. // resource to the test address.
handshakerChannelPool = SharedResourcePool.forResource( handshakerChannelPool =
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress)); SharedResourcePool.forResource(
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress));
return this; return this;
} }
@ -144,13 +146,11 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
@Override @Override
public AltsProtocolNegotiator buildProtocolNegotiator() { public AltsProtocolNegotiator buildProtocolNegotiator() {
final ImmutableList<String> targetServiceAccounts = targetServiceAccountsBuilder.build(); final ImmutableList<String> targetServiceAccounts = targetServiceAccountsBuilder.build();
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
TsiHandshakerFactory altsHandshakerFactory = TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() { new TsiHandshakerFactory() {
@Override @Override
public TsiHandshaker newHandshaker(String authority) { public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
AltsClientOptions handshakerOptions = AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder() new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
@ -158,12 +158,12 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
.setTargetName(authority) .setTargetName(authority)
.build(); .build();
return AltsTsiHandshaker.newClient( return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()), HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
handshakerOptions);
} }
}; };
return negotiatorForTest = return negotiatorForTest =
AltsProtocolNegotiator.createClientNegotiator(altsHandshakerFactory); AltsProtocolNegotiator.createClientNegotiator(
altsHandshakerFactory, lazyHandshakerChannel);
} }
} }

View File

@ -35,6 +35,7 @@ import io.grpc.ServerTransportFilter;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.alts.internal.AltsHandshakerOptions; import io.grpc.alts.internal.AltsHandshakerOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator; import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker; import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.HandshakerServiceGrpc; import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil; import io.grpc.alts.internal.RpcProtocolVersionsUtil;
@ -92,8 +93,9 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
public AltsServerBuilder setHandshakerAddressForTesting(String handshakerAddress) { public AltsServerBuilder setHandshakerAddressForTesting(String handshakerAddress) {
// Instead of using the default shared channel to the handshaker service, create a separate // Instead of using the default shared channel to the handshaker service, create a separate
// resource to the test address. // resource to the test address.
handshakerChannelPool = SharedResourcePool.forResource( handshakerChannelPool =
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress)); SharedResourcePool.forResource(
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress));
return this; return this;
} }
@ -196,19 +198,18 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
} }
} }
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
delegate.protocolNegotiator( delegate.protocolNegotiator(
AltsProtocolNegotiator.createServerNegotiator( AltsProtocolNegotiator.createServerNegotiator(
new TsiHandshakerFactory() { new TsiHandshakerFactory() {
@Override @Override
public TsiHandshaker newHandshaker(String authority) { public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
return AltsTsiHandshaker.newServer( return AltsTsiHandshaker.newServer(
HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()), HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions())); new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
} }
})); },
lazyHandshakerChannel));
return delegate.build(); return delegate.build();
} }

View File

@ -28,6 +28,7 @@ import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions; import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker; import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator; import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
import io.grpc.alts.internal.HandshakerServiceGrpc; import io.grpc.alts.internal.HandshakerServiceGrpc;
@ -36,7 +37,7 @@ import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory; import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.auth.MoreCallCredentials; import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
@ -94,24 +95,23 @@ public final class GoogleDefaultChannelBuilder
private final class ProtocolNegotiatorFactory private final class ProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
@Override @Override
public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() { public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
final LazyChannel lazyHandshakerChannel =
new LazyChannel(
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL));
TsiHandshakerFactory altsHandshakerFactory = TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() { new TsiHandshakerFactory() {
@Override @Override
public TsiHandshaker newHandshaker(String authority) { public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
Channel channel =
SharedResourceHolder.get(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
AltsClientOptions handshakerOptions = AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder() new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetName(authority) .setTargetName(authority)
.build(); .build();
return AltsTsiHandshaker.newClient( return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(channel), handshakerOptions); HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
} }
}; };
SslContext sslContext; SslContext sslContext;
@ -121,7 +121,8 @@ public final class GoogleDefaultChannelBuilder
throw new RuntimeException(ex); throw new RuntimeException(ex);
} }
return negotiatorForTest = return negotiatorForTest =
new GoogleDefaultProtocolNegotiator(altsHandshakerFactory, sslContext); new GoogleDefaultProtocolNegotiator(
altsHandshakerFactory, lazyHandshakerChannel, sslContext);
} }
} }

View File

@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.protobuf.Any; import com.google.protobuf.Any;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.Grpc; import io.grpc.Grpc;
import io.grpc.InternalChannelz.OtherSecurity; import io.grpc.InternalChannelz.OtherSecurity;
import io.grpc.InternalChannelz.Security; import io.grpc.InternalChannelz.Security;
@ -28,6 +29,7 @@ import io.grpc.Status;
import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult; import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent; import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiator;
import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler; import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler;
@ -47,15 +49,18 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
@Grpc.TransportAttr @Grpc.TransportAttr
public static final Attributes.Key<TsiPeer> TSI_PEER_KEY = Attributes.Key.create("TSI_PEER"); public static final Attributes.Key<TsiPeer> TSI_PEER_KEY = Attributes.Key.create("TSI_PEER");
@Grpc.TransportAttr @Grpc.TransportAttr
public static final Attributes.Key<AltsAuthContext> ALTS_CONTEXT_KEY = public static final Attributes.Key<AltsAuthContext> ALTS_CONTEXT_KEY =
Attributes.Key.create("ALTS_CONTEXT_KEY"); Attributes.Key.create("ALTS_CONTEXT_KEY");
private static final AsciiString scheme = AsciiString.of("https"); private static final AsciiString scheme = AsciiString.of("https");
/** Creates a negotiator used for ALTS client. */ /** Creates a negotiator used for ALTS client. */
public static AltsProtocolNegotiator createClientNegotiator( public static AltsProtocolNegotiator createClientNegotiator(
final TsiHandshakerFactory handshakerFactory) { final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator { final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator {
@Override @Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority()); TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
@ -68,17 +73,18 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
@Override @Override
public void close() { public void close() {
logger.finest("ALTS Client ProtocolNegotiator Closed"); logger.finest("ALTS Client ProtocolNegotiator Closed");
// TODO(jiangtaoli2016): release resources lazyHandshakerChannel.close();
} }
} }
return new ClientAltsProtocolNegotiator(); return new ClientAltsProtocolNegotiator();
} }
/** Creates a negotiator used for ALTS server. */ /** Creates a negotiator used for ALTS server. */
public static AltsProtocolNegotiator createServerNegotiator( public static AltsProtocolNegotiator createServerNegotiator(
final TsiHandshakerFactory handshakerFactory) { final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator { final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator {
@Override @Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null); TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null);
@ -91,13 +97,41 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
@Override @Override
public void close() { public void close() {
logger.finest("ALTS Server ProtocolNegotiator Closed"); logger.finest("ALTS Server ProtocolNegotiator Closed");
// TODO(jiangtaoli2016): release resources lazyHandshakerChannel.close();
} }
} }
return new ServerAltsProtocolNegotiator(); return new ServerAltsProtocolNegotiator();
} }
/** Channel created from a channel pool lazily. */
public static class LazyChannel {
private final ObjectPool<Channel> channelPool;
private Channel channel;
public LazyChannel(ObjectPool<Channel> channelPool) {
this.channelPool = channelPool;
}
/**
* If channel is null, gets a channel from the channel pool, otherwise, returns the cached
* channel.
*/
public synchronized Channel get() {
if (channel == null) {
channel = channelPool.getObject();
}
return channel;
}
/** Returns the cached channel to the channel pool. */
public synchronized void close() {
if (channel != null) {
channelPool.returnObject(channel);
}
}
}
/** Buffers all writes until the ALTS handshake is complete. */ /** Buffers all writes until the ALTS handshake is complete. */
@VisibleForTesting @VisibleForTesting
static final class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler static final class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler
@ -129,7 +163,7 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (logger.isLoggable(Level.FINEST)) { if (logger.isLoggable(Level.FINEST)) {
logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[]{evt}); logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[] {evt});
} }
if (evt instanceof TsiHandshakeCompletionEvent) { if (evt instanceof TsiHandshakeCompletionEvent) {
TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt; TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt;

View File

@ -17,6 +17,7 @@
package io.grpc.alts.internal; package io.grpc.alts.internal;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcAttributes;
import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiator;
@ -25,11 +26,15 @@ import io.netty.handler.ssl.SslContext;
/** A client-side GPRC {@link ProtocolNegotiator} for Google Default Channel. */ /** A client-side GPRC {@link ProtocolNegotiator} for Google Default Channel. */
public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator { public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator {
private final ProtocolNegotiator altsProtocolNegotiator; private final ProtocolNegotiator altsProtocolNegotiator;
private final ProtocolNegotiator tlsProtocolNegotiator; private final ProtocolNegotiator tlsProtocolNegotiator;
public GoogleDefaultProtocolNegotiator(TsiHandshakerFactory altsFactory, SslContext sslContext) { /** Constructor for protocol negotiator of Google Default Channel. */
altsProtocolNegotiator = AltsProtocolNegotiator.createClientNegotiator(altsFactory); public GoogleDefaultProtocolNegotiator(
TsiHandshakerFactory altsFactory, LazyChannel lazyHandshakerChannel, SslContext sslContext) {
altsProtocolNegotiator =
AltsProtocolNegotiator.createClientNegotiator(altsFactory, lazyHandshakerChannel);
tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext); tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext);
} }

View File

@ -24,13 +24,19 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.Grpc; import io.grpc.Grpc;
import io.grpc.InternalChannelz; import io.grpc.InternalChannelz;
import io.grpc.ManagedChannel;
import io.grpc.SecurityLevel; import io.grpc.SecurityLevel;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.TsiFrameProtector.Consumer; import io.grpc.alts.internal.TsiFrameProtector.Consumer;
import io.grpc.alts.internal.TsiPeer.Property; import io.grpc.alts.internal.TsiPeer.Property;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.CompositeByteBuf;
@ -147,8 +153,12 @@ public class AltsProtocolNegotiatorTest {
}; };
} }
}; };
ManagedChannel fakeChannel = NettyChannelBuilder.forTarget("localhost:8080").build();
ObjectPool<Channel> fakeChannelPool = new FixedObjectPool<Channel>(fakeChannel);
LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool);
handler = handler =
AltsProtocolNegotiator.createServerNegotiator(handshakerFactory).newHandler(grpcHandler); AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, lazyFakeChannel)
.newHandler(grpcHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler); channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
} }
@ -340,8 +350,7 @@ public class AltsProtocolNegotiatorTest {
public void peerPropagated() throws Exception { public void peerPropagated() throws Exception {
doHandshake(); doHandshake();
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)) assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer);
.isEqualTo(mockedTsiPeer);
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY)) assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY))
.isEqualTo(mockedAltsContext); .isEqualTo(mockedAltsContext);
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString()) assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())