diff --git a/xds/src/main/java/io/grpc/xds/RouteMatch.java b/xds/src/main/java/io/grpc/xds/RouteMatch.java index 63af1629f8..75cbaa9e56 100644 --- a/xds/src/main/java/io/grpc/xds/RouteMatch.java +++ b/xds/src/main/java/io/grpc/xds/RouteMatch.java @@ -20,9 +20,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects.ToStringHelper; import com.google.re2j.Pattern; +import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Set; import javax.annotation.Nullable; /** @@ -48,6 +51,27 @@ final class RouteMatch { Collections.emptyList(), null); } + /** + * Returns {@code true} if a request with the given path and headers passes all the rules + * specified by this RouteMatch. + * + *

The request's headers are given as a key-values mapping, where multiple values can + * be mapped to the same key. + * + *

Match is not deterministic if a runtime fraction match rule presents in this RouteMatch. + */ + boolean matches(String path, Map> headers) { + if (!pathMatch.matches(path)) { + return false; + } + for (HeaderMatcher headerMatcher : headerMatchers) { + if (!headerMatcher.matchesValue(headers.get(headerMatcher.getName()))) { + return false; + } + } + return fractionMatch == null || fractionMatch.matches(); + } + PathMatcher getPathMatch() { return pathMatch; } @@ -105,6 +129,15 @@ final class RouteMatch { this.regEx = regEx; } + private boolean matches(String fullMethodName) { + if (path != null) { + return path.equals(fullMethodName); + } else if (prefix != null) { + return fullMethodName.startsWith(prefix); + } + return regEx.matches(fullMethodName); + } + @Nullable String getPath() { return path; @@ -196,6 +229,39 @@ final class RouteMatch { this.isInvertedMatch = isInvertedMatch; } + private boolean matchesValue(@Nullable Set values) { + if (presentMatch != null) { + return (values == null) == presentMatch.equals(isInvertedMatch); + } + if (values == null) { + return false; + } + boolean baseMatch = false; + for (String value : values) { + if (exactMatch != null) { + baseMatch = exactMatch.equals(value); + } else if (safeRegExMatch != null) { + baseMatch = safeRegExMatch.matches(value); + } else if (rangeMatch != null) { + long numValue; + try { + numValue = Long.parseLong(value); + } catch (NumberFormatException ignored) { + continue; + } + baseMatch = rangeMatch.contains(numValue); + } else if (prefixMatch != null) { + baseMatch = value.startsWith(prefixMatch); + } else { + baseMatch = value.endsWith(suffixMatch); + } + if (baseMatch) { + break; + } + } + return baseMatch != isInvertedMatch; + } + String getName() { return name; } @@ -290,6 +356,10 @@ final class RouteMatch { this.end = end; } + boolean contains(long value) { + return value >= start && value < end; + } + long getStart() { return start; } @@ -329,10 +399,21 @@ final class RouteMatch { static final class FractionMatcher { private final int numerator; private final int denominator; + private final ThreadSafeRandom rand; FractionMatcher(int numerator, int denominator) { + this(numerator, denominator, ThreadSafeRandomImpl.instance); + } + + @VisibleForTesting + FractionMatcher(int numerator, int denominator, ThreadSafeRandom rand) { this.numerator = numerator; this.denominator = denominator; + this.rand = rand; + } + + private boolean matches() { + return rand.nextInt(denominator) < numerator; } int getNumerator() { diff --git a/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancer.java b/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancer.java index 2038179c4a..428164ec41 100644 --- a/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancer.java @@ -23,12 +23,17 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; import io.grpc.ConnectivityState; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.GracefulSwitchLoadBalancer; @@ -37,87 +42,89 @@ import io.grpc.xds.XdsRoutingLoadBalancerProvider.Route; import io.grpc.xds.XdsRoutingLoadBalancerProvider.XdsRoutingConfig; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; /** Load balancer for xds_routing policy. */ final class XdsRoutingLoadBalancer extends LoadBalancer { + @VisibleForTesting + static final int DELAYED_ACTION_DELETION_TIME_MINUTES = 15; + private final XdsLogger logger; private final Helper helper; - private final Map routeBalancers = new HashMap<>(); - private final Map routeHelpers = new HashMap<>(); + private final SynchronizationContext syncContext; + private final ScheduledExecutorService timeService; + private final Map childLbStates = new HashMap<>(); // keyed by action names - private Map actions = ImmutableMap.of(); private List routes = ImmutableList.of(); XdsRoutingLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); + this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); + this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); logger = XdsLogger.withLogId( InternalLogId.allocate("xds-routing-lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public void handleResolvedAddresses(final ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); XdsRoutingConfig xdsRoutingConfig = (XdsRoutingConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - checkNotNull(xdsRoutingConfig, "Missing xds_routing lb config"); - Map newActions = xdsRoutingConfig.actions; - for (String actionName : newActions.keySet()) { - PolicySelection action = newActions.get(actionName); - if (!actions.containsKey(actionName)) { - RouteHelper routeHelper = new RouteHelper(); - GracefulSwitchLoadBalancer routeBalancer = new GracefulSwitchLoadBalancer(routeHelper); - routeBalancer.switchTo(action.getProvider()); - routeHelpers.put(actionName, routeHelper); - routeBalancers.put(actionName, routeBalancer); - } else if (!action.getProvider().equals(actions.get(actionName).getProvider())) { - routeBalancers.get(actionName).switchTo(action.getProvider()); + for (final String actionName : newActions.keySet()) { + final PolicySelection action = newActions.get(actionName); + if (!childLbStates.containsKey(actionName)) { + childLbStates.put(actionName, new ChildLbState(actionName, action.getProvider())); + } else { + childLbStates.get(actionName).reactivate(action.getProvider()); } + syncContext.execute(new Runnable() { + @Override + public void run() { + childLbStates.get(actionName).lb + .handleResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(action.getConfig()) + .build()); + } + }); } - this.routes = xdsRoutingConfig.routes; - this.actions = newActions; - - for (String actionName : actions.keySet()) { - routeBalancers.get(actionName).handleResolvedAddresses( - resolvedAddresses.toBuilder() - .setLoadBalancingPolicyConfig(actions.get(actionName).getConfig()) - .build()); + Set diff = Sets.difference(childLbStates.keySet(), newActions.keySet()); + for (String actionName : diff) { + childLbStates.get(actionName).deactivate(); } - - // Cleanup removed actions. - // TODO(zdapeng): cache removed actions for 15 minutes. - for (String actionName : routeBalancers.keySet()) { - if (!actions.containsKey(actionName)) { - routeBalancers.get(actionName).shutdown(); - } - } - routeBalancers.keySet().retainAll(actions.keySet()); - routeHelpers.keySet().retainAll(actions.keySet()); } @Override public void handleNameResolutionError(Status error) { logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); - if (routeBalancers.isEmpty()) { - helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error)); + boolean gotoTransientFailure = true; + for (ChildLbState state : childLbStates.values()) { + if (!state.deactivated) { + gotoTransientFailure = false; + state.lb.handleNameResolutionError(error); + } } - for (LoadBalancer routeBalancer : routeBalancers.values()) { - routeBalancer.handleNameResolutionError(error); + if (gotoTransientFailure) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error)); } } @Override public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); - for (LoadBalancer routeBalancer : routeBalancers.values()) { - routeBalancer.shutdown(); + for (ChildLbState state : childLbStates.values()) { + state.tearDown(); } } @@ -131,10 +138,9 @@ final class XdsRoutingLoadBalancer extends LoadBalancer { // Use LinkedHashMap to preserve the order of routes. Map routePickers = new LinkedHashMap<>(); for (Route route : routes) { - RouteHelper routeHelper = routeHelpers.get(route.getActionName()); - routePickers.put(route.getRouteMatch(), routeHelper.currentPicker); - ConnectivityState routeState = routeHelper.currentState; - overallState = aggregateState(overallState, routeState); + ChildLbState state = childLbStates.get(route.getActionName()); + routePickers.put(route.getRouteMatch(), state.currentPicker); + overallState = aggregateState(overallState, state.currentState); } if (overallState != null) { SubchannelPicker picker = new RouteMatchingSubchannelPicker(routePickers); @@ -142,8 +148,9 @@ final class XdsRoutingLoadBalancer extends LoadBalancer { } } + @VisibleForTesting @Nullable - private static ConnectivityState aggregateState( + static ConnectivityState aggregateState( @Nullable ConnectivityState overallState, ConnectivityState childState) { if (overallState == null) { return childState; @@ -160,28 +167,94 @@ final class XdsRoutingLoadBalancer extends LoadBalancer { return overallState; } - /** - * The lb helper for a single route balancer. - */ - private final class RouteHelper extends ForwardingLoadBalancerHelper { - ConnectivityState currentState = CONNECTING; - SubchannelPicker currentPicker = BUFFER_PICKER; + private final class ChildLbState { + private final String name; + private final GracefulSwitchLoadBalancer lb; + private LoadBalancerProvider policyProvider; + private ConnectivityState currentState = CONNECTING; + private SubchannelPicker currentPicker = BUFFER_PICKER; + private boolean deactivated; + @Nullable + ScheduledHandle deletionTimer; - @Override - public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { - currentState = newState; - currentPicker = newPicker; - updateOverallBalancingState(); + private ChildLbState(String name, LoadBalancerProvider policyProvider) { + this.name = name; + this.policyProvider = policyProvider; + lb = new GracefulSwitchLoadBalancer(new RouteHelper()); + lb.switchTo(policyProvider); } - @Override - protected Helper delegate() { - return helper; + void deactivate() { + if (deactivated) { + return; + } + + class DeletionTask implements Runnable { + @Override + public void run() { + tearDown(); + childLbStates.remove(name); + } + } + + deletionTimer = + syncContext.schedule( + new DeletionTask(), + DELAYED_ACTION_DELETION_TIME_MINUTES, + TimeUnit.MINUTES, + timeService); + deactivated = true; + logger.log(XdsLogLevel.DEBUG, "Route action {0} deactivated", name); + } + + void reactivate(LoadBalancerProvider policyProvider) { + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + deactivated = false; + logger.log(XdsLogLevel.DEBUG, "Route action {0} reactivated", name); + } + if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) { + logger.log( + XdsLogLevel.DEBUG, + "Action {0} switching policy from {1} to {2}", + name, this.policyProvider.getPolicyName(), policyProvider.getPolicyName()); + lb.switchTo(policyProvider); + this.policyProvider = policyProvider; + } + } + + void tearDown() { + deactivated = true; + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + } + lb.shutdown(); + logger.log(XdsLogLevel.DEBUG, "Route action {0} deleted", name); + } + + private final class RouteHelper extends ForwardingLoadBalancerHelper { + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + if (deactivated) { + return; + } + currentState = newState; + currentPicker = newPicker; + updateOverallBalancingState(); + } + + @Override + protected Helper delegate() { + return helper; + } } } - private static final class RouteMatchingSubchannelPicker extends SubchannelPicker { + @VisibleForTesting + static final class RouteMatchingSubchannelPicker extends SubchannelPicker { + @VisibleForTesting final Map routePickers; RouteMatchingSubchannelPicker(Map routePickers) { @@ -190,8 +263,28 @@ final class XdsRoutingLoadBalancer extends LoadBalancer { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - // TODO(chengyuanzhang): to be implemented. - return PickResult.withError(Status.INTERNAL.withDescription("routing picker unimplemented")); + // Index ASCII headers by keys. + Map> asciiHeaders = new HashMap<>(); + Metadata headers = args.getHeaders(); + for (String headerName : headers.keys()) { + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + continue; + } + Set headerValues = new HashSet<>(); + Metadata.Key key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); + for (String value : headers.getAll(key)) { + headerValues.add(value); + } + asciiHeaders.put(headerName, headerValues); + } + for (Map.Entry entry : routePickers.entrySet()) { + RouteMatch routeMatch = entry.getKey(); + if (routeMatch.matches( + "/" + args.getMethodDescriptor().getFullMethodName(), asciiHeaders)) { + return entry.getValue().pickSubchannel(args); + } + } + return PickResult.withError(Status.UNAVAILABLE.withDescription("no matching route found")); } } } diff --git a/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancerProvider.java index 17e8d9b6c8..ee49050bda 100644 --- a/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsRoutingLoadBalancerProvider.java @@ -269,7 +269,8 @@ public final class XdsRoutingLoadBalancerProvider extends LoadBalancerProvider { final List routes; final Map actions; - private XdsRoutingConfig(List routes, Map actions) { + @VisibleForTesting + XdsRoutingConfig(List routes, Map actions) { this.routes = ImmutableList.copyOf(routes); this.actions = ImmutableMap.copyOf(actions); } diff --git a/xds/src/test/java/io/grpc/xds/RouteMatchTest.java b/xds/src/test/java/io/grpc/xds/RouteMatchTest.java new file mode 100644 index 0000000000..283d137c1c --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/RouteMatchTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.re2j.Pattern; +import io.grpc.xds.RouteMatch.FractionMatcher; +import io.grpc.xds.RouteMatch.HeaderMatcher; +import io.grpc.xds.RouteMatch.HeaderMatcher.Range; +import io.grpc.xds.RouteMatch.PathMatcher; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; + +/** Tests for {@link RouteMatch}. */ +public class RouteMatchTest { + + private final Map> headers = new HashMap<>(); + + @Before + public void setUp() { + headers.put("content-type", Collections.singleton("application/grpc")); + headers.put("grpc-encoding", Collections.singleton("gzip")); + headers.put("user-agent", Collections.singleton("gRPC-Java")); + headers.put("content-length", Collections.singleton("1000")); + headers.put("custom-key", new HashSet<>(Arrays.asList("custom-value1", "custom-value2"))); + } + + @Test + public void routeMatching_pathOnly() { + RouteMatch routeMatch1 = + new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.emptyList(), null); + assertThat(routeMatch1.matches("/FooService/barMethod", headers)).isTrue(); + assertThat(routeMatch1.matches("/FooService/bazMethod", headers)).isFalse(); + + RouteMatch routeMatch2 = + new RouteMatch( + new PathMatcher(null, "/FooService/", null), + Collections.emptyList(), null); + assertThat(routeMatch2.matches("/FooService/barMethod", headers)).isTrue(); + assertThat(routeMatch2.matches("/FooService/bazMethod", headers)).isTrue(); + assertThat(routeMatch2.matches("/BarService/bazMethod", headers)).isFalse(); + + RouteMatch routeMatch3 = + new RouteMatch( + new PathMatcher(null, null, Pattern.compile(".*Foo.*")), + Collections.emptyList(), null); + assertThat(routeMatch3.matches("/FooService/barMethod", headers)).isTrue(); + } + + @Test + public void routeMatching_withHeaders() { + RouteMatch routeMatch1 = new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Arrays.asList( + new HeaderMatcher( + "grpc-encoding", "gzip", null, null, null, null, null, false), + new HeaderMatcher( + "content-type", null, Pattern.compile(".*grpc.*"), null, null, null, + null, false), + new HeaderMatcher( + "content-length", null, null, new Range(100, 10000), null, null, null, false), + new HeaderMatcher("user-agent", null, null, null, true, null, null, false), + new HeaderMatcher("custom-key", null, null, null, null, "custom-", null, false), + new HeaderMatcher("custom-key", null, null, null, null, null, "value2", false)), + null); + assertThat(routeMatch1.matches("/FooService/barMethod", headers)).isTrue(); + + RouteMatch routeMatch2 = new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.singletonList( + new HeaderMatcher( + "content-type", null, Pattern.compile(".*grpc.*"), null, null, null, + null, true)), + null); + assertThat(routeMatch2.matches("/FooService/barMethod", headers)).isFalse(); + + RouteMatch routeMatch3 = new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.singletonList( + new HeaderMatcher( + "user-agent", "gRPC-Go", null, null, null, null, + null, false)), + null); + assertThat(routeMatch3.matches("/FooService/barMethod", headers)).isFalse(); + + RouteMatch routeMatch4 = new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.singletonList( + new HeaderMatcher( + "user-agent", null, null, null, false, null, + null, false)), + null); + assertThat(routeMatch4.matches("/FooService/barMethod", headers)).isFalse(); + + RouteMatch routeMatch5 = new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.singletonList( + new HeaderMatcher( + "user-agent", null, null, null, false, null, + null, true)), + null); + assertThat(routeMatch5.matches("/FooService/barMethod", headers)).isTrue(); + + RouteMatch routeMatch6 = new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.singletonList( + new HeaderMatcher( + "user-agent", null, null, null, true, null, + null, true)), + null); + assertThat(routeMatch6.matches("/FooService/barMethod", headers)).isFalse(); + } + + @Test + public void routeMatching_withRuntimeFraction() { + RouteMatch routeMatch1 = + new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.emptyList(), + new FractionMatcher(100, 1000, new FakeRandom(50))); + assertThat(routeMatch1.matches("/FooService/barMethod", headers)).isTrue(); + + RouteMatch routeMatch2 = + new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Collections.emptyList(), + new FractionMatcher(100, 1000, new FakeRandom(100))); + assertThat(routeMatch2.matches("/FooService/barMethod", headers)).isFalse(); + } + + private static final class FakeRandom implements ThreadSafeRandom { + private final int value; + + FakeRandom(int value) { + this.value = value; + } + + @Override + public int nextInt(int bound) { + return value; + } + } +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/XdsRoutingLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/XdsRoutingLoadBalancerTest.java index df67155b02..54646c37ea 100644 --- a/xds/src/test/java/io/grpc/xds/XdsRoutingLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsRoutingLoadBalancerTest.java @@ -16,13 +16,328 @@ package io.grpc.xds; -import org.junit.Ignore; +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import io.grpc.CallOptions; +import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.internal.PickSubchannelArgsImpl; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.RouteMatch.HeaderMatcher; +import io.grpc.xds.RouteMatch.PathMatcher; +import io.grpc.xds.XdsRoutingLoadBalancer.RouteMatchingSubchannelPicker; +import io.grpc.xds.XdsRoutingLoadBalancerProvider.Route; +import io.grpc.xds.XdsRoutingLoadBalancerProvider.XdsRoutingConfig; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; /** Tests for {@link XdsRoutingLoadBalancer}. */ @RunWith(JUnit4.class) -@Ignore public class XdsRoutingLoadBalancerTest { - // TODO(chengyuanzhang) + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private final FakeClock fakeClock = new FakeClock(); + + @Mock + private LoadBalancer.Helper helper; + + private RouteMatch routeMatch1 = + new RouteMatch( + new PathMatcher("/FooService/barMethod", null, null), + Arrays.asList( + new HeaderMatcher("user-agent", "gRPC-Java", null, null, null, null, null, false), + new HeaderMatcher("grpc-encoding", "gzip", null, null, null, null, null, false)), + null); + private RouteMatch routeMatch2 = + new RouteMatch( + new PathMatcher("/FooService/bazMethod", null, null), + Collections.emptyList(), + null); + private RouteMatch routeMatch3 = + new RouteMatch( + new PathMatcher(null, "/", null), + Collections.emptyList(), + null); + private List childBalancers = new ArrayList<>(); + private LoadBalancer xdsRoutingLoadBalancer; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); + xdsRoutingLoadBalancer = new XdsRoutingLoadBalancer(helper); + } + + @Test + public void typicalWorkflow() { + Object childConfig1 = new Object(); + Object childConfig2 = new Object(); + PolicySelection policyA = + new PolicySelection(new FakeLoadBalancerProvider("policy_a"), null, childConfig1); + PolicySelection policyB = + new PolicySelection(new FakeLoadBalancerProvider("policy_b"), null, childConfig2); + PolicySelection policyC = + new PolicySelection(new FakeLoadBalancerProvider("policy_c"), null , null); + + XdsRoutingConfig config = + new XdsRoutingConfig( + Arrays.asList( + new Route(routeMatch1, "action_a"), + new Route(routeMatch2, "action_b"), + new Route(routeMatch3, "action_a")), + ImmutableMap.of("action_a", policyA, "action_b", policyB)); + xdsRoutingLoadBalancer + .handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setLoadBalancingPolicyConfig(config) + .build()); + + assertThat(childBalancers).hasSize(2); + FakeLoadBalancer childBalancer1 = childBalancers.get(0); + FakeLoadBalancer childBalancer2 = childBalancers.get(1); + assertThat(childBalancer1.name).isEqualTo("policy_a"); + assertThat(childBalancer2.name).isEqualTo("policy_b"); + assertThat(childBalancer1.config).isEqualTo(childConfig1); + assertThat(childBalancer2.config).isEqualTo(childConfig2); + + // Receive an updated routing config. + config = + new XdsRoutingConfig( + Arrays.asList( + new Route(routeMatch1, "action_b"), + new Route(routeMatch2, "action_c"), + new Route(routeMatch3, "action_c")), + ImmutableMap.of("action_b", policyA, "action_c", policyC)); + xdsRoutingLoadBalancer + .handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setLoadBalancingPolicyConfig(config) + .build()); + + assertThat(childBalancer2.shutdown) + .isTrue(); // (immediate) shutdown because "action_b" changes policy (before ready) + assertThat(fakeClock.numPendingTasks()) + .isEqualTo(1); // (delayed) shutdown because "action_a" is removed + assertThat(childBalancer1.shutdown).isFalse(); + assertThat(childBalancers).hasSize(3); + FakeLoadBalancer childBalancer3 = childBalancers.get(1); + FakeLoadBalancer childBalancer4 = childBalancers.get(2); + assertThat(childBalancer3.name).isEqualTo("policy_a"); + assertThat(childBalancer3).isNotSameInstanceAs(childBalancer1); + assertThat(childBalancer4.name).isEqualTo("policy_c"); + + // Simulate subchannel state update from the leaf policy. + Subchannel subchannel1 = mock(Subchannel.class); + Subchannel subchannel2 = mock(Subchannel.class); + Subchannel subchannel3 = mock(Subchannel.class); + childBalancer1.deliverSubchannelState(subchannel1, ConnectivityState.READY); + childBalancer3.deliverSubchannelState(subchannel2, ConnectivityState.CONNECTING); + childBalancer4.deliverSubchannelState(subchannel3, ConnectivityState.READY); + + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); + verify(helper).updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); + RouteMatchingSubchannelPicker picker = (RouteMatchingSubchannelPicker) pickerCaptor.getValue(); + assertThat(picker.routePickers).hasSize(3); + assertThat( + picker.routePickers.get(routeMatch1) + .pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isSameInstanceAs(subchannel2); // routeMatch1 -> action_b -> policy_a -> subchannel2 + assertThat( + picker.routePickers.get(routeMatch2) + .pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isSameInstanceAs(subchannel3); // routeMatch2 -> action_c -> policy_c -> subchannel3 + assertThat( + picker.routePickers.get(routeMatch3) + .pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isSameInstanceAs(subchannel3); // routeMatch3 -> action_c -> policy_c -> subchannel3 + + // Error propagation from upstream policies. + Status error = Status.UNAVAILABLE.withDescription("network error"); + xdsRoutingLoadBalancer.handleNameResolutionError(error); + assertThat(childBalancer1.upstreamError).isNull(); + assertThat(childBalancer3.upstreamError).isEqualTo(error); + assertThat(childBalancer4.upstreamError).isEqualTo(error); + fakeClock.forwardTime( + XdsRoutingLoadBalancer.DELAYED_ACTION_DELETION_TIME_MINUTES, TimeUnit.MINUTES); + assertThat(childBalancer1.shutdown).isTrue(); + + xdsRoutingLoadBalancer.shutdown(); + assertThat(childBalancer3.shutdown).isTrue(); + assertThat(childBalancer4.shutdown).isTrue(); + } + + @Test + public void routeMatchingSubchannelPicker_typicalRouting() { + Subchannel subchannel1 = mock(Subchannel.class); + Subchannel subchannel2 = mock(Subchannel.class); + Subchannel subchannel3 = mock(Subchannel.class); + RouteMatchingSubchannelPicker routeMatchingPicker = + new RouteMatchingSubchannelPicker( + ImmutableMap.of( + routeMatch1, pickerOf(subchannel1), + routeMatch2, pickerOf(subchannel2), + routeMatch3, pickerOf(subchannel3))); + + PickSubchannelArgs args1 = + createPickSubchannelArgs( + "FooService", "barMethod", + ImmutableMap.of("user-agent", "gRPC-Java", "grpc-encoding", "gzip")); + assertThat(routeMatchingPicker.pickSubchannel(args1).getSubchannel()) + .isSameInstanceAs(subchannel1); + + PickSubchannelArgs args2 = + createPickSubchannelArgs( + "FooService", "bazMethod", + ImmutableMap.of("user-agent", "gRPC-Java", "custom-key", "custom-value")); + assertThat(routeMatchingPicker.pickSubchannel(args2).getSubchannel()) + .isSameInstanceAs(subchannel2); + + PickSubchannelArgs args3 = + createPickSubchannelArgs( + "FooService", "barMethod", + ImmutableMap.of("user-agent", "gRPC-Java", "custom-key", "custom-value")); + assertThat(routeMatchingPicker.pickSubchannel(args3).getSubchannel()) + .isSameInstanceAs(subchannel3); + + PickSubchannelArgs args4 = + createPickSubchannelArgs( + "BazService", "fooMethod", + Collections.emptyMap()); + assertThat(routeMatchingPicker.pickSubchannel(args4).getSubchannel()) + .isSameInstanceAs(subchannel3); + } + + private static SubchannelPicker pickerOf(final Subchannel subchannel) { + return new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withSubchannel(subchannel); + } + }; + } + + private static PickSubchannelArgs createPickSubchannelArgs( + String service, String method, Map headers) { + MethodDescriptor methodDescriptor = + MethodDescriptor.newBuilder() + .setType(MethodType.UNARY).setFullMethodName(service + "/" + method) + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) + .build(); + Metadata metadata = new Metadata(); + for (Map.Entry entry : headers.entrySet()) { + metadata.put( + Metadata.Key.of(entry.getKey(), Metadata.ASCII_STRING_MARSHALLER), entry.getValue()); + } + return new PickSubchannelArgsImpl(methodDescriptor, metadata, CallOptions.DEFAULT); + } + + private final class FakeLoadBalancerProvider extends LoadBalancerProvider { + private final String policyName; + + FakeLoadBalancerProvider(String policyName) { + this.policyName = policyName; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + FakeLoadBalancer balancer = new FakeLoadBalancer(policyName, helper); + childBalancers.add(balancer); + return balancer; + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 0; // doesn't matter + } + + @Override + public String getPolicyName() { + return policyName; + } + } + + private final class FakeLoadBalancer extends LoadBalancer { + private final String name; + private final Helper helper; + private Object config; + private Status upstreamError; + private boolean shutdown; + + FakeLoadBalancer(String name, Helper helper) { + this.name = name; + this.helper = helper; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + config = resolvedAddresses.getLoadBalancingPolicyConfig(); + } + + @Override + public void handleNameResolutionError(Status error) { + upstreamError = error; + } + + @Override + public void shutdown() { + shutdown = true; + childBalancers.remove(this); + } + + void deliverSubchannelState(final Subchannel subchannel, ConnectivityState state) { + SubchannelPicker picker = new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withSubchannel(subchannel); + } + }; + helper.updateBalancingState(state, picker); + } + } }