From 4130c5a1b8016656081258f1ceff393b59b9c7b2 Mon Sep 17 00:00:00 2001 From: Chengyuan Zhang Date: Fri, 15 Jan 2021 16:51:57 -0800 Subject: [PATCH] alts, xds: backend handshake protocol selection support for xDS in directpath (#7783) Attaches an attribute on endpoint addresses resolved/discovered using xDS plugin. The attribute indicates whether the endpoint address is a direct Google service endpoint or a CFE. This lets the GoogleDefault credentials choose between ALTS (direct Google service endpoint) and TLS (CFE). Due to dependency relation between grpc-xds and grpc-alts, GoogleDefault credentials will use the attribute key defined in grpc-xds reflectively. --- .../alts/ComputeEngineChannelCredentials.java | 3 +- .../alts/GoogleDefaultChannelCredentials.java | 24 +++++++++++- .../alts/internal/AltsProtocolNegotiator.java | 21 ++++++++--- .../GoogleDefaultProtocolNegotiatorTest.java | 37 ++++++++++++++++--- .../io/grpc/xds/ClusterImplLoadBalancer.java | 20 +++++----- .../io/grpc/xds/InternalXdsAttributes.java | 9 +++++ .../grpc/xds/ClusterImplLoadBalancerTest.java | 37 ++++++++++++++++--- 7 files changed, 123 insertions(+), 28 deletions(-) diff --git a/alts/src/main/java/io/grpc/alts/ComputeEngineChannelCredentials.java b/alts/src/main/java/io/grpc/alts/ComputeEngineChannelCredentials.java index 387fe6da69..9ec4e05889 100644 --- a/alts/src/main/java/io/grpc/alts/ComputeEngineChannelCredentials.java +++ b/alts/src/main/java/io/grpc/alts/ComputeEngineChannelCredentials.java @@ -69,6 +69,7 @@ public final class ComputeEngineChannelCredentials { return new GoogleDefaultProtocolNegotiatorFactory( /* targetServiceAccounts= */ ImmutableList.of(), SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), - sslContext); + sslContext, + null); } } diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java index 9c8b39cfc1..41c6b99827 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java @@ -18,6 +18,7 @@ package io.grpc.alts; import com.google.auth.oauth2.GoogleCredentials; import com.google.common.collect.ImmutableList; +import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.CompositeChannelCredentials; @@ -31,6 +32,8 @@ import io.grpc.netty.InternalNettyChannelCredentials; import io.grpc.netty.InternalProtocolNegotiator; import io.netty.handler.ssl.SslContext; import java.io.IOException; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.net.ssl.SSLException; /** @@ -39,6 +42,8 @@ import javax.net.ssl.SSLException; */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7479") public final class GoogleDefaultChannelCredentials { + private static Logger logger = Logger.getLogger(GoogleDefaultChannelCredentials.class.getName()); + private GoogleDefaultChannelCredentials() {} /** @@ -61,6 +66,7 @@ public final class GoogleDefaultChannelCredentials { return CompositeChannelCredentials.create(nettyCredentials, callCredentials); } + @SuppressWarnings("unchecked") private static InternalProtocolNegotiator.ClientFactory createClientFactory() { SslContext sslContext; try { @@ -68,9 +74,25 @@ public final class GoogleDefaultChannelCredentials { } catch (SSLException e) { throw new RuntimeException(e); } + Attributes.Key clusterNameAttrKey = null; + try { + Class klass = Class.forName("io.grpc.xds.InternalXdsAttributes"); + clusterNameAttrKey = + (Attributes.Key) klass.getField("ATTR_CLUSTER_NAME").get(null); + } catch (ClassNotFoundException e) { + logger.log(Level.FINE, + "Unable to load xDS endpoint cluster name key, this may be expected", e); + } catch (NoSuchFieldException e) { + logger.log(Level.FINE, + "Unable to load xDS endpoint cluster name key, this may be expected", e); + } catch (IllegalAccessException e) { + logger.log(Level.FINE, + "Unable to load xDS endpoint cluster name key, this may be expected", e); + } return new GoogleDefaultProtocolNegotiatorFactory( /* targetServiceAccounts= */ ImmutableList.of(), SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), - sslContext); + sslContext, + clusterNameAttrKey); } } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index d2040dff02..0b4a90748d 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -186,6 +186,8 @@ public final class AltsProtocolNegotiator { private final ImmutableList targetServiceAccounts; private final ObjectPool handshakerChannelPool; private final SslContext sslContext; + @Nullable + private final Attributes.Key clusterNameAttrKey; /** * Creates Negotiator Factory, which will either use the targetServiceAccounts and @@ -194,10 +196,12 @@ public final class AltsProtocolNegotiator { public GoogleDefaultProtocolNegotiatorFactory( List targetServiceAccounts, ObjectPool handshakerChannelPool, - SslContext sslContext) { + SslContext sslContext, + @Nullable Attributes.Key clusterNameAttrKey) { this.targetServiceAccounts = ImmutableList.copyOf(targetServiceAccounts); this.handshakerChannelPool = checkNotNull(handshakerChannelPool, "handshakerChannelPool"); this.sslContext = checkNotNull(sslContext, "sslContext"); + this.clusterNameAttrKey = clusterNameAttrKey; } @Override @@ -205,7 +209,8 @@ public final class AltsProtocolNegotiator { return new GoogleDefaultProtocolNegotiator( targetServiceAccounts, handshakerChannelPool, - sslContext); + sslContext, + clusterNameAttrKey); } @Override @@ -218,15 +223,19 @@ public final class AltsProtocolNegotiator { private final TsiHandshakerFactory handshakerFactory; private final LazyChannel lazyHandshakerChannel; private final SslContext sslContext; + @Nullable + private final Attributes.Key clusterNameAttrKey; GoogleDefaultProtocolNegotiator( ImmutableList targetServiceAccounts, ObjectPool handshakerChannelPool, - SslContext sslContext) { + SslContext sslContext, + @Nullable Attributes.Key clusterNameAttrKey) { this.lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); this.handshakerFactory = new ClientTsiHandshakerFactory(targetServiceAccounts, lazyHandshakerChannel); this.sslContext = checkNotNull(sslContext, "checkNotNull"); + this.clusterNameAttrKey = clusterNameAttrKey; } @Override @@ -238,9 +247,11 @@ public final class AltsProtocolNegotiator { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler); ChannelHandler securityHandler; + boolean isXdsDirectPath = clusterNameAttrKey != null + && !"google_cfe".equals(grpcHandler.getEagAttributes().get(clusterNameAttrKey)); if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null - || grpcHandler.getEagAttributes().get( - GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null) { + || grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null + || isXdsDirectPath) { TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority()); NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker); securityHandler = diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index 85603636ec..bc01f83102 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -36,17 +36,25 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.ssl.SslContext; +import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public final class GoogleDefaultProtocolNegotiatorTest { + @Parameterized.Parameter + public boolean withXds; + private ProtocolNegotiator googleProtocolNegotiator; + // Same as io.grpc.xds.InternalXdsAttributes.ATTR_CLUSTER_NAME + private final Attributes.Key clusterNameAttrKey = + Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName"); private final ObjectPool handshakerChannelPool = new ObjectPool() { @Override @@ -61,6 +69,11 @@ public final class GoogleDefaultProtocolNegotiatorTest { } }; + @Parameters(name = "Run with xDS : {0}") + public static Iterable data() { + return Arrays.asList(true, false); + } + @Before public void setUp() throws Exception { SslContext sslContext = GrpcSslContexts.forClient().build(); @@ -68,7 +81,8 @@ public final class GoogleDefaultProtocolNegotiatorTest { googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory( ImmutableList.of(), handshakerChannelPool, - sslContext) + sslContext, + withXds ? clusterNameAttrKey : null) .newNegotiator(); } @@ -79,8 +93,14 @@ public final class GoogleDefaultProtocolNegotiatorTest { @Test public void altsHandler() { - Attributes eagAttributes = - Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build(); + Attributes eagAttributes; + if (withXds) { + eagAttributes = + Attributes.newBuilder().set(clusterNameAttrKey, "api.googleapis.com").build(); + } else { + eagAttributes = + Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build(); + } GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); @@ -106,7 +126,12 @@ public final class GoogleDefaultProtocolNegotiatorTest { @Test public void tlsHandler() { - Attributes eagAttributes = Attributes.EMPTY; + Attributes eagAttributes; + if (withXds) { + eagAttributes = Attributes.newBuilder().set(clusterNameAttrKey, "google_cfe").build(); + } else { + eagAttributes = Attributes.EMPTY; + } GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); when(mockHandler.getAuthority()).thenReturn("authority"); diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index efa6d4c7a3..0ad72c8be8 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -195,18 +195,18 @@ final class ClusterImplLoadBalancer extends LoadBalancer { @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { - if (enableSecurity && sslContextProviderSupplier != null) { - List addresses = new ArrayList<>(); - for (EquivalentAddressGroup eag : args.getAddresses()) { - Attributes attributes = - eag.getAttributes().toBuilder() - .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - sslContextProviderSupplier) - .build(); - addresses.add(new EquivalentAddressGroup(eag.getAddresses(), attributes)); + List addresses = new ArrayList<>(); + for (EquivalentAddressGroup eag : args.getAddresses()) { + Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set( + InternalXdsAttributes.ATTR_CLUSTER_NAME, cluster); + if (enableSecurity && sslContextProviderSupplier != null) { + attrBuilder.set( + InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + sslContextProviderSupplier); } - args = args.toBuilder().setAddresses(addresses).build(); + addresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build())); } + args = args.toBuilder().setAddresses(addresses).build(); return delegate().createSubchannel(args); } diff --git a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java index 85f9ebc520..229bcff583 100644 --- a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java @@ -17,6 +17,7 @@ package io.grpc.xds; import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; import io.grpc.Internal; import io.grpc.NameResolver; @@ -53,6 +54,14 @@ public final class InternalXdsAttributes { static final Attributes.Key CALL_COUNTER_PROVIDER = Attributes.Key.create("io.grpc.xds.XdsAttributes.callCounterProvider"); + /** + * Name of the cluster that provides this EquivalentAddressGroup. + */ + @Internal + @EquivalentAddressGroup.Attr + public static final Attributes.Key ATTR_CLUSTER_NAME = + Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName"); + // TODO (chengyuanzhang): temporary solution for migrating to LRS policy. Should access // stats object via XdsClient interface. static final Attributes.Key ATTR_CLUSTER_SERVICE_LOAD_STATS_STORE = diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index de05332862..f770065ec3 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -372,19 +372,46 @@ public class ClusterImplLoadBalancerTest { } @Test - public void endpointConnectionWithTls_enableSecurity() { + public void endpointAddressesAttachedWithClusterName() { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_NAME, + null, Collections.emptyList(), + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + // One locality with two endpoints. + EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality); + EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality); + deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); + assertThat(downstreamBalancers).hasSize(1); // one leaf balancer + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + assertThat(leafBalancer.name).isEqualTo("round_robin"); + // Simulates leaf load balancer creating subchannels. + CreateSubchannelArgs args = + CreateSubchannelArgs.newBuilder() + .setAddresses(leafBalancer.addresses) + .build(); + Subchannel subchannel = leafBalancer.helper.createSubchannel(args); + for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { + assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME)) + .isEqualTo(CLUSTER); + } + } + + @Test + public void endpointAddressesAttachedWithTlsConfig_enableSecurity() { boolean originalEnableSecurity = ClusterImplLoadBalancer.enableSecurity; ClusterImplLoadBalancer.enableSecurity = true; - subtest_endpointConnectionWithTls(true); + subtest_endpointAddressesAttachedWithTlsConfig(true); ClusterImplLoadBalancer.enableSecurity = originalEnableSecurity; } @Test - public void endpointConnectionWithTls_securityDisabledByDefault() { - subtest_endpointConnectionWithTls(false); + public void endpointAddressesAttachedWithTlsConfig_securityDisabledByDefault() { + subtest_endpointAddressesAttachedWithTlsConfig(false); } - private void subtest_endpointConnectionWithTls(boolean enableSecurity) { + private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecurity) { UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.CLIENT_KEY_FILE,