From 2b87b01651fa5d99ab1b1d9cbc4f87af5a1f5d6a Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko Date: Tue, 18 Feb 2025 10:47:01 -0800 Subject: [PATCH] xds: Change how xDS filters are created by introducing Filter.Provider (#11883) This is the first step towards supporting filter state retention in Java. The mechanism will be similar to the one described in [A83] (https://github.com/grpc/proposal/blob/master/A83-xds-gcp-authn-filter.md#filter-call-credentials-cache) for C-core, and will serve the same purpose. However, the implementation details are very different due to the different nature of xDS HTTP filter support in C-core and Java. In Java, xDS HTTP filters are backed by classes implementing `io.grpc.xds.Filter`, from here just called "Filters". To support Filter state retention (next PR), Java's xDS implementation must be able to create unique Filter instances per: - Per HCM `envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager` - Per filter name as specified in `envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter.name` This PR **does not** implements Filter state retention, but lays the groundwork for it by changing how filters are registered and instantiated. To achieve this, all existing Filter classes had to be updated to the new instantiation mechanism described below. Prior to these this PR, Filters had no livecycle. FilterRegistry provided singleton instances for a given typeUrl. This PR introduces a new interface `Filter.Provider`, which instantiates Filter classes. All functionality that doesn't need an instance of a Filter is moved to the Filter.Provider. This includes parsing filter config proto into FilterConfig and determining the filter kind (client-side, server-side, or both). This PR is limited to refactoring, and there's no changes to the existing behavior. Note that all Filter Providers still return singleton Filter instances. However, with this PR, it is now possible to create Providers that return a new Filter instance each time `newInstance` is called. --- .../main/java/io/grpc/xds/FaultFilter.java | 178 ++++++++++-------- xds/src/main/java/io/grpc/xds/Filter.java | 94 ++++++--- .../main/java/io/grpc/xds/FilterRegistry.java | 16 +- .../io/grpc/xds/GcpAuthenticationFilter.java | 84 +++++---- .../java/io/grpc/xds/InternalRbacFilter.java | 7 +- xds/src/main/java/io/grpc/xds/RbacFilter.java | 162 ++++++++-------- .../main/java/io/grpc/xds/RouterFilter.java | 71 ++++--- .../java/io/grpc/xds/XdsListenerResource.java | 29 ++- .../java/io/grpc/xds/XdsNameResolver.java | 31 +-- .../grpc/xds/XdsRouteConfigureResource.java | 6 +- .../java/io/grpc/xds/XdsServerWrapper.java | 66 ++++--- .../java/io/grpc/xds/FaultFilterTest.java | 19 +- .../grpc/xds/GcpAuthenticationFilterTest.java | 25 ++- .../grpc/xds/GrpcXdsClientImplDataTest.java | 78 ++++---- .../test/java/io/grpc/xds/RbacFilterTest.java | 35 ++-- .../java/io/grpc/xds/RouterFilterTest.java | 36 ++++ .../java/io/grpc/xds/XdsNameResolverTest.java | 19 +- .../io/grpc/xds/XdsServerWrapperTest.java | 61 +++--- 18 files changed, 593 insertions(+), 424 deletions(-) create mode 100644 xds/src/test/java/io/grpc/xds/RouterFilterTest.java diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index c66861a9f1..2012fd36b6 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -45,7 +45,6 @@ import io.grpc.internal.DelayedClientCall; import io.grpc.internal.GrpcUtil; import io.grpc.xds.FaultConfig.FaultAbort; import io.grpc.xds.FaultConfig.FaultDelay; -import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import java.util.Locale; import java.util.concurrent.Executor; @@ -56,10 +55,11 @@ import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; /** HttpFault filter implementation. */ -final class FaultFilter implements Filter, ClientInterceptorBuilder { +final class FaultFilter implements Filter { - static final FaultFilter INSTANCE = + private static final FaultFilter INSTANCE = new FaultFilter(ThreadSafeRandomImpl.instance, new AtomicLong()); + @VisibleForTesting static final Metadata.Key HEADER_DELAY_KEY = Metadata.Key.of("x-envoy-fault-delay-request", Metadata.ASCII_STRING_MARSHALLER); @@ -87,96 +87,108 @@ final class FaultFilter implements Filter, ClientInterceptorBuilder { this.activeFaultCounter = activeFaultCounter; } - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL }; - } + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; + } - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - HTTPFault httpFaultProto; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + @Override + public boolean isClientFilter() { + return true; } - Any anyMessage = (Any) rawProtoMessage; - try { - httpFaultProto = anyMessage.unpack(HTTPFault.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); - } - return parseHttpFault(httpFaultProto); - } - private static ConfigOrError parseHttpFault(HTTPFault httpFault) { - FaultDelay faultDelay = null; - FaultAbort faultAbort = null; - if (httpFault.hasDelay()) { - faultDelay = parseFaultDelay(httpFault.getDelay()); + @Override + public FaultFilter newInstance() { + return INSTANCE; } - if (httpFault.hasAbort()) { - ConfigOrError faultAbortOrError = parseFaultAbort(httpFault.getAbort()); - if (faultAbortOrError.errorDetail != null) { - return ConfigOrError.fromError( - "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail); + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + HTTPFault httpFaultProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); } - faultAbort = faultAbortOrError.config; + Any anyMessage = (Any) rawProtoMessage; + try { + httpFaultProto = anyMessage.unpack(HTTPFault.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + return parseHttpFault(httpFaultProto); } - Integer maxActiveFaults = null; - if (httpFault.hasMaxActiveFaults()) { - maxActiveFaults = httpFault.getMaxActiveFaults().getValue(); - if (maxActiveFaults < 0) { - maxActiveFaults = Integer.MAX_VALUE; + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + return parseFilterConfig(rawProtoMessage); + } + + private static ConfigOrError parseHttpFault(HTTPFault httpFault) { + FaultDelay faultDelay = null; + FaultAbort faultAbort = null; + if (httpFault.hasDelay()) { + faultDelay = parseFaultDelay(httpFault.getDelay()); + } + if (httpFault.hasAbort()) { + ConfigOrError faultAbortOrError = parseFaultAbort(httpFault.getAbort()); + if (faultAbortOrError.errorDetail != null) { + return ConfigOrError.fromError( + "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail); + } + faultAbort = faultAbortOrError.config; + } + Integer maxActiveFaults = null; + if (httpFault.hasMaxActiveFaults()) { + maxActiveFaults = httpFault.getMaxActiveFaults().getValue(); + if (maxActiveFaults < 0) { + maxActiveFaults = Integer.MAX_VALUE; + } + } + return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults)); + } + + private static FaultDelay parseFaultDelay( + io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) { + FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage()); + if (faultDelay.hasHeaderDelay()) { + return FaultDelay.forHeader(percent); + } + return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent); + } + + @VisibleForTesting + static ConfigOrError parseFaultAbort( + io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) { + FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage()); + switch (faultAbort.getErrorTypeCase()) { + case HEADER_ABORT: + return ConfigOrError.fromConfig(FaultAbort.forHeader(percent)); + case HTTP_STATUS: + return ConfigOrError.fromConfig(FaultAbort.forStatus( + GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent)); + case GRPC_STATUS: + return ConfigOrError.fromConfig(FaultAbort.forStatus( + Status.fromCodeValue(faultAbort.getGrpcStatus()), percent)); + case ERRORTYPE_NOT_SET: + default: + return ConfigOrError.fromError( + "Unknown error type case: " + faultAbort.getErrorTypeCase()); } } - return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults)); - } - private static FaultDelay parseFaultDelay( - io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) { - FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage()); - if (faultDelay.hasHeaderDelay()) { - return FaultDelay.forHeader(percent); + private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) { + switch (proto.getDenominator()) { + case HUNDRED: + return FaultConfig.FractionalPercent.perHundred(proto.getNumerator()); + case TEN_THOUSAND: + return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator()); + case MILLION: + return FaultConfig.FractionalPercent.perMillion(proto.getNumerator()); + case UNRECOGNIZED: + default: + throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); + } } - return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent); - } - - @VisibleForTesting - static ConfigOrError parseFaultAbort( - io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) { - FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage()); - switch (faultAbort.getErrorTypeCase()) { - case HEADER_ABORT: - return ConfigOrError.fromConfig(FaultAbort.forHeader(percent)); - case HTTP_STATUS: - return ConfigOrError.fromConfig(FaultAbort.forStatus( - GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent)); - case GRPC_STATUS: - return ConfigOrError.fromConfig(FaultAbort.forStatus( - Status.fromCodeValue(faultAbort.getGrpcStatus()), percent)); - case ERRORTYPE_NOT_SET: - default: - return ConfigOrError.fromError( - "Unknown error type case: " + faultAbort.getErrorTypeCase()); - } - } - - private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) { - switch (proto.getDenominator()) { - case HUNDRED: - return FaultConfig.FractionalPercent.perHundred(proto.getNumerator()); - case TEN_THOUSAND: - return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator()); - case MILLION: - return FaultConfig.FractionalPercent.perMillion(proto.getNumerator()); - case UNRECOGNIZED: - default: - throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); - } - } - - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - return parseFilterConfig(rawProtoMessage); } @Nullable diff --git a/xds/src/main/java/io/grpc/xds/Filter.java b/xds/src/main/java/io/grpc/xds/Filter.java index 29f8cc4e33..ab61ba2b57 100644 --- a/xds/src/main/java/io/grpc/xds/Filter.java +++ b/xds/src/main/java/io/grpc/xds/Filter.java @@ -25,48 +25,82 @@ import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; /** - * Defines the parsing functionality of an HTTP filter. A Filter may optionally implement either - * {@link ClientInterceptorBuilder} or {@link ServerInterceptorBuilder} or both, indicating it is - * capable of working on the client side or server side or both, respectively. + * Defines the parsing functionality of an HTTP filter. + * + *

A Filter may optionally implement either {@link Filter#buildClientInterceptor} or + * {@link Filter#buildServerInterceptor} or both, and return true from corresponding + * {@link Provider#isClientFilter()}, {@link Provider#isServerFilter()} to indicate that the filter + * is capable of working on the client side or server side or both, respectively. */ interface Filter { - /** - * The proto message types supported by this filter. A filter will be registered by each of its - * supported message types. - */ - String[] typeUrls(); - - /** - * Parses the top-level filter config from raw proto message. The message may be either a {@link - * com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. - */ - ConfigOrError parseFilterConfig(Message rawProtoMessage); - - /** - * Parses the per-filter override filter config from raw proto message. The message may be either - * a {@link com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. - */ - ConfigOrError parseFilterConfigOverride(Message rawProtoMessage); - /** Represents an opaque data structure holding configuration for a filter. */ interface FilterConfig { String typeUrl(); } + /** + * Common interface for filter providers. + */ + interface Provider { + /** + * The proto message types supported by this filter. A filter will be registered by each of its + * supported message types. + */ + String[] typeUrls(); + + /** + * Whether the filter can be installed on the client side. + * + *

Returns true if the filter implements {@link Filter#buildClientInterceptor}. + */ + default boolean isClientFilter() { + return false; + } + + /** + * Whether the filter can be installed into xDS-enabled servers. + * + *

Returns true if the filter implements {@link Filter#buildServerInterceptor}. + */ + default boolean isServerFilter() { + return false; + } + + /** + * Creates a new instance of the filter. + * + *

Returns a filter instance registered with the same typeUrls as the provider, + * capable of working with the same FilterConfig type returned by provider's parse functions. + */ + Filter newInstance(); + + /** + * Parses the top-level filter config from raw proto message. The message may be either a {@link + * com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + */ + ConfigOrError parseFilterConfig(Message rawProtoMessage); + + /** + * Parses the per-filter override filter config from raw proto message. The message may be + * either a {@link com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + */ + ConfigOrError parseFilterConfigOverride(Message rawProtoMessage); + } + /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for clients. */ - interface ClientInterceptorBuilder { - @Nullable - ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, - ScheduledExecutorService scheduler); + @Nullable + default ClientInterceptor buildClientInterceptor( + FilterConfig config, @Nullable FilterConfig overrideConfig, + ScheduledExecutorService scheduler) { + return null; } /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for the server. */ - interface ServerInterceptorBuilder { - @Nullable - ServerInterceptor buildServerInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig); + @Nullable + default ServerInterceptor buildServerInterceptor( + FilterConfig config, @Nullable FilterConfig overrideConfig) { + return null; } /** Filter config with instance name. */ diff --git a/xds/src/main/java/io/grpc/xds/FilterRegistry.java b/xds/src/main/java/io/grpc/xds/FilterRegistry.java index 7f1fe82c6c..426c6d1b3f 100644 --- a/xds/src/main/java/io/grpc/xds/FilterRegistry.java +++ b/xds/src/main/java/io/grpc/xds/FilterRegistry.java @@ -23,21 +23,21 @@ import javax.annotation.Nullable; /** * A registry for all supported {@link Filter}s. Filters can be queried from the registry - * by any of the {@link Filter#typeUrls() type URLs}. + * by any of the {@link Filter.Provider#typeUrls() type URLs}. */ final class FilterRegistry { private static FilterRegistry instance; - private final Map supportedFilters = new HashMap<>(); + private final Map supportedFilters = new HashMap<>(); private FilterRegistry() {} static synchronized FilterRegistry getDefaultRegistry() { if (instance == null) { instance = newRegistry().register( - FaultFilter.INSTANCE, - RouterFilter.INSTANCE, - RbacFilter.INSTANCE); + new FaultFilter.Provider(), + new RouterFilter.Provider(), + new RbacFilter.Provider()); } return instance; } @@ -48,8 +48,8 @@ final class FilterRegistry { } @VisibleForTesting - FilterRegistry register(Filter... filters) { - for (Filter filter : filters) { + FilterRegistry register(Filter.Provider... filters) { + for (Filter.Provider filter : filters) { for (String typeUrl : filter.typeUrls()) { supportedFilters.put(typeUrl, filter); } @@ -58,7 +58,7 @@ final class FilterRegistry { } @Nullable - Filter get(String typeUrl) { + Filter.Provider get(String typeUrl) { return supportedFilters.get(typeUrl); } } diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index f73494d74d..7ed617c984 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -35,7 +35,6 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; import io.grpc.auth.MoreCallCredentials; -import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.MetadataRegistry.MetadataValueParser; import java.util.LinkedHashMap; import java.util.Map; @@ -47,50 +46,63 @@ import javax.annotation.Nullable; * A {@link Filter} that injects a {@link CallCredentials} to handle * authentication for xDS credentials. */ -final class GcpAuthenticationFilter implements Filter, ClientInterceptorBuilder { +final class GcpAuthenticationFilter implements Filter { static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL }; - } - - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - GcpAuthnFilterConfig gcpAuthnProto; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); - } - Any anyMessage = (Any) rawProtoMessage; - - try { - gcpAuthnProto = anyMessage.unpack(GcpAuthnFilterConfig.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; } - long cacheSize = 10; - // Validate cache_config - if (gcpAuthnProto.hasCacheConfig()) { - TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig(); - cacheSize = cacheConfig.getCacheSize().getValue(); - if (cacheSize == 0) { - return ConfigOrError.fromError( - "cache_config.cache_size must be greater than zero"); + @Override + public boolean isClientFilter() { + return true; + } + + @Override + public GcpAuthenticationFilter newInstance() { + return new GcpAuthenticationFilter(); + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + GcpAuthnFilterConfig gcpAuthnProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); } - // LruCache's size is an int and briefly exceeds its maximum size before evicting entries - cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1); + Any anyMessage = (Any) rawProtoMessage; + + try { + gcpAuthnProto = anyMessage.unpack(GcpAuthnFilterConfig.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + + long cacheSize = 10; + // Validate cache_config + if (gcpAuthnProto.hasCacheConfig()) { + TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig(); + cacheSize = cacheConfig.getCacheSize().getValue(); + if (cacheSize == 0) { + return ConfigOrError.fromError( + "cache_config.cache_size must be greater than zero"); + } + // LruCache's size is an int and briefly exceeds its maximum size before evicting entries + cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1); + } + + GcpAuthenticationConfig config = new GcpAuthenticationConfig((int) cacheSize); + return ConfigOrError.fromConfig(config); } - GcpAuthenticationConfig config = new GcpAuthenticationConfig((int) cacheSize); - return ConfigOrError.fromConfig(config); - } - - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - return parseFilterConfig(rawProtoMessage); + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage) { + return parseFilterConfig(rawProtoMessage); + } } @Nullable diff --git a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java index 54e6c748cd..cedb3f4c85 100644 --- a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java @@ -19,8 +19,6 @@ package io.grpc.xds; import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC; import io.grpc.Internal; import io.grpc.ServerInterceptor; -import io.grpc.xds.RbacConfig; -import io.grpc.xds.RbacFilter; /** This class exposes some functionality in RbacFilter to other packages. */ @Internal @@ -30,11 +28,12 @@ public final class InternalRbacFilter { /** Parses RBAC filter config and creates AuthorizationServerInterceptor. */ public static ServerInterceptor createInterceptor(RBAC rbac) { - ConfigOrError filterConfig = RbacFilter.parseRbacConfig(rbac); + ConfigOrError filterConfig = RbacFilter.Provider.parseRbacConfig(rbac); if (filterConfig.errorDetail != null) { throw new IllegalArgumentException( String.format("Failed to parse Rbac policy: %s", filterConfig.errorDetail)); } - return new RbacFilter().buildServerInterceptor(filterConfig.config, null); + return new RbacFilter.Provider().newInstance() + .buildServerInterceptor(filterConfig.config, null); } } diff --git a/xds/src/main/java/io/grpc/xds/RbacFilter.java b/xds/src/main/java/io/grpc/xds/RbacFilter.java index 6a55f7f193..2bc4eeb846 100644 --- a/xds/src/main/java/io/grpc/xds/RbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/RbacFilter.java @@ -18,7 +18,6 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -34,7 +33,6 @@ import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.internal.MatcherParser; import io.grpc.xds.internal.Matchers; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine; @@ -66,10 +64,10 @@ import java.util.stream.Collectors; import javax.annotation.Nullable; /** RBAC Http filter implementation. */ -final class RbacFilter implements Filter, ServerInterceptorBuilder { +final class RbacFilter implements Filter { private static final Logger logger = Logger.getLogger(RbacFilter.class.getName()); - static final RbacFilter INSTANCE = new RbacFilter(); + private static final RbacFilter INSTANCE = new RbacFilter(); static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBAC"; @@ -77,87 +75,99 @@ final class RbacFilter implements Filter, ServerInterceptorBuilder { private static final String TYPE_URL_OVERRIDE_CONFIG = "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBACPerRoute"; - RbacFilter() {} + private RbacFilter() {} - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL, TYPE_URL_OVERRIDE_CONFIG }; - } + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[] {TYPE_URL, TYPE_URL_OVERRIDE_CONFIG}; + } - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - RBAC rbacProto; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + @Override + public boolean isServerFilter() { + return true; } - Any anyMessage = (Any) rawProtoMessage; - try { - rbacProto = anyMessage.unpack(RBAC.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); - } - return parseRbacConfig(rbacProto); - } - @VisibleForTesting - static ConfigOrError parseRbacConfig(RBAC rbac) { - if (!rbac.hasRules()) { - return ConfigOrError.fromConfig(RbacConfig.create(null)); + @Override + public RbacFilter newInstance() { + return INSTANCE; } - io.envoyproxy.envoy.config.rbac.v3.RBAC rbacConfig = rbac.getRules(); - GrpcAuthorizationEngine.Action authAction; - switch (rbacConfig.getAction()) { - case ALLOW: - authAction = GrpcAuthorizationEngine.Action.ALLOW; - break; - case DENY: - authAction = GrpcAuthorizationEngine.Action.DENY; - break; - case LOG: - return ConfigOrError.fromConfig(RbacConfig.create(null)); - case UNRECOGNIZED: - default: - return ConfigOrError.fromError("Unknown rbacConfig action type: " + rbacConfig.getAction()); - } - List policyMatchers = new ArrayList<>(); - List> sortedPolicyEntries = rbacConfig.getPoliciesMap().entrySet() - .stream() - .sorted((a,b) -> a.getKey().compareTo(b.getKey())) - .collect(Collectors.toList()); - for (Map.Entry entry: sortedPolicyEntries) { + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + RBAC rbacProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; try { - Policy policy = entry.getValue(); - if (policy.hasCondition() || policy.hasCheckedCondition()) { - return ConfigOrError.fromError( - "Policy.condition and Policy.checked_condition must not set: " + entry.getKey()); - } - policyMatchers.add(PolicyMatcher.create(entry.getKey(), - parsePermissionList(policy.getPermissionsList()), - parsePrincipalList(policy.getPrincipalsList()))); - } catch (Exception e) { - return ConfigOrError.fromError("Encountered error parsing policy: " + e); + rbacProto = anyMessage.unpack(RBAC.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + return parseRbacConfig(rbacProto); + } + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + RBACPerRoute rbacPerRoute; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + rbacPerRoute = anyMessage.unpack(RBACPerRoute.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + if (rbacPerRoute.hasRbac()) { + return parseRbacConfig(rbacPerRoute.getRbac()); + } else { + return ConfigOrError.fromConfig(RbacConfig.create(null)); } } - return ConfigOrError.fromConfig(RbacConfig.create( - AuthConfig.create(policyMatchers, authAction))); - } - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - RBACPerRoute rbacPerRoute; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); - } - Any anyMessage = (Any) rawProtoMessage; - try { - rbacPerRoute = anyMessage.unpack(RBACPerRoute.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); - } - if (rbacPerRoute.hasRbac()) { - return parseRbacConfig(rbacPerRoute.getRbac()); - } else { - return ConfigOrError.fromConfig(RbacConfig.create(null)); + static ConfigOrError parseRbacConfig(RBAC rbac) { + if (!rbac.hasRules()) { + return ConfigOrError.fromConfig(RbacConfig.create(null)); + } + io.envoyproxy.envoy.config.rbac.v3.RBAC rbacConfig = rbac.getRules(); + GrpcAuthorizationEngine.Action authAction; + switch (rbacConfig.getAction()) { + case ALLOW: + authAction = GrpcAuthorizationEngine.Action.ALLOW; + break; + case DENY: + authAction = GrpcAuthorizationEngine.Action.DENY; + break; + case LOG: + return ConfigOrError.fromConfig(RbacConfig.create(null)); + case UNRECOGNIZED: + default: + return ConfigOrError.fromError( + "Unknown rbacConfig action type: " + rbacConfig.getAction()); + } + List policyMatchers = new ArrayList<>(); + List> sortedPolicyEntries = rbacConfig.getPoliciesMap().entrySet() + .stream() + .sorted((a,b) -> a.getKey().compareTo(b.getKey())) + .collect(Collectors.toList()); + for (Map.Entry entry: sortedPolicyEntries) { + try { + Policy policy = entry.getValue(); + if (policy.hasCondition() || policy.hasCheckedCondition()) { + return ConfigOrError.fromError( + "Policy.condition and Policy.checked_condition must not set: " + entry.getKey()); + } + policyMatchers.add(PolicyMatcher.create(entry.getKey(), + parsePermissionList(policy.getPermissionsList()), + parsePrincipalList(policy.getPrincipalsList()))); + } catch (Exception e) { + return ConfigOrError.fromError("Encountered error parsing policy: " + e); + } + } + return ConfigOrError.fromConfig(RbacConfig.create( + AuthConfig.create(policyMatchers, authAction))); } } diff --git a/xds/src/main/java/io/grpc/xds/RouterFilter.java b/xds/src/main/java/io/grpc/xds/RouterFilter.java index 8038c1b98a..939bd0b12a 100644 --- a/xds/src/main/java/io/grpc/xds/RouterFilter.java +++ b/xds/src/main/java/io/grpc/xds/RouterFilter.java @@ -17,18 +17,12 @@ package io.grpc.xds; import com.google.protobuf.Message; -import io.grpc.ClientInterceptor; -import io.grpc.ServerInterceptor; -import io.grpc.xds.Filter.ClientInterceptorBuilder; -import io.grpc.xds.Filter.ServerInterceptorBuilder; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; /** * Router filter implementation. Currently this filter does not parse any field in the config. */ -enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptorBuilder { - INSTANCE; +final class RouterFilter implements Filter { + private static final RouterFilter INSTANCE = new RouterFilter(); static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.router.v3.Router"; @@ -36,7 +30,7 @@ enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptor static final FilterConfig ROUTER_CONFIG = new FilterConfig() { @Override public String typeUrl() { - return RouterFilter.TYPE_URL; + return TYPE_URL; } @Override @@ -45,33 +39,38 @@ enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptor } }; - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL }; + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; + } + + @Override + public boolean isClientFilter() { + return true; + } + + @Override + public boolean isServerFilter() { + return true; + } + + @Override + public RouterFilter newInstance() { + return INSTANCE; + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + return ConfigOrError.fromConfig(ROUTER_CONFIG); + } + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage) { + return ConfigOrError.fromError("Router Filter should not have override config"); + } } - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - return ConfigOrError.fromConfig(ROUTER_CONFIG); - } - - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - return ConfigOrError.fromError("Router Filter should not have override config"); - } - - @Nullable - @Override - public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, - ScheduledExecutorService scheduler) { - return null; - } - - @Nullable - @Override - public ServerInterceptor buildServerInterceptor( - FilterConfig config, @Nullable Filter.FilterConfig overrideConfig) { - return null; - } + private RouterFilter() {} } diff --git a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java index a18b093e38..4b554be174 100644 --- a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java @@ -575,12 +575,8 @@ class XdsListenerResource extends XdsResourceType { String filterName = httpFilter.getName(); boolean isOptional = httpFilter.getIsOptional(); if (!httpFilter.hasTypedConfig()) { - if (isOptional) { - return null; - } else { - return StructOrError.fromError( - "HttpFilter [" + filterName + "] is not optional and has no typed config"); - } + return isOptional ? null : StructOrError.fromError( + "HttpFilter [" + filterName + "] is not optional and has no typed config"); } Message rawConfig = httpFilter.getTypedConfig(); String typeUrl = httpFilter.getTypedConfig().getTypeUrl(); @@ -600,18 +596,17 @@ class XdsListenerResource extends XdsResourceType { return StructOrError.fromError( "HttpFilter [" + filterName + "] contains invalid proto: " + e); } - Filter filter = filterRegistry.get(typeUrl); - if ((isForClient && !(filter instanceof Filter.ClientInterceptorBuilder)) - || (!isForClient && !(filter instanceof Filter.ServerInterceptorBuilder))) { - if (isOptional) { - return null; - } else { - return StructOrError.fromError( - "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for " - + (isForClient ? "client" : "server")); - } + + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null + || (isForClient && !provider.isClientFilter()) + || (!isForClient && !provider.isServerFilter())) { + // Filter type not supported. + return isOptional ? null : StructOrError.fromError( + "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for " + ( + isForClient ? "client" : "server")); } - ConfigOrError filterConfig = filter.parseFilterConfig(rawConfig); + ConfigOrError filterConfig = provider.parseFilterConfig(rawConfig); if (filterConfig.errorDetail != null) { return StructOrError.fromError( "Invalid filter config for HttpFilter [" + filterName + "]: " + filterConfig.errorDetail); diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 21f5d5efce..b7b1ed0bdb 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -49,7 +49,6 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; -import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig; @@ -827,26 +826,36 @@ final class XdsNameResolver extends NameResolver { if (filterConfigs == null) { return new PassthroughClientInterceptor(); } + Map selectedOverrideConfigs = new HashMap<>(virtualHost.filterConfigOverrides()); selectedOverrideConfigs.putAll(route.filterConfigOverrides()); if (weightedCluster != null) { selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides()); } + ImmutableList.Builder filterInterceptors = ImmutableList.builder(); for (NamedFilterConfig namedFilter : filterConfigs) { - FilterConfig filterConfig = namedFilter.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ClientInterceptorBuilder) { - ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter) - .buildClientInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilter.name), - scheduler); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } + FilterConfig config = namedFilter.filterConfig; + String name = namedFilter.name; + String typeUrl = config.typeUrl(); + + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null || !provider.isClientFilter()) { + continue; + } + + Filter filter = provider.newInstance(); + + ClientInterceptor interceptor = + filter.buildClientInterceptor(config, selectedOverrideConfigs.get(name), scheduler); + if (interceptor != null) { + filterInterceptors.add(interceptor); } } + + // Combine interceptors produced by different filters into a single one that executes + // them sequentially. The order is preserved. return combineInterceptors(filterInterceptors.build()); } diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java index c5ca8d45cb..80a77cbb1d 100644 --- a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java @@ -245,8 +245,8 @@ class XdsRouteConfigureResource extends XdsResourceType { return StructOrError.fromError( "FilterConfig [" + name + "] contains invalid proto: " + e); } - Filter filter = filterRegistry.get(typeUrl); - if (filter == null) { + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null) { if (isOptional) { continue; } @@ -254,7 +254,7 @@ class XdsRouteConfigureResource extends XdsResourceType { "HttpFilter [" + name + "](" + typeUrl + ") is required but unsupported"); } ConfigOrError filterConfig = - filter.parseFilterConfigOverride(rawConfig); + provider.parseFilterConfigOverride(rawConfig); if (filterConfig.errorDetail != null) { return StructOrError.fromError( "Invalid filter config for HttpFilter [" + name + "]: " + filterConfig.errorDetail); diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index 3a9b98ee32..bbb17d9b61 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -47,7 +47,6 @@ import io.grpc.internal.SharedResourceHolder; import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.VirtualHost.Route; @@ -524,37 +523,56 @@ final class XdsServerWrapper extends Server { } private ImmutableMap generatePerRouteInterceptors( - List namedFilterConfigs, List virtualHosts) { + @Nullable List filterConfigs, List virtualHosts) { + // This should always be called from the sync context. + // Ideally we'd want to throw otherwise, but this breaks the tests now. + // syncContext.throwIfNotInThisSynchronizationContext(); + ImmutableMap.Builder perRouteInterceptors = new ImmutableMap.Builder<>(); + for (VirtualHost virtualHost : virtualHosts) { for (Route route : virtualHost.routes()) { - List filterInterceptors = new ArrayList<>(); - Map selectedOverrideConfigs = - new HashMap<>(virtualHost.filterConfigOverrides()); - selectedOverrideConfigs.putAll(route.filterConfigOverrides()); - if (namedFilterConfigs != null) { - for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) { - FilterConfig filterConfig = namedFilterConfig.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ServerInterceptorBuilder) { - ServerInterceptor interceptor = - ((ServerInterceptorBuilder) filter).buildServerInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } - } else { - logger.log(Level.WARNING, "HttpFilterConfig(type URL: " - + filterConfig.typeUrl() + ") is not supported on server-side. " - + "Probably a bug at ClientXdsClient verification."); - } + // Short circuit. + if (filterConfigs == null) { + perRouteInterceptors.put(route, noopInterceptor); + continue; + } + + // Override vhost filter configs with more specific per-route configs. + Map perRouteOverrides = ImmutableMap.builder() + .putAll(virtualHost.filterConfigOverrides()) + .putAll(route.filterConfigOverrides()) + .buildKeepingLast(); + + // Interceptors for this vhost/route combo. + List interceptors = new ArrayList<>(filterConfigs.size()); + + for (NamedFilterConfig namedFilter : filterConfigs) { + FilterConfig config = namedFilter.filterConfig; + String name = namedFilter.name; + String typeUrl = config.typeUrl(); + + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null || !provider.isServerFilter()) { + logger.warning("HttpFilter[" + name + "]: not supported on server-side: " + typeUrl); + continue; + } + + Filter filter = provider.newInstance(); + ServerInterceptor interceptor = + filter.buildServerInterceptor(config, perRouteOverrides.get(name)); + if (interceptor != null) { + interceptors.add(interceptor); } } - ServerInterceptor interceptor = combineInterceptors(filterInterceptors); - perRouteInterceptors.put(route, interceptor); + + // Combine interceptors produced by different filters into a single one that executes + // them sequentially. The order is preserved. + perRouteInterceptors.put(route, combineInterceptors(interceptors)); } } + return perRouteInterceptors.buildOrThrow(); } diff --git a/xds/src/test/java/io/grpc/xds/FaultFilterTest.java b/xds/src/test/java/io/grpc/xds/FaultFilterTest.java index f85f29ec0a..8f0a33951b 100644 --- a/xds/src/test/java/io/grpc/xds/FaultFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/FaultFilterTest.java @@ -33,16 +33,23 @@ import org.junit.runners.JUnit4; /** Tests for {@link FaultFilter}. */ @RunWith(JUnit4.class) public class FaultFilterTest { + private static final FaultFilter.Provider FILTER_PROVIDER = new FaultFilter.Provider(); + + @Test + public void filterType_clientOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isFalse(); + } @Test public void parseFaultAbort_convertHttpStatus() { Any rawConfig = Any.pack( HTTPFault.newBuilder().setAbort(FaultAbort.newBuilder().setHttpStatus(404)).build()); - FaultConfig faultConfig = FaultFilter.INSTANCE.parseFilterConfig(rawConfig).config; + FaultConfig faultConfig = FILTER_PROVIDER.parseFilterConfig(rawConfig).config; assertThat(faultConfig.faultAbort().status().getCode()) .isEqualTo(GrpcUtil.httpStatusToGrpcStatus(404).getCode()); - FaultConfig faultConfigOverride = - FaultFilter.INSTANCE.parseFilterConfigOverride(rawConfig).config; + + FaultConfig faultConfigOverride = FILTER_PROVIDER.parseFilterConfigOverride(rawConfig).config; assertThat(faultConfigOverride.faultAbort().status().getCode()) .isEqualTo(GrpcUtil.httpStatusToGrpcStatus(404).getCode()); } @@ -54,7 +61,7 @@ public class FaultFilterTest { .setPercentage(FractionalPercent.newBuilder() .setNumerator(20).setDenominator(DenominatorType.HUNDRED)) .setHeaderAbort(HeaderAbort.getDefaultInstance()).build(); - FaultConfig.FaultAbort faultAbort = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort faultAbort = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(faultAbort.headerAbort()).isTrue(); assertThat(faultAbort.percent().numerator()).isEqualTo(20); assertThat(faultAbort.percent().denominatorType()) @@ -68,7 +75,7 @@ public class FaultFilterTest { .setPercentage(FractionalPercent.newBuilder() .setNumerator(100).setDenominator(DenominatorType.TEN_THOUSAND)) .setHttpStatus(400).build(); - FaultConfig.FaultAbort res = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort res = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(res.percent().numerator()).isEqualTo(100); assertThat(res.percent().denominatorType()) .isEqualTo(FaultConfig.FractionalPercent.DenominatorType.TEN_THOUSAND); @@ -82,7 +89,7 @@ public class FaultFilterTest { .setPercentage(FractionalPercent.newBuilder() .setNumerator(600).setDenominator(DenominatorType.MILLION)) .setGrpcStatus(Code.DEADLINE_EXCEEDED.value()).build(); - FaultConfig.FaultAbort faultAbort = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort faultAbort = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(faultAbort.percent().numerator()).isEqualTo(600); assertThat(faultAbort.percent().denominatorType()) .isEqualTo(FaultConfig.FractionalPercent.DenominatorType.MILLION); diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java index 3ca240ab7c..52efaf9bd7 100644 --- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -35,6 +35,7 @@ import io.grpc.Channel; import io.grpc.ClientInterceptor; import io.grpc.MethodDescriptor; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.GcpAuthenticationFilter.GcpAuthenticationConfig; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -43,6 +44,14 @@ import org.mockito.Mockito; @RunWith(JUnit4.class) public class GcpAuthenticationFilterTest { + private static final GcpAuthenticationFilter.Provider FILTER_PROVIDER = + new GcpAuthenticationFilter.Provider(); + + @Test + public void filterType_clientOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isFalse(); + } @Test public void testParseFilterConfig_withValidConfig() { @@ -51,13 +60,11 @@ public class GcpAuthenticationFilterTest { .build(); Any anyMessage = Any.pack(config); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); - ConfigOrError result = filter.parseFilterConfig(anyMessage); + ConfigOrError result = FILTER_PROVIDER.parseFilterConfig(anyMessage); assertNotNull(result.config); assertNull(result.errorDetail); - assertEquals(20L, - ((GcpAuthenticationFilter.GcpAuthenticationConfig) result.config).getCacheSize()); + assertEquals(20L, result.config.getCacheSize()); } @Test @@ -67,8 +74,7 @@ public class GcpAuthenticationFilterTest { .build(); Any anyMessage = Any.pack(config); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); - ConfigOrError result = filter.parseFilterConfig(anyMessage); + ConfigOrError result = FILTER_PROVIDER.parseFilterConfig(anyMessage); assertNull(result.config); assertNotNull(result.errorDetail); @@ -77,9 +83,9 @@ public class GcpAuthenticationFilterTest { @Test public void testParseFilterConfig_withInvalidMessageType() { - GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); Message invalidMessage = Empty.getDefaultInstance(); - ConfigOrError result = filter.parseFilterConfig(invalidMessage); + ConfigOrError result = + FILTER_PROVIDER.parseFilterConfig(invalidMessage); assertNull(result.config); assertThat(result.errorDetail).contains("Invalid config type"); @@ -87,8 +93,7 @@ public class GcpAuthenticationFilterTest { @Test public void testClientInterceptor_createsAndReusesCachedCredentials() { - GcpAuthenticationFilter.GcpAuthenticationConfig config = - new GcpAuthenticationFilter.GcpAuthenticationConfig(10); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); // Create interceptor diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java index 314b209448..610d147ccf 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java @@ -110,7 +110,6 @@ import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; import io.envoyproxy.envoy.type.v3.Int64Range; -import io.grpc.ClientInterceptor; import io.grpc.EquivalentAddressGroup; import io.grpc.InsecureChannelCredentials; import io.grpc.LoadBalancerRegistry; @@ -150,9 +149,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -165,6 +162,10 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class GrpcXdsClientImplDataTest { + private static final FaultFilter.Provider FAULT_FILTER_PROVIDER = new FaultFilter.Provider(); + private static final RbacFilter.Provider RBAC_FILTER_PROVIDER = new RbacFilter.Provider(); + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + private static final ServerInfo LRS_SERVER_INFO = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); private static final String GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE = @@ -1243,36 +1244,39 @@ public class GrpcXdsClientImplDataTest { } } - private static class TestFilter implements io.grpc.xds.Filter, - io.grpc.xds.Filter.ClientInterceptorBuilder { - @Override - public String[] typeUrls() { - return new String[]{"test-url"}; - } + private static class TestFilter implements io.grpc.xds.Filter { - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); - } + static final class Provider implements io.grpc.xds.Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{"test-url"}; + } - @Override - public ConfigOrError parseFilterConfigOverride( - Message rawProtoMessage) { - return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); - } + @Override + public boolean isClientFilter() { + return true; + } - @Nullable - @Override - public ClientInterceptor buildClientInterceptor(FilterConfig config, - @Nullable FilterConfig overrideConfig, - ScheduledExecutorService scheduler) { - return null; + @Override + public TestFilter newInstance() { + return new TestFilter(); + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); + } + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); + } } } @Test public void parseHttpFilter_typedStructMigration() { - filterRegistry.register(new TestFilter()); + filterRegistry.register(new TestFilter.Provider()); Struct rawStruct = Struct.newBuilder() .putFields("name", Value.newBuilder().setStringValue("default").build()) .build(); @@ -1301,7 +1305,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseOverrideHttpFilter_typedStructMigration() { - filterRegistry.register(new TestFilter()); + filterRegistry.register(new TestFilter.Provider()); Struct rawStruct0 = Struct.newBuilder() .putFields("name", Value.newBuilder().setStringValue("default0").build()) .build(); @@ -1342,7 +1346,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpFilter_routerFilterForClient() { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1356,7 +1360,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpFilter_routerFilterForServer() { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1370,7 +1374,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpFilter_faultConfigForClient() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1397,7 +1401,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpFilter_faultConfigUnsupportedForServer() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1426,7 +1430,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpFilter_rbacConfigForServer() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1453,7 +1457,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpFilter_rbacConfigUnsupportedForClient() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1482,7 +1486,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseOverrideRbacFilterConfig() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder() .setRbac( @@ -1508,7 +1512,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseOverrideFilterConfigs_unsupportedButOptional() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HTTPFault httpFault = HTTPFault.newBuilder() .setDelay(FaultDelay.newBuilder().setFixedDelay(Durations.fromNanos(3000))) .build(); @@ -1528,7 +1532,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseOverrideFilterConfigs_unsupportedAndRequired() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HTTPFault httpFault = HTTPFault.newBuilder() .setDelay(FaultDelay.newBuilder().setFixedDelay(Durations.fromNanos(3000))) .build(); @@ -1620,7 +1624,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidException { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addHttpFilters( @@ -1638,7 +1642,7 @@ public class GrpcXdsClientImplDataTest { @Test public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidException { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addHttpFilters( diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java index 013b21e3f4..7f195693d8 100644 --- a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -78,6 +78,13 @@ public class RbacFilterTest { private static final String PATH = "auth"; private static final StringMatcher STRING_MATCHER = StringMatcher.newBuilder().setExact("/" + PATH).setIgnoreCase(true).build(); + private static final RbacFilter.Provider FILTER_PROVIDER = new RbacFilter.Provider(); + + @Test + public void filterType_serverOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isFalse(); + assertThat(FILTER_PROVIDER.isServerFilter()).isTrue(); + } @Test @SuppressWarnings({"unchecked", "deprecation"}) @@ -252,7 +259,7 @@ public class RbacFilterTest { OrMatcher.create(AlwaysTrueMatcher.INSTANCE)); AuthConfig authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.ALLOW); - new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance().buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler, never()).startCall(eq(mockServerCall), any(Metadata.class)); ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); @@ -264,7 +271,7 @@ public class RbacFilterTest { authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.DENY); - new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance().buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); } @@ -290,7 +297,7 @@ public class RbacFilterTest { .putPolicies("policy-name", Policy.newBuilder().setCondition(Expr.newBuilder().build()).build()) .build()).build(); - result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + result = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(result.errorDetail).isNotNull(); } @@ -312,10 +319,10 @@ public class RbacFilterTest { RbacConfig original = RbacConfig.create(authconfig); RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().build(); - RbacConfig override = - new RbacFilter().parseFilterConfigOverride(Any.pack(rbacPerRoute)).config; + RbacConfig override = FILTER_PROVIDER.parseFilterConfigOverride(Any.pack(rbacPerRoute)).config; assertThat(override).isEqualTo(RbacConfig.create(null)); - ServerInterceptor interceptor = new RbacFilter().buildServerInterceptor(original, override); + ServerInterceptor interceptor = + FILTER_PROVIDER.newInstance().buildServerInterceptor(original, override); assertThat(interceptor).isNull(); policyMatcher = PolicyMatcher.create("policy-matcher-override", @@ -325,7 +332,7 @@ public class RbacFilterTest { GrpcAuthorizationEngine.Action.ALLOW); override = RbacConfig.create(authconfig); - new RbacFilter().buildServerInterceptor(original, override) + FILTER_PROVIDER.newInstance().buildServerInterceptor(original, override) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); verify(mockServerCall).getAttributes(); @@ -337,22 +344,22 @@ public class RbacFilterTest { Message rawProto = io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC.newBuilder() .setRules(RBAC.newBuilder().setAction(Action.LOG) .putPolicies("policy-name", Policy.newBuilder().build()).build()).build(); - ConfigOrError result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError result = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(result.config).isEqualTo(RbacConfig.create(null)); } @Test public void testOrderIndependenceOfPolicies() { Message rawProto = buildComplexRbac(ImmutableList.of(1, 2, 3, 4, 5, 6), true); - ConfigOrError ascFirst = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError ascFirst = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); rawProto = buildComplexRbac(ImmutableList.of(1, 2, 3, 4, 5, 6), false); - ConfigOrError ascLast = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError ascLast = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(ascFirst.config).isEqualTo(ascLast.config); rawProto = buildComplexRbac(ImmutableList.of(6, 5, 4, 3, 2, 1), true); - ConfigOrError decFirst = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError decFirst = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto)); assertThat(ascFirst.config).isEqualTo(decFirst.config); } @@ -374,14 +381,14 @@ public class RbacFilterTest { private ConfigOrError parse(List permissionList, List principalList) { - return RbacFilter.parseRbacConfig(buildRbac(permissionList, principalList)); + return RbacFilter.Provider.parseRbacConfig(buildRbac(permissionList, principalList)); } private ConfigOrError parseRaw(List permissionList, List principalList) { Message rawProto = buildRbac(permissionList, principalList); Any proto = Any.pack(rawProto); - return new RbacFilter().parseFilterConfig(proto); + return FILTER_PROVIDER.parseFilterConfig(proto); } private io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC buildRbac( @@ -449,6 +456,6 @@ public class RbacFilterTest { RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().setRbac( buildRbac(permissionList, principalList)).build(); Any proto = Any.pack(rbacPerRoute); - return new RbacFilter().parseFilterConfigOverride(proto); + return FILTER_PROVIDER.parseFilterConfigOverride(proto); } } diff --git a/xds/src/test/java/io/grpc/xds/RouterFilterTest.java b/xds/src/test/java/io/grpc/xds/RouterFilterTest.java new file mode 100644 index 0000000000..30fd8a6dc3 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/RouterFilterTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RouterFilter}. */ +@RunWith(JUnit4.class) +public class RouterFilterTest { + private static final RouterFilter.Provider FILTER_PROVIDER = new RouterFilter.Provider(); + + @Test + public void filterType_clientAndServer() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isTrue(); + } + +} diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index d895cecdb1..f7309051f9 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -22,10 +22,12 @@ import static io.grpc.xds.FaultFilter.HEADER_ABORT_HTTP_STATUS_KEY; import static io.grpc.xds.FaultFilter.HEADER_ABORT_PERCENTAGE_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -130,6 +132,9 @@ public class XdsNameResolverTest { private static final String RDS_RESOURCE_NAME = "route-configuration.googleapis.com"; private static final String FAULT_FILTER_INSTANCE_NAME = "envoy.fault"; private static final String ROUTER_FILTER_INSTANCE_NAME = "envoy.router"; + private static final FaultFilter.Provider FAULT_FILTER_PROVIDER = new FaultFilter.Provider(); + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private final SynchronizationContext syncContext = new SynchronizationContext( @@ -184,9 +189,19 @@ public class XdsNameResolverTest { originalEnableTimeout = XdsNameResolver.enableTimeout; XdsNameResolver.enableTimeout = true; + + // Replace FaultFilter.Provider with the one returning FaultFilter injected with mockRandom. + Filter.Provider faultFilterProvider = + mock(Filter.Provider.class, delegatesTo(FAULT_FILTER_PROVIDER)); + // Lenient: suppress [MockitoHint] Unused warning, only used in resolved_fault* tests. + lenient() + .doReturn(new FaultFilter(mockRandom, new AtomicLong())) + .when(faultFilterProvider).newInstance(); + FilterRegistry filterRegistry = FilterRegistry.newRegistry().register( - new FaultFilter(mockRandom, new AtomicLong()), - RouterFilter.INSTANCE); + ROUTER_FILTER_PROVIDER, + faultFilterProvider); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, filterRegistry, null, metricRecorder); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 66ac1475d8..41f005ba58 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -31,7 +31,6 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -53,7 +52,6 @@ import io.grpc.testing.TestMethodDescriptors; import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteMatch; @@ -957,9 +955,11 @@ public class XdsServerWrapperTest { new AtomicReference<>(routingConfig)).build()); when(serverCall.getAuthority()).thenReturn("not-match.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -998,9 +998,11 @@ public class XdsServerWrapperTest { when(serverCall.getMethodDescriptor()).thenReturn(createMethod("NotMatchMethod")); when(serverCall.getAuthority()).thenReturn("foo.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -1044,9 +1046,11 @@ public class XdsServerWrapperTest { when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); when(serverCall.getAuthority()).thenReturn("foo.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -1113,10 +1117,14 @@ public class XdsServerWrapperTest { RouteMatch.create( PathMatcher.fromPath("/FooService/barMethod", true), Collections.emptyList(), null); - Filter filter = mock(Filter.class, withSettings() - .extraInterfaces(ServerInterceptorBuilder.class)); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + + Filter filter = mock(Filter.class); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + when(filterProvider.newInstance()).thenReturn(filter); + filterRegistry.register(filterProvider); + FilterConfig f0 = mock(FilterConfig.class); FilterConfig f0Override = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn("filter-type-url"); @@ -1137,10 +1145,8 @@ public class XdsServerWrapperTest { return next.startCall(call, headers); } }; - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) - .thenReturn(interceptor0); - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) - .thenReturn(interceptor1); + when(filter.buildServerInterceptor(f0, null)).thenReturn(interceptor0); + when(filter.buildServerInterceptor(f0, f0Override)).thenReturn(interceptor1); Route route = Route.forAction(routeMatch, null, ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create( @@ -1185,10 +1191,13 @@ public class XdsServerWrapperTest { }); xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - Filter filter = mock(Filter.class, withSettings() - .extraInterfaces(ServerInterceptorBuilder.class)); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter filter = mock(Filter.class); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + when(filterProvider.newInstance()).thenReturn(filter); + filterRegistry.register(filterProvider); + FilterConfig f0 = mock(FilterConfig.class); FilterConfig f0Override = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn("filter-type-url"); @@ -1209,10 +1218,8 @@ public class XdsServerWrapperTest { return next.startCall(call, headers); } }; - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) - .thenReturn(interceptor0); - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) - .thenReturn(interceptor1); + when(filter.buildServerInterceptor(f0, null)).thenReturn(interceptor0); + when(filter.buildServerInterceptor(f0, f0Override)).thenReturn(interceptor1); RouteMatch routeMatch = RouteMatch.create( PathMatcher.fromPath("/FooService/barMethod", true),