From 1958e42370d58dc313ed8810f2476ba577e8a7df Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Fri, 21 Mar 2025 15:19:40 -0700 Subject: [PATCH] xds: add support for custom per-target credentials on the transport (#11951) --- .../io/grpc/xds/GrpcXdsTransportFactory.java | 54 +++++++++---- .../InternalSharedXdsClientPoolProvider.java | 10 ++- .../grpc/xds/SharedXdsClientPoolProvider.java | 50 ++++++++---- .../grpc/xds/GrpcXdsClientImplTestBase.java | 3 +- .../grpc/xds/GrpcXdsTransportFactoryTest.java | 7 +- .../io/grpc/xds/LoadReportClientTest.java | 14 ++-- .../xds/SharedXdsClientPoolProviderTest.java | 77 +++++++++++++++++++ .../io/grpc/xds/XdsClientFallbackTest.java | 34 +++++--- 8 files changed, 198 insertions(+), 51 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java index 74c28ba2d2..0da51bf47f 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java +++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java @@ -19,6 +19,7 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; @@ -34,35 +35,50 @@ import java.util.concurrent.TimeUnit; final class GrpcXdsTransportFactory implements XdsTransportFactory { - static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY = - new GrpcXdsTransportFactory(); + private final CallCredentials callCredentials; + + GrpcXdsTransportFactory(CallCredentials callCredentials) { + this.callCredentials = callCredentials; + } @Override public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - return new GrpcXdsTransport(serverInfo); + return new GrpcXdsTransport(serverInfo, callCredentials); } @VisibleForTesting public XdsTransport createForTest(ManagedChannel channel) { - return new GrpcXdsTransport(channel); + return new GrpcXdsTransport(channel, callCredentials); } @VisibleForTesting static class GrpcXdsTransport implements XdsTransport { private final ManagedChannel channel; + private final CallCredentials callCredentials; public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) { + this(serverInfo, null); + } + + @VisibleForTesting + public GrpcXdsTransport(ManagedChannel channel) { + this(channel, null); + } + + public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) { String target = serverInfo.target(); ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig(); this.channel = Grpc.newChannelBuilder(target, channelCredentials) .keepAliveTime(5, TimeUnit.MINUTES) .build(); + this.callCredentials = callCredentials; } @VisibleForTesting - public GrpcXdsTransport(ManagedChannel channel) { + public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) { this.channel = checkNotNull(channel, "channel"); + this.callCredentials = callCredentials; } @Override @@ -72,7 +88,8 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory { MethodDescriptor.Marshaller respMarshaller) { Context prevContext = Context.ROOT.attach(); try { - return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller); + return new XdsStreamingCall<>( + fullMethodName, reqMarshaller, respMarshaller, callCredentials); } finally { Context.ROOT.detach(prevContext); } @@ -89,16 +106,21 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory { private final ClientCall call; - public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller reqMarshaller, - MethodDescriptor.Marshaller respMarshaller) { - this.call = channel.newCall( - MethodDescriptor.newBuilder() - .setFullMethodName(methodName) - .setType(MethodDescriptor.MethodType.BIDI_STREAMING) - .setRequestMarshaller(reqMarshaller) - .setResponseMarshaller(respMarshaller) - .build(), - CallOptions.DEFAULT); // TODO(zivy): support waitForReady + public XdsStreamingCall( + String methodName, + MethodDescriptor.Marshaller reqMarshaller, + MethodDescriptor.Marshaller respMarshaller, + CallCredentials callCredentials) { + this.call = + channel.newCall( + MethodDescriptor.newBuilder() + .setFullMethodName(methodName) + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setRequestMarshaller(reqMarshaller) + .setResponseMarshaller(respMarshaller) + .build(), + CallOptions.DEFAULT.withCallCredentials( + callCredentials)); // TODO(zivy): support waitForReady } @Override diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index 85b59fabfa..9c98bba93c 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import io.grpc.CallCredentials; import io.grpc.Internal; import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; @@ -42,6 +43,13 @@ public final class InternalSharedXdsClientPoolProvider { public static ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) throws XdsInitializationException { - return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target, metricRecorder); + return getOrCreate(target, metricRecorder, null); + } + + public static ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { + return SharedXdsClientPoolProvider.getDefaultProvider() + .getOrCreate(target, metricRecorder, transportCallCredentials); } } diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index 2bc7be4a01..5302880d48 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -17,11 +17,11 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.CallCredentials; import io.grpc.MetricRecorder; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; @@ -87,6 +87,12 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { @Override public ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) throws XdsInitializationException { + return getOrCreate(target, metricRecorder, null); + } + + public ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { ObjectPool ref = targetToXdsClientMap.get(target); if (ref == null) { synchronized (lock) { @@ -102,7 +108,9 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { if (bootstrapInfo.servers().isEmpty()) { throw new XdsInitializationException("No xDS server provided"); } - ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target, metricRecorder); + ref = + new RefCountedXdsClientObjectPool( + bootstrapInfo, target, metricRecorder, transportCallCredentials); targetToXdsClientMap.put(target, ref); } } @@ -126,6 +134,7 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { private final BootstrapInfo bootstrapInfo; private final String target; // The target associated with the xDS client. private final MetricRecorder metricRecorder; + private final CallCredentials transportCallCredentials; private final Object lock = new Object(); @GuardedBy("lock") private ScheduledExecutorService scheduler; @@ -137,11 +146,21 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { private XdsClientMetricReporterImpl metricReporter; @VisibleForTesting - RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target, - MetricRecorder metricRecorder) { + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) { + this(bootstrapInfo, target, metricRecorder, null); + } + + @VisibleForTesting + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, + String target, + MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { this.bootstrapInfo = checkNotNull(bootstrapInfo); this.target = target; this.metricRecorder = metricRecorder; + this.transportCallCredentials = transportCallCredentials; } @Override @@ -153,16 +172,19 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { } scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target); - xdsClient = new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, - bootstrapInfo, - scheduler, - BACKOFF_POLICY_PROVIDER, - GrpcUtil.STOPWATCH_SUPPLIER, - TimeProvider.SYSTEM_TIME_PROVIDER, - MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo), - metricReporter); + GrpcXdsTransportFactory xdsTransportFactory = + new GrpcXdsTransportFactory(transportCallCredentials); + xdsClient = + new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + scheduler, + BACKOFF_POLICY_PROVIDER, + GrpcUtil.STOPWATCH_SUPPLIER, + TimeProvider.SYSTEM_TIME_PROVIDER, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + metricReporter); metricReporter.setXdsClient(xdsClient); } refCount++; diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index 51c07cb353..36131464d0 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -18,7 +18,6 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; @@ -4193,7 +4192,7 @@ public abstract class GrpcXdsClientImplTestBase { private XdsClientImpl createXdsClient(String serverUri) { BootstrapInfo bootstrapInfo = buildBootStrap(serverUri); return new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, + new GrpcXdsTransportFactory(null), bootstrapInfo, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 703e429fa2..66e0d4b319 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -92,9 +92,10 @@ public class GrpcXdsTransportFactoryTest { @Test public void callApis() throws Exception { XdsTransportFactory.XdsTransport xdsTransport = - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create( - Bootstrapper.ServerInfo.create("localhost:" + server.getPort(), - InsecureChannelCredentials.create())); + new GrpcXdsTransportFactory(null) + .create( + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create())); MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); XdsTransportFactory.StreamingCall streamingCall = diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index c11a3a6e0d..9bdf86132b 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -178,11 +178,15 @@ public class LoadReportClientTest { when(backoffPolicy2.nextBackoffNanos()) .thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L)); addFakeStatsData(); - lrsClient = new LoadReportClient(loadStatsManager, - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel), - NODE, - syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, - fakeClock.getStopwatchSupplier()); + lrsClient = + new LoadReportClient( + loadStatsManager, + new GrpcXdsTransportFactory(null).createForTest(channel), + NODE, + syncContext, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); syncContext.execute(new Runnable() { @Override public void run() { diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 4fb77f0be4..86e4fc83a8 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -18,20 +18,36 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.CallCredentials; +import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.Metadata; import io.grpc.MetricRecorder; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.ObjectPool; import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.client.XdsInitializationException; import java.util.Collections; +import java.util.concurrent.TimeUnit; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -54,9 +70,12 @@ public class SharedXdsClientPoolProviderTest { private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); private final MetricRecorder metricRecorder = new MetricRecorder() {}; private static final String DUMMY_TARGET = "dummy"; + static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); @Mock private GrpcBootstrapperImpl bootstrapper; + @Mock private ResourceWatcher ldsResourceWatcher; @Test public void noServer() throws XdsInitializationException { @@ -138,4 +157,62 @@ public class SharedXdsClientPoolProviderTest { assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1); xdsClientPool.returnObject(xdsClient2); } + + private class CallCredsServerInterceptor implements ServerInterceptor { + private SettableFuture tokenFuture = SettableFuture.create(); + + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler next) { + tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY)); + return next.startCall(serverCall, metadata); + } + + public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception { + return tokenFuture.get(timeout, unit); + } + } + + @Test + public void xdsClient_usesCallCredentials() throws Exception { + // Set up fake xDS server + XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService(); + CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor(); + Server xdsServer = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(fakeXdsService) + .intercept(callCredentialsInterceptor) + .build() + .start(); + String xdsServerUri = "localhost:" + xdsServer.getPort(); + + // Set up bootstrap & xDS client pool provider + ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); + BootstrapInfo bootstrapInfo = + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); + + // Create custom xDS transport CallCredentials + CallCredentials sampleCreds = + MoreCallCredentials.from( + OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null))); + + // Create xDS client that uses the CallCredentials on the transport + ObjectPool xdsClientPool = + provider.getOrCreate("target", metricRecorder, sampleCreds); + XdsClient xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); + + // Wait for xDS server to get the request and verify that it received the CallCredentials + assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS)) + .isEqualTo("Bearer token"); + + // Clean up + xdsClientPool.returnObject(xdsClient); + xdsServer.shutdownNow(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java index 97c2695f20..036b9f6f55 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java @@ -18,7 +18,6 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -442,9 +441,14 @@ public class XdsClientFallbackTest { String garbageUri = "some. garbage"; String validUri = "localhost:" + mainXdsServer.getServer().getPort(); - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(garbageUri, validUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri, validUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS); @@ -462,9 +466,14 @@ public class XdsClientFallbackTest { String garbageUri = "some. garbage"; String validUri = "localhost:" + mainXdsServer.getServer().getPort(); - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(validUri, garbageUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(validUri, garbageUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); verify(ldsWatcher, timeout(5000)).onChanged( @@ -481,9 +490,14 @@ public class XdsClientFallbackTest { String garbageUri1 = "some. garbage"; String garbageUri2 = "other garbage"; - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(garbageUri1, garbageUri2), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri1, garbageUri2), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS);