xds: implement routing policy with all supported types of matcher (#7130)

Implement xds_routing LB policy with all kinds of matchers (path, header, runtime faction) supported.
This commit is contained in:
Chengyuan Zhang 2020-06-19 23:55:49 +00:00 committed by GitHub
parent ae7a482d9a
commit 43cf77de83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 723 additions and 67 deletions

View File

@ -20,9 +20,12 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.MoreObjects.ToStringHelper;
import com.google.re2j.Pattern;
import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;
/**
@ -48,6 +51,27 @@ final class RouteMatch {
Collections.<HeaderMatcher>emptyList(), null);
}
/**
* Returns {@code true} if a request with the given path and headers passes all the rules
* specified by this RouteMatch.
*
* <p>The request's headers are given as a key-values mapping, where multiple values can
* be mapped to the same key.
*
* <p>Match is not deterministic if a runtime fraction match rule presents in this RouteMatch.
*/
boolean matches(String path, Map<String, Set<String>> headers) {
if (!pathMatch.matches(path)) {
return false;
}
for (HeaderMatcher headerMatcher : headerMatchers) {
if (!headerMatcher.matchesValue(headers.get(headerMatcher.getName()))) {
return false;
}
}
return fractionMatch == null || fractionMatch.matches();
}
PathMatcher getPathMatch() {
return pathMatch;
}
@ -105,6 +129,15 @@ final class RouteMatch {
this.regEx = regEx;
}
private boolean matches(String fullMethodName) {
if (path != null) {
return path.equals(fullMethodName);
} else if (prefix != null) {
return fullMethodName.startsWith(prefix);
}
return regEx.matches(fullMethodName);
}
@Nullable
String getPath() {
return path;
@ -196,6 +229,39 @@ final class RouteMatch {
this.isInvertedMatch = isInvertedMatch;
}
private boolean matchesValue(@Nullable Set<String> values) {
if (presentMatch != null) {
return (values == null) == presentMatch.equals(isInvertedMatch);
}
if (values == null) {
return false;
}
boolean baseMatch = false;
for (String value : values) {
if (exactMatch != null) {
baseMatch = exactMatch.equals(value);
} else if (safeRegExMatch != null) {
baseMatch = safeRegExMatch.matches(value);
} else if (rangeMatch != null) {
long numValue;
try {
numValue = Long.parseLong(value);
} catch (NumberFormatException ignored) {
continue;
}
baseMatch = rangeMatch.contains(numValue);
} else if (prefixMatch != null) {
baseMatch = value.startsWith(prefixMatch);
} else {
baseMatch = value.endsWith(suffixMatch);
}
if (baseMatch) {
break;
}
}
return baseMatch != isInvertedMatch;
}
String getName() {
return name;
}
@ -290,6 +356,10 @@ final class RouteMatch {
this.end = end;
}
boolean contains(long value) {
return value >= start && value < end;
}
long getStart() {
return start;
}
@ -329,10 +399,21 @@ final class RouteMatch {
static final class FractionMatcher {
private final int numerator;
private final int denominator;
private final ThreadSafeRandom rand;
FractionMatcher(int numerator, int denominator) {
this(numerator, denominator, ThreadSafeRandomImpl.instance);
}
@VisibleForTesting
FractionMatcher(int numerator, int denominator, ThreadSafeRandom rand) {
this.numerator = numerator;
this.denominator = denominator;
this.rand = rand;
}
private boolean matches() {
return rand.nextInt(denominator) < numerator;
}
int getNumerator() {

View File

@ -23,12 +23,17 @@ import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.grpc.ConnectivityState;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.util.ForwardingLoadBalancerHelper;
import io.grpc.util.GracefulSwitchLoadBalancer;
@ -37,87 +42,89 @@ import io.grpc.xds.XdsRoutingLoadBalancerProvider.Route;
import io.grpc.xds.XdsRoutingLoadBalancerProvider.XdsRoutingConfig;
import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
/** Load balancer for xds_routing policy. */
final class XdsRoutingLoadBalancer extends LoadBalancer {
@VisibleForTesting
static final int DELAYED_ACTION_DELETION_TIME_MINUTES = 15;
private final XdsLogger logger;
private final Helper helper;
private final Map<String, GracefulSwitchLoadBalancer> routeBalancers = new HashMap<>();
private final Map<String, RouteHelper> routeHelpers = new HashMap<>();
private final SynchronizationContext syncContext;
private final ScheduledExecutorService timeService;
private final Map<String, ChildLbState> childLbStates = new HashMap<>(); // keyed by action names
private Map<String, PolicySelection> actions = ImmutableMap.of();
private List<Route> routes = ImmutableList.of();
XdsRoutingLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
logger = XdsLogger.withLogId(
InternalLogId.allocate("xds-routing-lb", helper.getAuthority()));
logger.log(XdsLogLevel.INFO, "Created");
}
@Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
public void handleResolvedAddresses(final ResolvedAddresses resolvedAddresses) {
logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
XdsRoutingConfig xdsRoutingConfig =
(XdsRoutingConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
checkNotNull(xdsRoutingConfig, "Missing xds_routing lb config");
Map<String, PolicySelection> newActions = xdsRoutingConfig.actions;
for (String actionName : newActions.keySet()) {
PolicySelection action = newActions.get(actionName);
if (!actions.containsKey(actionName)) {
RouteHelper routeHelper = new RouteHelper();
GracefulSwitchLoadBalancer routeBalancer = new GracefulSwitchLoadBalancer(routeHelper);
routeBalancer.switchTo(action.getProvider());
routeHelpers.put(actionName, routeHelper);
routeBalancers.put(actionName, routeBalancer);
} else if (!action.getProvider().equals(actions.get(actionName).getProvider())) {
routeBalancers.get(actionName).switchTo(action.getProvider());
for (final String actionName : newActions.keySet()) {
final PolicySelection action = newActions.get(actionName);
if (!childLbStates.containsKey(actionName)) {
childLbStates.put(actionName, new ChildLbState(actionName, action.getProvider()));
} else {
childLbStates.get(actionName).reactivate(action.getProvider());
}
syncContext.execute(new Runnable() {
@Override
public void run() {
childLbStates.get(actionName).lb
.handleResolvedAddresses(
resolvedAddresses.toBuilder()
.setLoadBalancingPolicyConfig(action.getConfig())
.build());
}
});
}
this.routes = xdsRoutingConfig.routes;
this.actions = newActions;
for (String actionName : actions.keySet()) {
routeBalancers.get(actionName).handleResolvedAddresses(
resolvedAddresses.toBuilder()
.setLoadBalancingPolicyConfig(actions.get(actionName).getConfig())
.build());
Set<String> diff = Sets.difference(childLbStates.keySet(), newActions.keySet());
for (String actionName : diff) {
childLbStates.get(actionName).deactivate();
}
// Cleanup removed actions.
// TODO(zdapeng): cache removed actions for 15 minutes.
for (String actionName : routeBalancers.keySet()) {
if (!actions.containsKey(actionName)) {
routeBalancers.get(actionName).shutdown();
}
}
routeBalancers.keySet().retainAll(actions.keySet());
routeHelpers.keySet().retainAll(actions.keySet());
}
@Override
public void handleNameResolutionError(Status error) {
logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error);
if (routeBalancers.isEmpty()) {
helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error));
boolean gotoTransientFailure = true;
for (ChildLbState state : childLbStates.values()) {
if (!state.deactivated) {
gotoTransientFailure = false;
state.lb.handleNameResolutionError(error);
}
}
for (LoadBalancer routeBalancer : routeBalancers.values()) {
routeBalancer.handleNameResolutionError(error);
if (gotoTransientFailure) {
helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error));
}
}
@Override
public void shutdown() {
logger.log(XdsLogLevel.INFO, "Shutdown");
for (LoadBalancer routeBalancer : routeBalancers.values()) {
routeBalancer.shutdown();
for (ChildLbState state : childLbStates.values()) {
state.tearDown();
}
}
@ -131,10 +138,9 @@ final class XdsRoutingLoadBalancer extends LoadBalancer {
// Use LinkedHashMap to preserve the order of routes.
Map<RouteMatch, SubchannelPicker> routePickers = new LinkedHashMap<>();
for (Route route : routes) {
RouteHelper routeHelper = routeHelpers.get(route.getActionName());
routePickers.put(route.getRouteMatch(), routeHelper.currentPicker);
ConnectivityState routeState = routeHelper.currentState;
overallState = aggregateState(overallState, routeState);
ChildLbState state = childLbStates.get(route.getActionName());
routePickers.put(route.getRouteMatch(), state.currentPicker);
overallState = aggregateState(overallState, state.currentState);
}
if (overallState != null) {
SubchannelPicker picker = new RouteMatchingSubchannelPicker(routePickers);
@ -142,8 +148,9 @@ final class XdsRoutingLoadBalancer extends LoadBalancer {
}
}
@VisibleForTesting
@Nullable
private static ConnectivityState aggregateState(
static ConnectivityState aggregateState(
@Nullable ConnectivityState overallState, ConnectivityState childState) {
if (overallState == null) {
return childState;
@ -160,28 +167,94 @@ final class XdsRoutingLoadBalancer extends LoadBalancer {
return overallState;
}
/**
* The lb helper for a single route balancer.
*/
private final class RouteHelper extends ForwardingLoadBalancerHelper {
ConnectivityState currentState = CONNECTING;
SubchannelPicker currentPicker = BUFFER_PICKER;
private final class ChildLbState {
private final String name;
private final GracefulSwitchLoadBalancer lb;
private LoadBalancerProvider policyProvider;
private ConnectivityState currentState = CONNECTING;
private SubchannelPicker currentPicker = BUFFER_PICKER;
private boolean deactivated;
@Nullable
ScheduledHandle deletionTimer;
@Override
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
currentState = newState;
currentPicker = newPicker;
updateOverallBalancingState();
private ChildLbState(String name, LoadBalancerProvider policyProvider) {
this.name = name;
this.policyProvider = policyProvider;
lb = new GracefulSwitchLoadBalancer(new RouteHelper());
lb.switchTo(policyProvider);
}
@Override
protected Helper delegate() {
return helper;
void deactivate() {
if (deactivated) {
return;
}
class DeletionTask implements Runnable {
@Override
public void run() {
tearDown();
childLbStates.remove(name);
}
}
deletionTimer =
syncContext.schedule(
new DeletionTask(),
DELAYED_ACTION_DELETION_TIME_MINUTES,
TimeUnit.MINUTES,
timeService);
deactivated = true;
logger.log(XdsLogLevel.DEBUG, "Route action {0} deactivated", name);
}
void reactivate(LoadBalancerProvider policyProvider) {
if (deletionTimer != null && deletionTimer.isPending()) {
deletionTimer.cancel();
deactivated = false;
logger.log(XdsLogLevel.DEBUG, "Route action {0} reactivated", name);
}
if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) {
logger.log(
XdsLogLevel.DEBUG,
"Action {0} switching policy from {1} to {2}",
name, this.policyProvider.getPolicyName(), policyProvider.getPolicyName());
lb.switchTo(policyProvider);
this.policyProvider = policyProvider;
}
}
void tearDown() {
deactivated = true;
if (deletionTimer != null && deletionTimer.isPending()) {
deletionTimer.cancel();
}
lb.shutdown();
logger.log(XdsLogLevel.DEBUG, "Route action {0} deleted", name);
}
private final class RouteHelper extends ForwardingLoadBalancerHelper {
@Override
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
if (deactivated) {
return;
}
currentState = newState;
currentPicker = newPicker;
updateOverallBalancingState();
}
@Override
protected Helper delegate() {
return helper;
}
}
}
private static final class RouteMatchingSubchannelPicker extends SubchannelPicker {
@VisibleForTesting
static final class RouteMatchingSubchannelPicker extends SubchannelPicker {
@VisibleForTesting
final Map<RouteMatch, SubchannelPicker> routePickers;
RouteMatchingSubchannelPicker(Map<RouteMatch, SubchannelPicker> routePickers) {
@ -190,8 +263,28 @@ final class XdsRoutingLoadBalancer extends LoadBalancer {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
// TODO(chengyuanzhang): to be implemented.
return PickResult.withError(Status.INTERNAL.withDescription("routing picker unimplemented"));
// Index ASCII headers by keys.
Map<String, Set<String>> asciiHeaders = new HashMap<>();
Metadata headers = args.getHeaders();
for (String headerName : headers.keys()) {
if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
continue;
}
Set<String> headerValues = new HashSet<>();
Metadata.Key<String> key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER);
for (String value : headers.getAll(key)) {
headerValues.add(value);
}
asciiHeaders.put(headerName, headerValues);
}
for (Map.Entry<RouteMatch, SubchannelPicker> entry : routePickers.entrySet()) {
RouteMatch routeMatch = entry.getKey();
if (routeMatch.matches(
"/" + args.getMethodDescriptor().getFullMethodName(), asciiHeaders)) {
return entry.getValue().pickSubchannel(args);
}
}
return PickResult.withError(Status.UNAVAILABLE.withDescription("no matching route found"));
}
}
}

View File

@ -269,7 +269,8 @@ public final class XdsRoutingLoadBalancerProvider extends LoadBalancerProvider {
final List<Route> routes;
final Map<String, PolicySelection> actions;
private XdsRoutingConfig(List<Route> routes, Map<String, PolicySelection> actions) {
@VisibleForTesting
XdsRoutingConfig(List<Route> routes, Map<String, PolicySelection> actions) {
this.routes = ImmutableList.copyOf(routes);
this.actions = ImmutableMap.copyOf(actions);
}

View File

@ -0,0 +1,166 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import com.google.re2j.Pattern;
import io.grpc.xds.RouteMatch.FractionMatcher;
import io.grpc.xds.RouteMatch.HeaderMatcher;
import io.grpc.xds.RouteMatch.HeaderMatcher.Range;
import io.grpc.xds.RouteMatch.PathMatcher;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
/** Tests for {@link RouteMatch}. */
public class RouteMatchTest {
private final Map<String, Set<String>> headers = new HashMap<>();
@Before
public void setUp() {
headers.put("content-type", Collections.singleton("application/grpc"));
headers.put("grpc-encoding", Collections.singleton("gzip"));
headers.put("user-agent", Collections.singleton("gRPC-Java"));
headers.put("content-length", Collections.singleton("1000"));
headers.put("custom-key", new HashSet<>(Arrays.asList("custom-value1", "custom-value2")));
}
@Test
public void routeMatching_pathOnly() {
RouteMatch routeMatch1 =
new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.<HeaderMatcher>emptyList(), null);
assertThat(routeMatch1.matches("/FooService/barMethod", headers)).isTrue();
assertThat(routeMatch1.matches("/FooService/bazMethod", headers)).isFalse();
RouteMatch routeMatch2 =
new RouteMatch(
new PathMatcher(null, "/FooService/", null),
Collections.<HeaderMatcher>emptyList(), null);
assertThat(routeMatch2.matches("/FooService/barMethod", headers)).isTrue();
assertThat(routeMatch2.matches("/FooService/bazMethod", headers)).isTrue();
assertThat(routeMatch2.matches("/BarService/bazMethod", headers)).isFalse();
RouteMatch routeMatch3 =
new RouteMatch(
new PathMatcher(null, null, Pattern.compile(".*Foo.*")),
Collections.<HeaderMatcher>emptyList(), null);
assertThat(routeMatch3.matches("/FooService/barMethod", headers)).isTrue();
}
@Test
public void routeMatching_withHeaders() {
RouteMatch routeMatch1 = new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Arrays.asList(
new HeaderMatcher(
"grpc-encoding", "gzip", null, null, null, null, null, false),
new HeaderMatcher(
"content-type", null, Pattern.compile(".*grpc.*"), null, null, null,
null, false),
new HeaderMatcher(
"content-length", null, null, new Range(100, 10000), null, null, null, false),
new HeaderMatcher("user-agent", null, null, null, true, null, null, false),
new HeaderMatcher("custom-key", null, null, null, null, "custom-", null, false),
new HeaderMatcher("custom-key", null, null, null, null, null, "value2", false)),
null);
assertThat(routeMatch1.matches("/FooService/barMethod", headers)).isTrue();
RouteMatch routeMatch2 = new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.singletonList(
new HeaderMatcher(
"content-type", null, Pattern.compile(".*grpc.*"), null, null, null,
null, true)),
null);
assertThat(routeMatch2.matches("/FooService/barMethod", headers)).isFalse();
RouteMatch routeMatch3 = new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.singletonList(
new HeaderMatcher(
"user-agent", "gRPC-Go", null, null, null, null,
null, false)),
null);
assertThat(routeMatch3.matches("/FooService/barMethod", headers)).isFalse();
RouteMatch routeMatch4 = new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.singletonList(
new HeaderMatcher(
"user-agent", null, null, null, false, null,
null, false)),
null);
assertThat(routeMatch4.matches("/FooService/barMethod", headers)).isFalse();
RouteMatch routeMatch5 = new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.singletonList(
new HeaderMatcher(
"user-agent", null, null, null, false, null,
null, true)),
null);
assertThat(routeMatch5.matches("/FooService/barMethod", headers)).isTrue();
RouteMatch routeMatch6 = new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.singletonList(
new HeaderMatcher(
"user-agent", null, null, null, true, null,
null, true)),
null);
assertThat(routeMatch6.matches("/FooService/barMethod", headers)).isFalse();
}
@Test
public void routeMatching_withRuntimeFraction() {
RouteMatch routeMatch1 =
new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.<HeaderMatcher>emptyList(),
new FractionMatcher(100, 1000, new FakeRandom(50)));
assertThat(routeMatch1.matches("/FooService/barMethod", headers)).isTrue();
RouteMatch routeMatch2 =
new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Collections.<HeaderMatcher>emptyList(),
new FractionMatcher(100, 1000, new FakeRandom(100)));
assertThat(routeMatch2.matches("/FooService/barMethod", headers)).isFalse();
}
private static final class FakeRandom implements ThreadSafeRandom {
private final int value;
FakeRandom(int value) {
this.value = value;
}
@Override
public int nextInt(int bound) {
return value;
}
}
}

View File

@ -16,13 +16,328 @@
package io.grpc.xds;
import org.junit.Ignore;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableMap;
import io.grpc.CallOptions;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancerProvider;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
import io.grpc.internal.PickSubchannelArgsImpl;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.testing.TestMethodDescriptors;
import io.grpc.xds.RouteMatch.HeaderMatcher;
import io.grpc.xds.RouteMatch.PathMatcher;
import io.grpc.xds.XdsRoutingLoadBalancer.RouteMatchingSubchannelPicker;
import io.grpc.xds.XdsRoutingLoadBalancerProvider.Route;
import io.grpc.xds.XdsRoutingLoadBalancerProvider.XdsRoutingConfig;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Tests for {@link XdsRoutingLoadBalancer}. */
@RunWith(JUnit4.class)
@Ignore
public class XdsRoutingLoadBalancerTest {
// TODO(chengyuanzhang)
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
});
private final FakeClock fakeClock = new FakeClock();
@Mock
private LoadBalancer.Helper helper;
private RouteMatch routeMatch1 =
new RouteMatch(
new PathMatcher("/FooService/barMethod", null, null),
Arrays.asList(
new HeaderMatcher("user-agent", "gRPC-Java", null, null, null, null, null, false),
new HeaderMatcher("grpc-encoding", "gzip", null, null, null, null, null, false)),
null);
private RouteMatch routeMatch2 =
new RouteMatch(
new PathMatcher("/FooService/bazMethod", null, null),
Collections.<HeaderMatcher>emptyList(),
null);
private RouteMatch routeMatch3 =
new RouteMatch(
new PathMatcher(null, "/", null),
Collections.<HeaderMatcher>emptyList(),
null);
private List<FakeLoadBalancer> childBalancers = new ArrayList<>();
private LoadBalancer xdsRoutingLoadBalancer;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService());
xdsRoutingLoadBalancer = new XdsRoutingLoadBalancer(helper);
}
@Test
public void typicalWorkflow() {
Object childConfig1 = new Object();
Object childConfig2 = new Object();
PolicySelection policyA =
new PolicySelection(new FakeLoadBalancerProvider("policy_a"), null, childConfig1);
PolicySelection policyB =
new PolicySelection(new FakeLoadBalancerProvider("policy_b"), null, childConfig2);
PolicySelection policyC =
new PolicySelection(new FakeLoadBalancerProvider("policy_c"), null , null);
XdsRoutingConfig config =
new XdsRoutingConfig(
Arrays.asList(
new Route(routeMatch1, "action_a"),
new Route(routeMatch2, "action_b"),
new Route(routeMatch3, "action_a")),
ImmutableMap.of("action_a", policyA, "action_b", policyB));
xdsRoutingLoadBalancer
.handleResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(Collections.<EquivalentAddressGroup>emptyList())
.setLoadBalancingPolicyConfig(config)
.build());
assertThat(childBalancers).hasSize(2);
FakeLoadBalancer childBalancer1 = childBalancers.get(0);
FakeLoadBalancer childBalancer2 = childBalancers.get(1);
assertThat(childBalancer1.name).isEqualTo("policy_a");
assertThat(childBalancer2.name).isEqualTo("policy_b");
assertThat(childBalancer1.config).isEqualTo(childConfig1);
assertThat(childBalancer2.config).isEqualTo(childConfig2);
// Receive an updated routing config.
config =
new XdsRoutingConfig(
Arrays.asList(
new Route(routeMatch1, "action_b"),
new Route(routeMatch2, "action_c"),
new Route(routeMatch3, "action_c")),
ImmutableMap.of("action_b", policyA, "action_c", policyC));
xdsRoutingLoadBalancer
.handleResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(Collections.<EquivalentAddressGroup>emptyList())
.setLoadBalancingPolicyConfig(config)
.build());
assertThat(childBalancer2.shutdown)
.isTrue(); // (immediate) shutdown because "action_b" changes policy (before ready)
assertThat(fakeClock.numPendingTasks())
.isEqualTo(1); // (delayed) shutdown because "action_a" is removed
assertThat(childBalancer1.shutdown).isFalse();
assertThat(childBalancers).hasSize(3);
FakeLoadBalancer childBalancer3 = childBalancers.get(1);
FakeLoadBalancer childBalancer4 = childBalancers.get(2);
assertThat(childBalancer3.name).isEqualTo("policy_a");
assertThat(childBalancer3).isNotSameInstanceAs(childBalancer1);
assertThat(childBalancer4.name).isEqualTo("policy_c");
// Simulate subchannel state update from the leaf policy.
Subchannel subchannel1 = mock(Subchannel.class);
Subchannel subchannel2 = mock(Subchannel.class);
Subchannel subchannel3 = mock(Subchannel.class);
childBalancer1.deliverSubchannelState(subchannel1, ConnectivityState.READY);
childBalancer3.deliverSubchannelState(subchannel2, ConnectivityState.CONNECTING);
childBalancer4.deliverSubchannelState(subchannel3, ConnectivityState.READY);
ArgumentCaptor<SubchannelPicker> pickerCaptor = ArgumentCaptor.forClass(null);
verify(helper).updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
RouteMatchingSubchannelPicker picker = (RouteMatchingSubchannelPicker) pickerCaptor.getValue();
assertThat(picker.routePickers).hasSize(3);
assertThat(
picker.routePickers.get(routeMatch1)
.pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel())
.isSameInstanceAs(subchannel2); // routeMatch1 -> action_b -> policy_a -> subchannel2
assertThat(
picker.routePickers.get(routeMatch2)
.pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel())
.isSameInstanceAs(subchannel3); // routeMatch2 -> action_c -> policy_c -> subchannel3
assertThat(
picker.routePickers.get(routeMatch3)
.pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel())
.isSameInstanceAs(subchannel3); // routeMatch3 -> action_c -> policy_c -> subchannel3
// Error propagation from upstream policies.
Status error = Status.UNAVAILABLE.withDescription("network error");
xdsRoutingLoadBalancer.handleNameResolutionError(error);
assertThat(childBalancer1.upstreamError).isNull();
assertThat(childBalancer3.upstreamError).isEqualTo(error);
assertThat(childBalancer4.upstreamError).isEqualTo(error);
fakeClock.forwardTime(
XdsRoutingLoadBalancer.DELAYED_ACTION_DELETION_TIME_MINUTES, TimeUnit.MINUTES);
assertThat(childBalancer1.shutdown).isTrue();
xdsRoutingLoadBalancer.shutdown();
assertThat(childBalancer3.shutdown).isTrue();
assertThat(childBalancer4.shutdown).isTrue();
}
@Test
public void routeMatchingSubchannelPicker_typicalRouting() {
Subchannel subchannel1 = mock(Subchannel.class);
Subchannel subchannel2 = mock(Subchannel.class);
Subchannel subchannel3 = mock(Subchannel.class);
RouteMatchingSubchannelPicker routeMatchingPicker =
new RouteMatchingSubchannelPicker(
ImmutableMap.of(
routeMatch1, pickerOf(subchannel1),
routeMatch2, pickerOf(subchannel2),
routeMatch3, pickerOf(subchannel3)));
PickSubchannelArgs args1 =
createPickSubchannelArgs(
"FooService", "barMethod",
ImmutableMap.of("user-agent", "gRPC-Java", "grpc-encoding", "gzip"));
assertThat(routeMatchingPicker.pickSubchannel(args1).getSubchannel())
.isSameInstanceAs(subchannel1);
PickSubchannelArgs args2 =
createPickSubchannelArgs(
"FooService", "bazMethod",
ImmutableMap.of("user-agent", "gRPC-Java", "custom-key", "custom-value"));
assertThat(routeMatchingPicker.pickSubchannel(args2).getSubchannel())
.isSameInstanceAs(subchannel2);
PickSubchannelArgs args3 =
createPickSubchannelArgs(
"FooService", "barMethod",
ImmutableMap.of("user-agent", "gRPC-Java", "custom-key", "custom-value"));
assertThat(routeMatchingPicker.pickSubchannel(args3).getSubchannel())
.isSameInstanceAs(subchannel3);
PickSubchannelArgs args4 =
createPickSubchannelArgs(
"BazService", "fooMethod",
Collections.<String, String>emptyMap());
assertThat(routeMatchingPicker.pickSubchannel(args4).getSubchannel())
.isSameInstanceAs(subchannel3);
}
private static SubchannelPicker pickerOf(final Subchannel subchannel) {
return new SubchannelPicker() {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
return PickResult.withSubchannel(subchannel);
}
};
}
private static PickSubchannelArgs createPickSubchannelArgs(
String service, String method, Map<String, String> headers) {
MethodDescriptor<Void, Void> methodDescriptor =
MethodDescriptor.<Void, Void>newBuilder()
.setType(MethodType.UNARY).setFullMethodName(service + "/" + method)
.setRequestMarshaller(TestMethodDescriptors.voidMarshaller())
.setResponseMarshaller(TestMethodDescriptors.voidMarshaller())
.build();
Metadata metadata = new Metadata();
for (Map.Entry<String, String> entry : headers.entrySet()) {
metadata.put(
Metadata.Key.of(entry.getKey(), Metadata.ASCII_STRING_MARSHALLER), entry.getValue());
}
return new PickSubchannelArgsImpl(methodDescriptor, metadata, CallOptions.DEFAULT);
}
private final class FakeLoadBalancerProvider extends LoadBalancerProvider {
private final String policyName;
FakeLoadBalancerProvider(String policyName) {
this.policyName = policyName;
}
@Override
public LoadBalancer newLoadBalancer(Helper helper) {
FakeLoadBalancer balancer = new FakeLoadBalancer(policyName, helper);
childBalancers.add(balancer);
return balancer;
}
@Override
public boolean isAvailable() {
return true;
}
@Override
public int getPriority() {
return 0; // doesn't matter
}
@Override
public String getPolicyName() {
return policyName;
}
}
private final class FakeLoadBalancer extends LoadBalancer {
private final String name;
private final Helper helper;
private Object config;
private Status upstreamError;
private boolean shutdown;
FakeLoadBalancer(String name, Helper helper) {
this.name = name;
this.helper = helper;
}
@Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
config = resolvedAddresses.getLoadBalancingPolicyConfig();
}
@Override
public void handleNameResolutionError(Status error) {
upstreamError = error;
}
@Override
public void shutdown() {
shutdown = true;
childBalancers.remove(this);
}
void deliverSubchannelState(final Subchannel subchannel, ConnectivityState state) {
SubchannelPicker picker = new SubchannelPicker() {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
return PickResult.withSubchannel(subchannel);
}
};
helper.updateBalancingState(state, picker);
}
}
}