From c506190b0f1f19c389d9ef797d6c69fb41167850 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Thu, 30 Jan 2025 12:43:51 -0800 Subject: [PATCH] xds: Reuse filter interceptors across RPCs This moves the interceptor creation from the ConfigSelector to the resource update handling. The code structure changes will make adding support for filter lifecycles (for RLQS) a bit easier. The filter lifecycles will allow filters to share state across interceptors, and constructing all the interceptors on a single thread will mean filters wouldn't need to be thread-safe (but their interceptors would be thread-safe). --- .../main/java/io/grpc/xds/FaultFilter.java | 3 +- xds/src/main/java/io/grpc/xds/Filter.java | 3 +- .../io/grpc/xds/GcpAuthenticationFilter.java | 4 +- .../main/java/io/grpc/xds/RouterFilter.java | 3 +- .../java/io/grpc/xds/XdsNameResolver.java | 201 ++++++++++++------ .../grpc/xds/GcpAuthenticationFilterTest.java | 2 +- .../grpc/xds/GrpcXdsClientImplDataTest.java | 2 - 7 files changed, 136 insertions(+), 82 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index b7f7fa9c22..c66861a9f1 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -37,7 +37,6 @@ import io.grpc.Context; import io.grpc.Deadline; import io.grpc.ForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -183,7 +182,7 @@ final class FaultFilter implements Filter, ClientInterceptorBuilder { @Nullable @Override public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, + FilterConfig config, @Nullable FilterConfig overrideConfig, final ScheduledExecutorService scheduler) { checkNotNull(config, "config"); if (overrideConfig != null) { diff --git a/xds/src/main/java/io/grpc/xds/Filter.java b/xds/src/main/java/io/grpc/xds/Filter.java index 4b2767687f..29f8cc4e33 100644 --- a/xds/src/main/java/io/grpc/xds/Filter.java +++ b/xds/src/main/java/io/grpc/xds/Filter.java @@ -19,7 +19,6 @@ package io.grpc.xds; import com.google.common.base.MoreObjects; import com.google.protobuf.Message; import io.grpc.ClientInterceptor; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.ServerInterceptor; import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; @@ -59,7 +58,7 @@ interface Filter { interface ClientInterceptorBuilder { @Nullable ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, + FilterConfig config, @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler); } diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index 6d05e8ffa9..f73494d74d 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -31,7 +31,6 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.CompositeCallCredentials; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -97,8 +96,7 @@ final class GcpAuthenticationFilter implements Filter, ClientInterceptorBuilder @Nullable @Override public ClientInterceptor buildClientInterceptor(FilterConfig config, - @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, - ScheduledExecutorService scheduler) { + @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) { ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); LruCache callCredentialsCache = diff --git a/xds/src/main/java/io/grpc/xds/RouterFilter.java b/xds/src/main/java/io/grpc/xds/RouterFilter.java index 7f1adf86a6..8038c1b98a 100644 --- a/xds/src/main/java/io/grpc/xds/RouterFilter.java +++ b/xds/src/main/java/io/grpc/xds/RouterFilter.java @@ -18,7 +18,6 @@ package io.grpc.xds; import com.google.protobuf.Message; import io.grpc.ClientInterceptor; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.ServerInterceptor; import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.Filter.ServerInterceptorBuilder; @@ -64,7 +63,7 @@ enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptor @Nullable @Override public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, + FilterConfig config, @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) { return null; } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 3c7f4455fd..5ac5376337 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -59,6 +59,7 @@ import io.grpc.xds.VirtualHost.Route.RouteAction; import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; +import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; @@ -384,20 +385,17 @@ final class XdsNameResolver extends NameResolver { @Override public Result selectConfig(PickSubchannelArgs args) { String cluster = null; - Route selectedRoute = null; + ClientInterceptor filters = null; // null iff cluster is null + RouteData selectedRoute = null; RoutingConfig routingCfg; - Map selectedOverrideConfigs; - List filterInterceptors = new ArrayList<>(); Metadata headers = args.getHeaders(); do { routingCfg = routingConfig; - selectedOverrideConfigs = new HashMap<>(routingCfg.virtualHostOverrideConfig); - for (Route route : routingCfg.routes) { + for (RouteData route : routingCfg.routes) { if (RoutingUtils.matchRoute( - route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(), - headers, random)) { + route.routeMatch, "/" + args.getMethodDescriptor().getFullMethodName(), + headers, random)) { selectedRoute = route; - selectedOverrideConfigs.putAll(route.filterConfigOverrides()); break; } } @@ -405,13 +403,14 @@ final class XdsNameResolver extends NameResolver { return Result.forError( Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC")); } - if (selectedRoute.routeAction() == null) { + if (selectedRoute.routeAction == null) { return Result.forError(Status.UNAVAILABLE.withDescription( "Could not route RPC to Route with non-forwarding action")); } - RouteAction action = selectedRoute.routeAction(); + RouteAction action = selectedRoute.routeAction; if (action.cluster() != null) { cluster = prefixedClusterName(action.cluster()); + filters = selectedRoute.filterChoices.get(0); } else if (action.weightedClusters() != null) { long totalWeight = 0; for (ClusterWeight weightedCluster : action.weightedClusters()) { @@ -419,23 +418,25 @@ final class XdsNameResolver extends NameResolver { } long select = random.nextLong(totalWeight); long accumulator = 0; - for (ClusterWeight weightedCluster : action.weightedClusters()) { + for (int i = 0; i < action.weightedClusters().size(); i++) { + ClusterWeight weightedCluster = action.weightedClusters().get(i); accumulator += weightedCluster.weight(); if (select < accumulator) { cluster = prefixedClusterName(weightedCluster.name()); - selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides()); + filters = selectedRoute.filterChoices.get(i); break; } } } else if (action.namedClusterSpecifierPluginConfig() != null) { cluster = prefixedClusterSpecifierPluginName(action.namedClusterSpecifierPluginConfig().name()); + filters = selectedRoute.filterChoices.get(0); } } while (!retainCluster(cluster)); Long timeoutNanos = null; if (enableTimeout) { if (selectedRoute != null) { - timeoutNanos = selectedRoute.routeAction().timeoutNano(); + timeoutNanos = selectedRoute.routeAction.timeoutNano(); } if (timeoutNanos == null) { timeoutNanos = routingCfg.fallbackTimeoutNano; @@ -445,7 +446,7 @@ final class XdsNameResolver extends NameResolver { } } RetryPolicy retryPolicy = - selectedRoute == null ? null : selectedRoute.routeAction().retryPolicy(); + selectedRoute == null ? null : selectedRoute.routeAction.retryPolicy(); // TODO(chengyuanzhang): avoid service config generation and parsing for each call. Map rawServiceConfig = generateServiceConfigWithMethodConfig(timeoutNanos, retryPolicy); @@ -457,24 +458,9 @@ final class XdsNameResolver extends NameResolver { parsedServiceConfig.getError().augmentDescription( "Failed to parse service config (method config)")); } - if (routingCfg.filterChain != null) { - for (NamedFilterConfig namedFilter : routingCfg.filterChain) { - FilterConfig filterConfig = namedFilter.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ClientInterceptorBuilder) { - ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter) - .buildClientInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilter.name), - args, scheduler); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } - } - } - } final String finalCluster = cluster; - final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), headers); - Route finalSelectedRoute = selectedRoute; + final long hash = generateHash(selectedRoute.routeAction.hashPolicies(), headers); + RouteData finalSelectedRoute = selectedRoute; class ClusterSelectionInterceptor implements ClientInterceptor { @Override public ClientCall interceptCall( @@ -483,7 +469,7 @@ final class XdsNameResolver extends NameResolver { CallOptions callOptionsForCluster = callOptions.withOption(CLUSTER_SELECTION_KEY, finalCluster) .withOption(RPC_HASH_KEY, hash); - if (finalSelectedRoute.routeAction().autoHostRewrite()) { + if (finalSelectedRoute.routeAction.autoHostRewrite()) { callOptionsForCluster = callOptionsForCluster.withOption(AUTO_HOST_REWRITE_KEY, true); } return new SimpleForwardingClientCall( @@ -514,11 +500,11 @@ final class XdsNameResolver extends NameResolver { } } - filterInterceptors.add(new ClusterSelectionInterceptor()); return Result.newBuilder() .setConfig(config) - .setInterceptor(combineInterceptors(filterInterceptors)) + .setInterceptor(combineInterceptors( + ImmutableList.of(filters, new ClusterSelectionInterceptor()))) .build(); } @@ -584,8 +570,18 @@ final class XdsNameResolver extends NameResolver { } } + static final class PassthroughClientInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + } + private static ClientInterceptor combineInterceptors(final List interceptors) { - checkArgument(!interceptors.isEmpty(), "empty interceptors"); + if (interceptors.size() == 0) { + return new PassthroughClientInterceptor(); + } if (interceptors.size() == 1) { return interceptors.get(0); } @@ -722,6 +718,7 @@ final class XdsNameResolver extends NameResolver { } List routes = virtualHost.routes(); + ImmutableList.Builder routesData = ImmutableList.builder(); // Populate all clusters to which requests can be routed to through the virtual host. Set clusters = new HashSet<>(); @@ -732,26 +729,34 @@ final class XdsNameResolver extends NameResolver { for (Route route : routes) { RouteAction action = route.routeAction(); String prefixedName; - if (action != null) { - if (action.cluster() != null) { - prefixedName = prefixedClusterName(action.cluster()); + if (action == null) { + routesData.add(new RouteData(route.routeMatch(), null, ImmutableList.of())); + } else if (action.cluster() != null) { + prefixedName = prefixedClusterName(action.cluster()); + clusters.add(prefixedName); + clusterNameMap.put(prefixedName, action.cluster()); + ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); + routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); + } else if (action.weightedClusters() != null) { + ImmutableList.Builder filterList = ImmutableList.builder(); + for (ClusterWeight weightedCluster : action.weightedClusters()) { + prefixedName = prefixedClusterName(weightedCluster.name()); clusters.add(prefixedName); - clusterNameMap.put(prefixedName, action.cluster()); - } else if (action.weightedClusters() != null) { - for (ClusterWeight weighedCluster : action.weightedClusters()) { - prefixedName = prefixedClusterName(weighedCluster.name()); - clusters.add(prefixedName); - clusterNameMap.put(prefixedName, weighedCluster.name()); - } - } else if (action.namedClusterSpecifierPluginConfig() != null) { - PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config(); - if (pluginConfig instanceof RlsPluginConfig) { - prefixedName = prefixedClusterSpecifierPluginName( - action.namedClusterSpecifierPluginConfig().name()); - clusters.add(prefixedName); - rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig); - } + clusterNameMap.put(prefixedName, weightedCluster.name()); + filterList.add(createFilters(filterConfigs, virtualHost, route, weightedCluster)); } + routesData.add( + new RouteData(route.routeMatch(), route.routeAction(), filterList.build())); + } else if (action.namedClusterSpecifierPluginConfig() != null) { + PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config(); + if (pluginConfig instanceof RlsPluginConfig) { + prefixedName = prefixedClusterSpecifierPluginName( + action.namedClusterSpecifierPluginConfig().name()); + clusters.add(prefixedName); + rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig); + } + ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); + routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); } } @@ -796,10 +801,7 @@ final class XdsNameResolver extends NameResolver { } // Make newly added clusters selectable by config selector and deleted clusters no longer // selectable. - routingConfig = - new RoutingConfig( - httpMaxStreamDurationNano, routes, filterConfigs, - virtualHost.filterConfigOverrides()); + routingConfig = new RoutingConfig(httpMaxStreamDurationNano, routesData.build()); shouldUpdateResult = false; for (String cluster : deletedClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); @@ -813,6 +815,37 @@ final class XdsNameResolver extends NameResolver { } } + private ClientInterceptor createFilters( + @Nullable List filterConfigs, + VirtualHost virtualHost, + Route route, + @Nullable ClusterWeight weightedCluster) { + if (filterConfigs == null) { + return new PassthroughClientInterceptor(); + } + Map selectedOverrideConfigs = + new HashMap<>(virtualHost.filterConfigOverrides()); + selectedOverrideConfigs.putAll(route.filterConfigOverrides()); + if (weightedCluster != null) { + selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides()); + } + ImmutableList.Builder filterInterceptors = ImmutableList.builder(); + for (NamedFilterConfig namedFilter : filterConfigs) { + FilterConfig filterConfig = namedFilter.filterConfig; + Filter filter = filterRegistry.get(filterConfig.typeUrl()); + if (filter instanceof ClientInterceptorBuilder) { + ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter) + .buildClientInterceptor( + filterConfig, selectedOverrideConfigs.get(namedFilter.name), + scheduler); + if (interceptor != null) { + filterInterceptors.add(interceptor); + } + } + } + return combineInterceptors(filterInterceptors.build()); + } + private void cleanUpRoutes(String error) { if (existingClusters != null) { for (String cluster : existingClusters) { @@ -903,22 +936,50 @@ final class XdsNameResolver extends NameResolver { */ private static class RoutingConfig { private final long fallbackTimeoutNano; - final List routes; - // Null if HttpFilter is not supported. - @Nullable final List filterChain; - final Map virtualHostOverrideConfig; + final ImmutableList routes; - private static RoutingConfig empty = new RoutingConfig( - 0, Collections.emptyList(), null, Collections.emptyMap()); + private static RoutingConfig empty = new RoutingConfig(0, ImmutableList.of()); - private RoutingConfig( - long fallbackTimeoutNano, List routes, @Nullable List filterChain, - Map virtualHostOverrideConfig) { + private RoutingConfig(long fallbackTimeoutNano, ImmutableList routes) { this.fallbackTimeoutNano = fallbackTimeoutNano; - this.routes = routes; - checkArgument(filterChain == null || !filterChain.isEmpty(), "filterChain is empty"); - this.filterChain = filterChain == null ? null : Collections.unmodifiableList(filterChain); - this.virtualHostOverrideConfig = Collections.unmodifiableMap(virtualHostOverrideConfig); + this.routes = checkNotNull(routes, "routes"); + } + } + + static final class RouteData { + final RouteMatch routeMatch; + /** null implies non-forwarding action. */ + @Nullable + final RouteAction routeAction; + /** + * Only one of these interceptors should be used per-RPC. There are only multiple values in the + * list for weighted clusters, in which case the order of the list mirrors the weighted + * clusters. + */ + final ImmutableList filterChoices; + + RouteData(RouteMatch routeMatch, @Nullable RouteAction routeAction, ClientInterceptor filter) { + this(routeMatch, routeAction, ImmutableList.of(filter)); + } + + RouteData( + RouteMatch routeMatch, + @Nullable RouteAction routeAction, + ImmutableList filterChoices) { + this.routeMatch = checkNotNull(routeMatch, "routeMatch"); + checkArgument( + routeAction == null || !filterChoices.isEmpty(), + "filter may be empty only for non-forwarding action"); + this.routeAction = routeAction; + if (routeAction != null && routeAction.weightedClusters() != null) { + checkArgument( + routeAction.weightedClusters().size() == filterChoices.size(), + "filter choices must match size of weighted clusters"); + } + for (ClientInterceptor filter : filterChoices) { + checkNotNull(filter, "entry in filterChoices is null"); + } + this.filterChoices = checkNotNull(filterChoices, "filterChoices"); } } diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java index ddd244c855..3ca240ab7c 100644 --- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -92,7 +92,7 @@ public class GcpAuthenticationFilterTest { GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); // Create interceptor - ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, null); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); // Mock channel and capture CallOptions diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java index 654e85143b..5b9fdda112 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java @@ -113,7 +113,6 @@ import io.envoyproxy.envoy.type.v3.Int64Range; import io.grpc.ClientInterceptor; import io.grpc.EquivalentAddressGroup; import io.grpc.InsecureChannelCredentials; -import io.grpc.LoadBalancer; import io.grpc.LoadBalancerRegistry; import io.grpc.Status.Code; import io.grpc.internal.JsonUtil; @@ -1266,7 +1265,6 @@ public class GrpcXdsClientImplDataTest { @Override public ClientInterceptor buildClientInterceptor(FilterConfig config, @Nullable FilterConfig overrideConfig, - LoadBalancer.PickSubchannelArgs args, ScheduledExecutorService scheduler) { return null; }