alts, xds: backend handshake protocol selection support for xDS in directpath (#7783)

Attaches an attribute on endpoint addresses resolved/discovered using xDS plugin. The attribute indicates whether the endpoint address is a direct Google service endpoint or a CFE. This lets the GoogleDefault credentials choose between ALTS (direct Google service endpoint) and TLS (CFE).

Due to dependency relation between grpc-xds and grpc-alts, GoogleDefault credentials will use the attribute key defined in grpc-xds reflectively.
This commit is contained in:
Chengyuan Zhang 2021-01-15 16:51:57 -08:00 committed by GitHub
parent 23d279660c
commit 4130c5a1b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 28 deletions

View File

@ -69,6 +69,7 @@ public final class ComputeEngineChannelCredentials {
return new GoogleDefaultProtocolNegotiatorFactory( return new GoogleDefaultProtocolNegotiatorFactory(
/* targetServiceAccounts= */ ImmutableList.<String>of(), /* targetServiceAccounts= */ ImmutableList.<String>of(),
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
sslContext); sslContext,
null);
} }
} }

View File

@ -18,6 +18,7 @@ package io.grpc.alts;
import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import io.grpc.Attributes;
import io.grpc.CallCredentials; import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials; import io.grpc.ChannelCredentials;
import io.grpc.CompositeChannelCredentials; import io.grpc.CompositeChannelCredentials;
@ -31,6 +32,8 @@ import io.grpc.netty.InternalNettyChannelCredentials;
import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import java.io.IOException; import java.io.IOException;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
/** /**
@ -39,6 +42,8 @@ import javax.net.ssl.SSLException;
*/ */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/7479") @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7479")
public final class GoogleDefaultChannelCredentials { public final class GoogleDefaultChannelCredentials {
private static Logger logger = Logger.getLogger(GoogleDefaultChannelCredentials.class.getName());
private GoogleDefaultChannelCredentials() {} private GoogleDefaultChannelCredentials() {}
/** /**
@ -61,6 +66,7 @@ public final class GoogleDefaultChannelCredentials {
return CompositeChannelCredentials.create(nettyCredentials, callCredentials); return CompositeChannelCredentials.create(nettyCredentials, callCredentials);
} }
@SuppressWarnings("unchecked")
private static InternalProtocolNegotiator.ClientFactory createClientFactory() { private static InternalProtocolNegotiator.ClientFactory createClientFactory() {
SslContext sslContext; SslContext sslContext;
try { try {
@ -68,9 +74,25 @@ public final class GoogleDefaultChannelCredentials {
} catch (SSLException e) { } catch (SSLException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
Attributes.Key<String> clusterNameAttrKey = null;
try {
Class<?> klass = Class.forName("io.grpc.xds.InternalXdsAttributes");
clusterNameAttrKey =
(Attributes.Key<String>) klass.getField("ATTR_CLUSTER_NAME").get(null);
} catch (ClassNotFoundException e) {
logger.log(Level.FINE,
"Unable to load xDS endpoint cluster name key, this may be expected", e);
} catch (NoSuchFieldException e) {
logger.log(Level.FINE,
"Unable to load xDS endpoint cluster name key, this may be expected", e);
} catch (IllegalAccessException e) {
logger.log(Level.FINE,
"Unable to load xDS endpoint cluster name key, this may be expected", e);
}
return new GoogleDefaultProtocolNegotiatorFactory( return new GoogleDefaultProtocolNegotiatorFactory(
/* targetServiceAccounts= */ ImmutableList.<String>of(), /* targetServiceAccounts= */ ImmutableList.<String>of(),
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL),
sslContext); sslContext,
clusterNameAttrKey);
} }
} }

View File

