xds: add support for custom per-target credentials on the transport (#11951)

This commit is contained in:
Ashley Zhang 2025-03-21 15:19:40 -07:00 committed by GitHub
parent 94f8e93691
commit 1958e42370
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 198 additions and 51 deletions

View File

@ -19,6 +19,7 @@ package io.grpc.xds;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.ChannelCredentials;
import io.grpc.ClientCall;
@ -34,35 +35,50 @@ import java.util.concurrent.TimeUnit;
final class GrpcXdsTransportFactory implements XdsTransportFactory {
static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY =
new GrpcXdsTransportFactory();
private final CallCredentials callCredentials;
GrpcXdsTransportFactory(CallCredentials callCredentials) {
this.callCredentials = callCredentials;
}
@Override
public XdsTransport create(Bootstrapper.ServerInfo serverInfo) {
return new GrpcXdsTransport(serverInfo);
return new GrpcXdsTransport(serverInfo, callCredentials);
}
@VisibleForTesting
public XdsTransport createForTest(ManagedChannel channel) {
return new GrpcXdsTransport(channel);
return new GrpcXdsTransport(channel, callCredentials);
}
@VisibleForTesting
static class GrpcXdsTransport implements XdsTransport {
private final ManagedChannel channel;
private final CallCredentials callCredentials;
public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) {
this(serverInfo, null);
}
@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
this(channel, null);
}
public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) {
String target = serverInfo.target();
ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig();
this.channel = Grpc.newChannelBuilder(target, channelCredentials)
.keepAliveTime(5, TimeUnit.MINUTES)
.build();
this.callCredentials = callCredentials;
}
@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) {
this.channel = checkNotNull(channel, "channel");
this.callCredentials = callCredentials;
}
@Override
@ -72,7 +88,8 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory {
MethodDescriptor.Marshaller<RespT> respMarshaller) {
Context prevContext = Context.ROOT.attach();
try {
return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller);
return new XdsStreamingCall<>(
fullMethodName, reqMarshaller, respMarshaller, callCredentials);
} finally {
Context.ROOT.detach(prevContext);
}
@ -89,16 +106,21 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory {
private final ClientCall<ReqT, RespT> call;
public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller<ReqT> reqMarshaller,
MethodDescriptor.Marshaller<RespT> respMarshaller) {
this.call = channel.newCall(
MethodDescriptor.<ReqT, RespT>newBuilder()
.setFullMethodName(methodName)
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setRequestMarshaller(reqMarshaller)
.setResponseMarshaller(respMarshaller)
.build(),
CallOptions.DEFAULT); // TODO(zivy): support waitForReady
public XdsStreamingCall(
String methodName,
MethodDescriptor.Marshaller<ReqT> reqMarshaller,
MethodDescriptor.Marshaller<RespT> respMarshaller,
CallCredentials callCredentials) {
this.call =
channel.newCall(
MethodDescriptor.<ReqT, RespT>newBuilder()
.setFullMethodName(methodName)
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setRequestMarshaller(reqMarshaller)
.setResponseMarshaller(respMarshaller)
.build(),
CallOptions.DEFAULT.withCallCredentials(
callCredentials)); // TODO(zivy): support waitForReady
}
@Override

View File

@ -16,6 +16,7 @@
package io.grpc.xds;
import io.grpc.CallCredentials;
import io.grpc.Internal;
import io.grpc.MetricRecorder;
import io.grpc.internal.ObjectPool;
@ -42,6 +43,13 @@ public final class InternalSharedXdsClientPoolProvider {
public static ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
throws XdsInitializationException {
return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target, metricRecorder);
return getOrCreate(target, metricRecorder, null);
}
public static ObjectPool<XdsClient> getOrCreate(
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
throws XdsInitializationException {
return SharedXdsClientPoolProvider.getDefaultProvider()
.getOrCreate(target, metricRecorder, transportCallCredentials);
}
}

View File

@ -17,11 +17,11 @@
package io.grpc.xds;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.grpc.CallCredentials;
import io.grpc.MetricRecorder;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.GrpcUtil;
@ -87,6 +87,12 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory {
@Override
public ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
throws XdsInitializationException {
return getOrCreate(target, metricRecorder, null);
}
public ObjectPool<XdsClient> getOrCreate(
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
throws XdsInitializationException {
ObjectPool<XdsClient> ref = targetToXdsClientMap.get(target);
if (ref == null) {
synchronized (lock) {
@ -102,7 +108,9 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory {
if (bootstrapInfo.servers().isEmpty()) {
throw new XdsInitializationException("No xDS server provided");
}
ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target, metricRecorder);
ref =
new RefCountedXdsClientObjectPool(
bootstrapInfo, target, metricRecorder, transportCallCredentials);
targetToXdsClientMap.put(target, ref);
}
}
@ -126,6 +134,7 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory {
private final BootstrapInfo bootstrapInfo;
private final String target; // The target associated with the xDS client.
private final MetricRecorder metricRecorder;
private final CallCredentials transportCallCredentials;
private final Object lock = new Object();
@GuardedBy("lock")
private ScheduledExecutorService scheduler;
@ -137,11 +146,21 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory {
private XdsClientMetricReporterImpl metricReporter;
@VisibleForTesting
RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target,
MetricRecorder metricRecorder) {
RefCountedXdsClientObjectPool(
BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) {
this(bootstrapInfo, target, metricRecorder, null);
}
@VisibleForTesting
RefCountedXdsClientObjectPool(
BootstrapInfo bootstrapInfo,
String target,
MetricRecorder metricRecorder,
CallCredentials transportCallCredentials) {
this.bootstrapInfo = checkNotNull(bootstrapInfo);
this.target = target;
this.metricRecorder = metricRecorder;
this.transportCallCredentials = transportCallCredentials;
}
@Override
@ -153,16 +172,19 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory {
}
scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE);
metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target);
xdsClient = new XdsClientImpl(
DEFAULT_XDS_TRANSPORT_FACTORY,
bootstrapInfo,
scheduler,
BACKOFF_POLICY_PROVIDER,
GrpcUtil.STOPWATCH_SUPPLIER,
TimeProvider.SYSTEM_TIME_PROVIDER,
MessagePrinter.INSTANCE,
new TlsContextManagerImpl(bootstrapInfo),
metricReporter);
GrpcXdsTransportFactory xdsTransportFactory =
new GrpcXdsTransportFactory(transportCallCredentials);
xdsClient =
new XdsClientImpl(
xdsTransportFactory,
bootstrapInfo,
scheduler,
BACKOFF_POLICY_PROVIDER,
GrpcUtil.STOPWATCH_SUPPLIER,
TimeProvider.SYSTEM_TIME_PROVIDER,
MessagePrinter.INSTANCE,
new TlsContextManagerImpl(bootstrapInfo),
metricReporter);
metricReporter.setXdsClient(xdsClient);
}
refCount++;

View File

@ -18,7 +18,6 @@ package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
@ -4193,7 +4192,7 @@ public abstract class GrpcXdsClientImplTestBase {
private XdsClientImpl createXdsClient(String serverUri) {
BootstrapInfo bootstrapInfo = buildBootStrap(serverUri);
return new XdsClientImpl(
DEFAULT_XDS_TRANSPORT_FACTORY,
new GrpcXdsTransportFactory(null),
bootstrapInfo,
fakeClock.getScheduledExecutorService(),
backoffPolicyProvider,

View File

@ -92,9 +92,10 @@ public class GrpcXdsTransportFactoryTest {
@Test
public void callApis() throws Exception {
XdsTransportFactory.XdsTransport xdsTransport =
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create(
Bootstrapper.ServerInfo.create("localhost:" + server.getPort(),
InsecureChannelCredentials.create()));
new GrpcXdsTransportFactory(null)
.create(
Bootstrapper.ServerInfo.create(
"localhost:" + server.getPort(), InsecureChannelCredentials.create()));
MethodDescriptor<DiscoveryRequest, DiscoveryResponse> methodDescriptor =
AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod();
XdsTransportFactory.StreamingCall<DiscoveryRequest, DiscoveryResponse> streamingCall =

View File

@ -178,11 +178,15 @@ public class LoadReportClientTest {
when(backoffPolicy2.nextBackoffNanos())
.thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L));
addFakeStatsData();
lrsClient = new LoadReportClient(loadStatsManager,
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel),
NODE,
syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider,
fakeClock.getStopwatchSupplier());
lrsClient =
new LoadReportClient(
loadStatsManager,
new GrpcXdsTransportFactory(null).createForTest(channel),
NODE,
syncContext,
fakeClock.getScheduledExecutorService(),
backoffPolicyProvider,
fakeClock.getStopwatchSupplier());
syncContext.execute(new Runnable() {
@Override
public void run() {

View File

@ -18,20 +18,36 @@ package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.Metadata.ASCII_STRING_MARSHALLER;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.Metadata;
import io.grpc.MetricRecorder;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.ObjectPool;
import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool;
import io.grpc.xds.XdsListenerResource.LdsUpdate;
import io.grpc.xds.client.Bootstrapper.BootstrapInfo;
import io.grpc.xds.client.Bootstrapper.ServerInfo;
import io.grpc.xds.client.EnvoyProtoData.Node;
import io.grpc.xds.client.XdsClient;
import io.grpc.xds.client.XdsClient.ResourceWatcher;
import io.grpc.xds.client.XdsInitializationException;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@ -54,9 +70,12 @@ public class SharedXdsClientPoolProviderTest {
private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build();
private final MetricRecorder metricRecorder = new MetricRecorder() {};
private static final String DUMMY_TARGET = "dummy";
static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY =
Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER);
@Mock
private GrpcBootstrapperImpl bootstrapper;
@Mock private ResourceWatcher<LdsUpdate> ldsResourceWatcher;
@Test
public void noServer() throws XdsInitializationException {
@ -138,4 +157,62 @@ public class SharedXdsClientPoolProviderTest {
assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1);
xdsClientPool.returnObject(xdsClient2);
}
private class CallCredsServerInterceptor implements ServerInterceptor {
private SettableFuture<String> tokenFuture = SettableFuture.create();
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> serverCall,
Metadata metadata,
ServerCallHandler<ReqT, RespT> next) {
tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY));
return next.startCall(serverCall, metadata);
}
public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception {
return tokenFuture.get(timeout, unit);
}
}
@Test
public void xdsClient_usesCallCredentials() throws Exception {
// Set up fake xDS server
XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService();
CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor();
Server xdsServer =
Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())
.addService(fakeXdsService)
.intercept(callCredentialsInterceptor)
.build()
.start();
String xdsServerUri = "localhost:" + xdsServer.getPort();
// Set up bootstrap & xDS client pool provider
ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create());
BootstrapInfo bootstrapInfo =
BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build();
when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo);
SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper);
// Create custom xDS transport CallCredentials
CallCredentials sampleCreds =
MoreCallCredentials.from(
OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null)));
// Create xDS client that uses the CallCredentials on the transport
ObjectPool<XdsClient> xdsClientPool =
provider.getOrCreate("target", metricRecorder, sampleCreds);
XdsClient xdsClient = xdsClientPool.getObject();
xdsClient.watchXdsResource(
XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher);
// Wait for xDS server to get the request and verify that it received the CallCredentials
assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS))
.isEqualTo("Bearer token");
// Clean up
xdsClientPool.returnObject(xdsClient);
xdsServer.shutdownNow();
}
}