@ -186,6 +186,8 @@ public final class AltsProtocolNegotiator {
private final ImmutableList<String> targetServiceAccounts; private final ImmutableList<String> targetServiceAccounts;
private final ObjectPool<Channel> handshakerChannelPool; private final ObjectPool<Channel> handshakerChannelPool;
private final SslContext sslContext; private final SslContext sslContext;
@Nullable
private final Attributes.Key<String> clusterNameAttrKey;
/** /**
* Creates Negotiator Factory, which will either use the targetServiceAccounts and * Creates Negotiator Factory, which will either use the targetServiceAccounts and
@ -194,10 +196,12 @@ public final class AltsProtocolNegotiator {
public GoogleDefaultProtocolNegotiatorFactory( public GoogleDefaultProtocolNegotiatorFactory(
List<String> targetServiceAccounts, List<String> targetServiceAccounts,
ObjectPool<Channel> handshakerChannelPool, ObjectPool<Channel> handshakerChannelPool,
SslContext sslContext) { SslContext sslContext,
@Nullable Attributes.Key<String> clusterNameAttrKey) {
this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts); this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts);
this.handshakerChannelPool = checkNotNull(handshakerChannelPool, "handshakerChannelPool"); this.handshakerChannelPool = checkNotNull(handshakerChannelPool, "handshakerChannelPool");
this.sslContext = checkNotNull(sslContext, "sslContext"); this.sslContext = checkNotNull(sslContext, "sslContext");
this.clusterNameAttrKey = clusterNameAttrKey;
} }
@Override @Override
@ -205,7 +209,8 @@ public final class AltsProtocolNegotiator {
return new GoogleDefaultProtocolNegotiator( return new GoogleDefaultProtocolNegotiator(
targetServiceAccounts, targetServiceAccounts,
handshakerChannelPool, handshakerChannelPool,
sslContext); sslContext,
clusterNameAttrKey);
} }
@Override @Override
@ -218,15 +223,19 @@ public final class AltsProtocolNegotiator {
private final TsiHandshakerFactory handshakerFactory; private final TsiHandshakerFactory handshakerFactory;
private final LazyChannel lazyHandshakerChannel; private final LazyChannel lazyHandshakerChannel;
private final SslContext sslContext; private final SslContext sslContext;
@Nullable
private final Attributes.Key<String> clusterNameAttrKey;
GoogleDefaultProtocolNegotiator( GoogleDefaultProtocolNegotiator(
ImmutableList<String> targetServiceAccounts, ImmutableList<String> targetServiceAccounts,
ObjectPool<Channel> handshakerChannelPool, ObjectPool<Channel> handshakerChannelPool,
SslContext sslContext) { SslContext sslContext,
@Nullable Attributes.Key<String> clusterNameAttrKey) {
this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
this.handshakerFactory = this.handshakerFactory =
new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel); new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel);
this.sslContext = checkNotNull(sslContext, "checkNotNull"); this.sslContext = checkNotNull(sslContext, "checkNotNull");
this.clusterNameAttrKey = clusterNameAttrKey;
} }
@Override @Override
@ -238,9 +247,11 @@ public final class AltsProtocolNegotiator {
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler); ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
ChannelHandler securityHandler; ChannelHandler securityHandler;
boolean isXdsDirectPath = clusterNameAttrKey != null
&& !"google_cfe".equals(grpcHandler.getEagAttributes().get(clusterNameAttrKey));
if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null
|| grpcHandler.getEagAttributes().get( || grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null
GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null) { || isXdsDirectPath) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority()); TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker); NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
securityHandler = securityHandler =

View File

@ -36,17 +36,25 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
@RunWith(JUnit4.class) @RunWith(Parameterized.class)
public final class GoogleDefaultProtocolNegotiatorTest { public final class GoogleDefaultProtocolNegotiatorTest {
@Parameterized.Parameter
public boolean withXds;
private ProtocolNegotiator googleProtocolNegotiator; private ProtocolNegotiator googleProtocolNegotiator;
// Same as io.grpc.xds.InternalXdsAttributes.ATTR_CLUSTER_NAME
private final Attributes.Key<String> clusterNameAttrKey =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName");
private final ObjectPool<Channel> handshakerChannelPool = new ObjectPool<Channel>() { private final ObjectPool<Channel> handshakerChannelPool = new ObjectPool<Channel>() {
@Override @Override
@ -61,6 +69,11 @@ public final class GoogleDefaultProtocolNegotiatorTest {
} }
}; };
@Parameters(name = "Run with xDS : {0}")
public static Iterable<Boolean> data() {
return Arrays.asList(true, false);
}
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
SslContext sslContext = GrpcSslContexts.forClient().build(); SslContext sslContext = GrpcSslContexts.forClient().build();
@ -68,7 +81,8 @@ public final class GoogleDefaultProtocolNegotiatorTest {
googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory( googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory(
ImmutableList.<String>of(), ImmutableList.<String>of(),
handshakerChannelPool, handshakerChannelPool,
sslContext) sslContext,
withXds ? clusterNameAttrKey : null)
.newNegotiator(); .newNegotiator();
} }
@ -79,8 +93,14 @@ public final class GoogleDefaultProtocolNegotiatorTest {
@Test @Test
public void altsHandler() { public void altsHandler() {
Attributes eagAttributes = Attributes eagAttributes;
Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build(); if (withXds) {
eagAttributes =
Attributes.newBuilder().set(clusterNameAttrKey, "api.googleapis.com").build();
} else {
eagAttributes =
Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build();
}
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); when(mockHandler.getEagAttributes()).thenReturn(eagAttributes);
@ -106,7 +126,12 @@ public final class GoogleDefaultProtocolNegotiatorTest {
@Test @Test
public void tlsHandler() { public void tlsHandler() {
Attributes eagAttributes = Attributes.EMPTY; Attributes eagAttributes;
if (withXds) {
eagAttributes = Attributes.newBuilder().set(clusterNameAttrKey, "google_cfe").build();
} else {
eagAttributes = Attributes.EMPTY;
}
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); when(mockHandler.getEagAttributes()).thenReturn(eagAttributes);
when(mockHandler.getAuthority()).thenReturn("authority"); when(mockHandler.getAuthority()).thenReturn("authority");

View File

@ -195,18 +195,18 @@ final class ClusterImplLoadBalancer extends LoadBalancer {
@Override @Override
public Subchannel createSubchannel(CreateSubchannelArgs args) { public Subchannel createSubchannel(CreateSubchannelArgs args) {
if (enableSecurity && sslContextProviderSupplier != null) { List<EquivalentAddressGroup> addresses = new ArrayList<>();
List<EquivalentAddressGroup> addresses = new ArrayList<>(); for (EquivalentAddressGroup eag : args.getAddresses()) {
for (EquivalentAddressGroup eag : args.getAddresses()) { Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set(
Attributes attributes = InternalXdsAttributes.ATTR_CLUSTER_NAME, cluster);
eag.getAttributes().toBuilder() if (enableSecurity && sslContextProviderSupplier != null) {
.set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, attrBuilder.set(
sslContextProviderSupplier) InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
.build(); sslContextProviderSupplier);
addresses.add(new EquivalentAddressGroup(eag.getAddresses(), attributes));
} }
args = args.toBuilder().setAddresses(addresses).build(); addresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build()));
} }
args = args.toBuilder().setAddresses(addresses).build();
return delegate().createSubchannel(args); return delegate().createSubchannel(args);
} }

View File

@ -17,6 +17,7 @@
package io.grpc.xds; package io.grpc.xds;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Grpc; import io.grpc.Grpc;
import io.grpc.Internal; import io.grpc.Internal;
import io.grpc.NameResolver; import io.grpc.NameResolver;
@ -53,6 +54,14 @@ public final class InternalXdsAttributes {
static final Attributes.Key<CallCounterProvider> CALL_COUNTER_PROVIDER = static final Attributes.Key<CallCounterProvider> CALL_COUNTER_PROVIDER =
Attributes.Key.create("io.grpc.xds.XdsAttributes.callCounterProvider"); Attributes.Key.create("io.grpc.xds.XdsAttributes.callCounterProvider");
/**
* Name of the cluster that provides this EquivalentAddressGroup.
*/
@Internal
@EquivalentAddressGroup.Attr
public static final Attributes.Key<String> ATTR_CLUSTER_NAME =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName");
// TODO (chengyuanzhang): temporary solution for migrating to LRS policy. Should access // TODO (chengyuanzhang): temporary solution for migrating to LRS policy. Should access
// stats object via XdsClient interface. // stats object via XdsClient interface.
static final Attributes.Key<LoadStatsStore> ATTR_CLUSTER_SERVICE_LOAD_STATS_STORE = static final Attributes.Key<LoadStatsStore> ATTR_CLUSTER_SERVICE_LOAD_STATS_STORE =

View File

@ -372,19 +372,46 @@ public class ClusterImplLoadBalancerTest {
} }
@Test @Test
public void endpointConnectionWithTls_enableSecurity() { public void endpointAddressesAttachedWithClusterName() {
LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider();
WeightedTargetConfig weightedTargetConfig =
buildWeightedTargetConfig(ImmutableMap.of(locality, 10));
ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_NAME,
null, Collections.<DropOverload>emptyList(),
new PolicySelection(weightedTargetProvider, weightedTargetConfig), null);
// One locality with two endpoints.
EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality);
EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality);
deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config);
assertThat(downstreamBalancers).hasSize(1); // one leaf balancer
FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers);
assertThat(leafBalancer.name).isEqualTo("round_robin");
// Simulates leaf load balancer creating subchannels.
CreateSubchannelArgs args =
CreateSubchannelArgs.newBuilder()
.setAddresses(leafBalancer.addresses)
.build();
Subchannel subchannel = leafBalancer.helper.createSubchannel(args);
for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) {
assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME))
.isEqualTo(CLUSTER);
}
}
@Test
public void endpointAddressesAttachedWithTlsConfig_enableSecurity() {
boolean originalEnableSecurity = ClusterImplLoadBalancer.enableSecurity; boolean originalEnableSecurity = ClusterImplLoadBalancer.enableSecurity;
ClusterImplLoadBalancer.enableSecurity = true; ClusterImplLoadBalancer.enableSecurity = true;
subtest_endpointConnectionWithTls(true); subtest_endpointAddressesAttachedWithTlsConfig(true);
ClusterImplLoadBalancer.enableSecurity = originalEnableSecurity; ClusterImplLoadBalancer.enableSecurity = originalEnableSecurity;
} }
@Test @Test
public void endpointConnectionWithTls_securityDisabledByDefault() { public void endpointAddressesAttachedWithTlsConfig_securityDisabledByDefault() {
subtest_endpointConnectionWithTls(false); subtest_endpointAddressesAttachedWithTlsConfig(false);
} }
private void subtest_endpointConnectionWithTls(boolean enableSecurity) { private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecurity) {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.CLIENT_KEY_FILE, CommonTlsContextTestsUtil.CLIENT_KEY_FILE,