View File

@ -18,7 +18,6 @@ package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
@ -442,9 +441,14 @@ public class XdsClientFallbackTest {
String garbageUri = "some. garbage";
String validUri = "localhost:" + mainXdsServer.getServer().getPort();
XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient(
Arrays.asList(garbageUri, validUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock,
new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter);
XdsClientImpl client =
CommonBootstrapperTestUtils.createXdsClient(
Arrays.asList(garbageUri, validUri),
new GrpcXdsTransportFactory(null),
fakeClock,
new ExponentialBackoffPolicy.Provider(),
MessagePrinter.INSTANCE,
xdsClientMetricReporter);
client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher);
fakeClock.forwardTime(20, TimeUnit.SECONDS);
@ -462,9 +466,14 @@ public class XdsClientFallbackTest {
String garbageUri = "some. garbage";
String validUri = "localhost:" + mainXdsServer.getServer().getPort();
XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient(
Arrays.asList(validUri, garbageUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock,
new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter);
XdsClientImpl client =
CommonBootstrapperTestUtils.createXdsClient(
Arrays.asList(validUri, garbageUri),
new GrpcXdsTransportFactory(null),
fakeClock,
new ExponentialBackoffPolicy.Provider(),
MessagePrinter.INSTANCE,
xdsClientMetricReporter);
client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher);
verify(ldsWatcher, timeout(5000)).onChanged(
@ -481,9 +490,14 @@ public class XdsClientFallbackTest {
String garbageUri1 = "some. garbage";
String garbageUri2 = "other garbage";
XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient(
Arrays.asList(garbageUri1, garbageUri2), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock,
new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter);
XdsClientImpl client =
CommonBootstrapperTestUtils.createXdsClient(
Arrays.asList(garbageUri1, garbageUri2),
new GrpcXdsTransportFactory(null),
fakeClock,
new ExponentialBackoffPolicy.Provider(),
MessagePrinter.INSTANCE,
xdsClientMetricReporter);
client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher);
fakeClock.forwardTime(20, TimeUnit.SECONDS);