xds:Make Ring Hash LB a petiole policy (#10610)

* Update picker logic per A61 that it no longer pays attention to the first 2 elements, but rather takes the first ring element not in TF and uses that.
---------
Pulled in by rebase:
Eric Anderson  (android: Remove unneeded proguard rule 44723b6)
Terry Wilson (stub: Deprecate StreamObservers b5434e8)
This commit is contained in:
Larry Safran 2023-11-09 13:46:52 -08:00 committed by GitHub
parent 0346b40e4e
commit dfdd50bc79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 916 additions and 885 deletions

View File

@ -102,6 +102,7 @@ final class PickFirstLoadBalancer extends LoadBalancer {
subchannel.shutdown();
subchannel = null;
}
// NB(lukaszx0) Whether we should propagate the error unconditionally is arguable. It's fine
// for time being.
updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error)));

View File

@ -181,4 +181,8 @@ public final class GracefulSwitchLoadBalancer extends ForwardingLoadBalancer {
pendingLb.shutdown();
currentLb.shutdown();
}
public String delegateType() {
return delegate().getClass().getSimpleName();
}
}

View File

@ -26,6 +26,7 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Internal;
@ -57,11 +58,9 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
private final Map<Object, ChildLbState> childLbStates = new LinkedHashMap<>();
private final Helper helper;
// Set to true if currently in the process of handling resolved addresses.
@VisibleForTesting
protected boolean resolvingAddresses;
protected final PickFirstLoadBalancerProvider pickFirstLbProvider =
new PickFirstLoadBalancerProvider();
protected final LoadBalancerProvider pickFirstLbProvider = new PickFirstLoadBalancerProvider();
protected ConnectivityState currentConnectivityState;
@ -85,6 +84,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
* Generally, the only reason to override this is to expose it to a test of a LB in a different
* package.
*/
protected ImmutableMap<Object, ChildLbState> getImmutableChildMap() {
return ImmutableMap.copyOf(childLbStates);
}
@VisibleForTesting
protected Collection<ChildLbState> getChildLbStates() {
return childLbStates.values();
@ -93,8 +96,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
/**
* Generally, the only reason to override this is to expose it to a test of a LB in a
* different package.
*/
*/
protected ChildLbState getChildLbState(Object key) {
if (key == null) {
return null;
@ -125,7 +127,8 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
if (existingChildLbState != null) {
childLbMap.put(endpoint, existingChildLbState);
} else {
childLbMap.put(endpoint, createChildLbState(endpoint, null, getInitialPicker()));
childLbMap.put(endpoint,
createChildLbState(endpoint, null, getInitialPicker(), resolvedAddresses));
}
}
return childLbMap;
@ -135,7 +138,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
* Override to create an instance of a subclass.
*/
protected ChildLbState createChildLbState(Object key, Object policyConfig,
SubchannelPicker initialPicker) {
SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) {
return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker);
}
@ -146,7 +149,20 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
try {
resolvingAddresses = true;
return acceptResolvedAddressesInternal(resolvedAddresses);
// process resolvedAddresses to update children
AcceptResolvedAddressRetVal acceptRetVal =
acceptResolvedAddressesInternal(resolvedAddresses);
if (!acceptRetVal.status.isOk()) {
return acceptRetVal.status;
}
// Update the picker and our connectivity state
updateOverallBalancingState();
// shutdown removed children
shutdownRemoved(acceptRetVal.removedChildren);
return acceptRetVal.status;
} finally {
resolvingAddresses = false;
}
@ -161,15 +177,18 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
*/
protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses,
Object childConfig) {
Endpoint endpointKey;
if (key instanceof EquivalentAddressGroup) {
key = new Endpoint((EquivalentAddressGroup) key);
endpointKey = new Endpoint((EquivalentAddressGroup) key);
} else {
checkArgument(key instanceof Endpoint, "key is wrong type");
endpointKey = (Endpoint) key;
}
checkArgument(key instanceof Endpoint, "key is wrong type");
// Retrieve the non-stripped version
EquivalentAddressGroup eagToUse = null;
for (EquivalentAddressGroup currEag : resolvedAddresses.getAddresses()) {
if (key.equals(new Endpoint(currEag))) {
if (endpointKey.equals(new Endpoint(currEag))) {
eagToUse = currEag;
break;
}
@ -183,15 +202,21 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
.build();
}
private Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) {
/**
* This does the work to update the child map and calculate which children have been removed.
* You must call {@link #updateOverallBalancingState} to update the picker
* and call {@link #shutdownRemoved(List)} to shutdown the endpoints that have been removed.
*/
protected AcceptResolvedAddressRetVal acceptResolvedAddressesInternal(
ResolvedAddresses resolvedAddresses) {
logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses);
Map<Object, ChildLbState> newChildren = createChildLbMap(resolvedAddresses);
if (newChildren.isEmpty()) {
Status unavailableStatus = Status.UNAVAILABLE.withDescription(
"NameResolver returned no usable address. " + resolvedAddresses);
"NameResolver returned no usable address. " + resolvedAddresses);
handleNameResolutionError(unavailableStatus);
return unavailableStatus;
return new AcceptResolvedAddressRetVal(unavailableStatus, null);
}
// Do adds and updates
@ -204,33 +229,44 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
} else {
// Reuse the existing one
ChildLbState existingChildLbState = childLbStates.get(key);
if (existingChildLbState.isDeactivated()) {
if (existingChildLbState.isDeactivated() && reactivateChildOnReuse()) {
existingChildLbState.reactivate(childPolicyProvider);
}
}
LoadBalancer childLb = childLbStates.get(key).lb;
ChildLbState childLbState = childLbStates.get(key);
ResolvedAddresses childAddresses = getChildAddresses(key, resolvedAddresses, childConfig);
childLbStates.get(key).setResolvedAddresses(childAddresses); // update child state
childLb.handleResolvedAddresses(childAddresses); // update child LB
childLbStates.get(key).setResolvedAddresses(childAddresses); // update child
if (!childLbState.deactivated) {
childLbState.lb.handleResolvedAddresses(childAddresses); // update child LB
}
}
List<ChildLbState> removedChildren = new ArrayList<>();
// Do removals
for (Object key : ImmutableList.copyOf(childLbStates.keySet())) {
if (!newChildren.containsKey(key)) {
childLbStates.get(key).deactivate();
ChildLbState childLbState = childLbStates.get(key);
childLbState.deactivate();
removedChildren.add(childLbState);
}
}
// Must update channel picker before return so that new RPCs will not be routed to deleted
// clusters and resolver can remove them in service config.
updateOverallBalancingState();
return Status.OK;
return new AcceptResolvedAddressRetVal(Status.OK, removedChildren);
}
protected void shutdownRemoved(List<ChildLbState> removedChildren) {
// Do shutdowns after updating picker to reduce the chance of failing an RPC by picking a
// subchannel that has been shutdown.
for (ChildLbState childLbState : removedChildren) {
childLbState.shutdown();
}
}
@Override
public void handleNameResolutionError(Status error) {
if (currentConnectivityState != READY) {
updateHelperBalancingState(TRANSIENT_FAILURE, getErrorPicker(error));
helper.updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error));
}
}
@ -240,12 +276,22 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
/**
* If true, then when a subchannel state changes to idle, the corresponding child will
* have requestConnection called on its LB.
* have requestConnection called on its LB. Also causes the PickFirstLB to be created when
* the child is created or reused.
*/
protected boolean reconnectOnIdle() {
return true;
}
/**
* If true, then when {@link #acceptResolvedAddresses} sees a key that was already part of the
* child map which is deactivated, it will call reactivate on the child.
* If false, it will leave it deactivated.
*/
protected boolean reactivateChildOnReuse() {
return true;
}
@Override
public void shutdown() {
logger.log(Level.INFO, "Shutdown");
@ -265,17 +311,13 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
childPickers.put(childLbState.key, childLbState.currentPicker);
overallState = aggregateState(overallState, childLbState.currentState);
}
if (overallState != null) {
helper.updateBalancingState(overallState, getSubchannelPicker(childPickers));
currentConnectivityState = overallState;
}
}
protected final void updateHelperBalancingState(ConnectivityState newState,
SubchannelPicker newPicker) {
helper.updateBalancingState(newState, newPicker);
}
@Nullable
protected static ConnectivityState aggregateState(
@Nullable ConnectivityState overallState, ConnectivityState childState) {
@ -332,20 +374,31 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
private final Object key;
private ResolvedAddresses resolvedAddresses;
private final Object config;
private final GracefulSwitchLoadBalancer lb;
private LoadBalancerProvider policyProvider;
private ConnectivityState currentState = CONNECTING;
private final LoadBalancerProvider policyProvider;
private ConnectivityState currentState;
private SubchannelPicker currentPicker;
private boolean deactivated;
public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
SubchannelPicker initialPicker) {
this(key, policyProvider, childConfig, initialPicker, null, false);
}
public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
SubchannelPicker initialPicker, ResolvedAddresses resolvedAddrs, boolean deactivated) {
this.key = key;
this.policyProvider = policyProvider;
lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper());
lb.switchTo(policyProvider);
currentPicker = initialPicker;
config = childConfig;
this.deactivated = deactivated;
this.currentPicker = initialPicker;
this.config = childConfig;
this.lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper());
this.currentState = deactivated ? IDLE : CONNECTING;
this.resolvedAddresses = resolvedAddrs;
if (!deactivated) {
lb.switchTo(policyProvider);
}
}
@Override
@ -365,6 +418,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return config;
}
protected GracefulSwitchLoadBalancer getLb() {
return lb;
}
public LoadBalancerProvider getPolicyProvider() {
return policyProvider;
}
@ -399,34 +456,41 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
deactivated = true;
}
protected void markReactivated() {
deactivated = false;
}
protected void setResolvedAddresses(ResolvedAddresses newAddresses) {
checkNotNull(newAddresses, "Missing address list for child");
resolvedAddresses = newAddresses;
}
/**
* The default implementation. This not only marks the lb policy as not active, it also removes
* this child from the map of children maintained by the petiole policy.
*
* <p>Note that this does not explicitly shutdown this child. That will generally be done by
* acceptResolvedAddresses on the LB, but can also be handled by an override such as is done
* in <a href=" https://github.com/grpc/grpc-java/blob/master/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java">ClusterManagerLoadBalancer</a>.
*
* <p>If you plan to reactivate, you will probably want to override this to not call
* childLbStates.remove() and handle that cleanup another way.
*/
protected void deactivate() {
if (deactivated) {
return;
}
shutdown();
childLbStates.remove(key);
childLbStates.remove(key); // This means it can't be reactivated again
deactivated = true;
logger.log(Level.FINE, "Child balancer {0} deactivated", key);
}
/**
* This base implementation does nothing but reset the flag. If you really want to both
* deactivate and reactivate you should override them both.
*/
protected void reactivate(LoadBalancerProvider policyProvider) {
if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) {
Object[] objects = {
key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()};
logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects);
lb.switchTo(policyProvider);
this.policyProvider = policyProvider;
} else {
logger.log(Level.FINE, "Child balancer {0} reactivated", key);
lb.acceptResolvedAddresses(resolvedAddresses);
}
deactivated = false;
}
@ -443,6 +507,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
* <p>The ChildLbState updates happen during updateBalancingState. Otherwise, it is doing
* simple forwarding.
*/
protected ResolvedAddresses getResolvedAddresses() {
return resolvedAddresses;
}
private final class ChildLbStateHelper extends ForwardingLoadBalancerHelper {
@Override
@ -482,7 +550,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
final String[] addrs;
final int hashCode;
Endpoint(EquivalentAddressGroup eag) {
public Endpoint(EquivalentAddressGroup eag) {
checkNotNull(eag, "eag");
addrs = new String[eag.getAddresses().size()];
@ -525,4 +593,14 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return Arrays.toString(addrs);
}
}
protected static class AcceptResolvedAddressRetVal {
public final Status status;
public final List<ChildLbState> removedChildren;
public AcceptResolvedAddressRetVal(Status status, List<ChildLbState> removedChildren) {
this.status = status;
this.removedChildren = removedChildren;
}
}
}

View File

@ -16,6 +16,7 @@
package io.grpc.util;
import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
@ -25,9 +26,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Internal;
import io.grpc.LoadBalancer;
@ -48,10 +47,6 @@ import javax.annotation.Nonnull;
*/
@Internal
public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
@VisibleForTesting
static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.create("state-info");
private final Random random;
protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK);
@ -132,7 +127,7 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
private volatile int index;
public ReadyPicker(List<SubchannelPicker> list, int startIndex) {
Preconditions.checkArgument(!list.isEmpty(), "empty list");
checkArgument(!list.isEmpty(), "empty list");
this.subchannelPickers = list;
this.index = startIndex - 1;
}

View File

@ -74,6 +74,7 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@ -352,6 +353,19 @@ public class RoundRobinLoadBalancerTest {
verifyNoMoreInteractions(mockHelper);
}
@Test
public void removingAddressShutsdownSubchannel() {
acceptAddresses(servers, affinity);
final Subchannel subchannel2 = subchannels.get(Collections.singletonList(servers.get(2)));
InOrder inOrder = Mockito.inOrder(mockHelper, subchannel2);
// send LB only the first 2 addresses
List<EquivalentAddressGroup> svs2 = Arrays.asList(servers.get(0), servers.get(1));
acceptAddresses(svs2, affinity);
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any());
inOrder.verify(subchannel2).shutdown();
}
@Test
public void pickerRoundRobin() throws Exception {
Subchannel subchannel = mock(Subchannel.class);

View File

@ -52,6 +52,7 @@ import java.util.Map;
public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>();
protected final Map<Subchannel, Subchannel> realToMockSubChannelMap = new HashMap<>();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
@ -99,15 +100,20 @@ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
public Subchannel createSubchannel(CreateSubchannelArgs args) {
Subchannel subchannel = getSubchannelMap().get(args.getAddresses());
if (subchannel == null) {
TestSubchannel delegate = new TestSubchannel(args);
TestSubchannel delegate = createRealSubchannel(args);
subchannel = mock(Subchannel.class, delegatesTo(delegate));
getSubchannelMap().put(args.getAddresses(), subchannel);
getMockToRealSubChannelMap().put(subchannel, delegate);
realToMockSubChannelMap.put(delegate, subchannel);
}
return subchannel;
}
protected TestSubchannel createRealSubchannel(CreateSubchannelArgs args) {
return new TestSubchannel(args);
}
@Override
public void refreshNameResolution() {
// no-op
@ -122,7 +128,7 @@ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
return "Test Helper";
}
private class TestSubchannel extends ForwardingSubchannel {
protected class TestSubchannel extends ForwardingSubchannel {
CreateSubchannelArgs args;
Channel channel;

View File

@ -93,6 +93,31 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer {
return newChildPolicies;
}
/**
* This is like the parent except that it doesn't shutdown the removed children since we want that
* to be done by the timer.
*/
@Override
public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
try {
resolvingAddresses = true;
// process resolvedAddresses to update children
AcceptResolvedAddressRetVal acceptRetVal =
acceptResolvedAddressesInternal(resolvedAddresses);
if (!acceptRetVal.status.isOk()) {
return acceptRetVal.status;
}
// Update the picker
updateOverallBalancingState();
return acceptRetVal.status;
} finally {
resolvingAddresses = false;
}
}
@Override
protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) {
return new SubchannelPicker() {

View File

@ -145,13 +145,13 @@ final class LeastRequestLoadBalancer extends MultiChildLoadBalancer {
@Override
protected ChildLbState createChildLbState(Object key, Object policyConfig,
SubchannelPicker initialPicker) {
SubchannelPicker initialPicker, ResolvedAddresses unused) {
return new LeastRequestLbState(key, pickFirstLbProvider, policyConfig, initialPicker);
}
private void updateBalancingState(ConnectivityState state, LeastRequestPicker picker) {
if (state != currentConnectivityState || !picker.isEquivalentTo(currentPicker)) {
super.updateHelperBalancingState(state, picker);
getHelper().updateBalancingState(state, picker);
currentConnectivityState = state;
currentPicker = picker;
}

View File

@ -27,27 +27,28 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.base.MoreObjects;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multiset;
import com.google.common.collect.Sets;
import com.google.common.primitives.UnsignedInteger;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.util.GracefulSwitchLoadBalancer;
import io.grpc.util.MultiChildLoadBalancer;
import io.grpc.xds.XdsLogger.XdsLogLevel;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@ -60,9 +61,7 @@ import javax.annotation.Nullable;
* number of times proportional to its weight. With the ring partitioned appropriately, the
* addition or removal of one host from a set of N hosts will affect only 1/N requests.
*/
final class RingHashLoadBalancer extends LoadBalancer {
private static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.create("state-info");
final class RingHashLoadBalancer extends MultiChildLoadBalancer {
private static final Status RPC_HASH_NOT_FOUND =
Status.INTERNAL.withDescription("RPC hash not found. Probably a bug because xds resolver"
+ " config selector always generates a hash.");
@ -70,16 +69,10 @@ final class RingHashLoadBalancer extends LoadBalancer {
private final XdsLogger logger;
private final SynchronizationContext syncContext;
private final Map<EquivalentAddressGroup, Subchannel> subchannels = new HashMap<>();
private final Helper helper;
private List<RingEntry> ring;
private ConnectivityState currentState;
private Iterator<Subchannel> connectionAttemptIterator = subchannels.values().iterator();
private final Random random = new Random();
RingHashLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
super(helper);
syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
logger = XdsLogger.withLogId(InternalLogId.allocate("ring_hash_lb", helper.getAuthority()));
logger.log(XdsLogLevel.INFO, "Created");
@ -94,83 +87,159 @@ final class RingHashLoadBalancer extends LoadBalancer {
return addressValidityStatus;
}
Map<EquivalentAddressGroup, EquivalentAddressGroup> latestAddrs = stripAttrs(addrList);
Set<EquivalentAddressGroup> removedAddrs =
Sets.newHashSet(Sets.difference(subchannels.keySet(), latestAddrs.keySet()));
RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
Map<EquivalentAddressGroup, Long> serverWeights = new HashMap<>();
long totalWeight = 0L;
for (EquivalentAddressGroup eag : addrList) {
Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT);
// Support two ways of server weighing: either multiple instances of the same address
// or each address contains a per-address weight attribute. If a weight is not provided,
// each occurrence of the address will be counted a weight value of one.
if (weight == null) {
weight = 1L;
}
totalWeight += weight;
EquivalentAddressGroup addrKey = stripAttrs(eag);
if (serverWeights.containsKey(addrKey)) {
serverWeights.put(addrKey, serverWeights.get(addrKey) + weight);
} else {
serverWeights.put(addrKey, weight);
AcceptResolvedAddressRetVal acceptRetVal;
try {
resolvingAddresses = true;
// Update the child list by creating-adding, updating addresses, and removing
acceptRetVal = super.acceptResolvedAddressesInternal(resolvedAddresses);
if (!acceptRetVal.status.isOk()) {
addressValidityStatus = Status.UNAVAILABLE.withDescription(
"Ring hash lb error: EDS resolution was successful, but was not accepted by base class"
+ " (" + acceptRetVal.status + ")");
handleNameResolutionError(addressValidityStatus);
return addressValidityStatus;
}
Subchannel existingSubchannel = subchannels.get(addrKey);
if (existingSubchannel != null) {
existingSubchannel.updateAddresses(Collections.singletonList(eag));
continue;
// Now do the ringhash specific logic with weights and building the ring
RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
if (config == null) {
throw new IllegalArgumentException("Missing RingHash configuration");
}
Attributes attr = Attributes.newBuilder().set(
STATE_INFO, new Ref<>(ConnectivityStateInfo.forNonError(IDLE))).build();
final Subchannel subchannel = helper.createSubchannel(
CreateSubchannelArgs.newBuilder().setAddresses(eag).setAttributes(attr).build());
subchannel.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
processSubchannelState(subchannel, newState);
Map<EquivalentAddressGroup, Long> serverWeights = new HashMap<>();
long totalWeight = 0L;
for (EquivalentAddressGroup eag : addrList) {
Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT);
// Support two ways of server weighing: either multiple instances of the same address
// or each address contains a per-address weight attribute. If a weight is not provided,
// each occurrence of the address will be counted a weight value of one.
if (weight == null) {
weight = 1L;
}
});
subchannels.put(addrKey, subchannel);
}
long minWeight = Collections.min(serverWeights.values());
double normalizedMinWeight = (double) minWeight / totalWeight;
// Scale up the number of hashes per host such that the least-weighted host gets a whole
// number of hashes on the the ring. Other hosts might not end up with whole numbers, and
// that's fine (the ring-building algorithm can handle this). This preserves the original
// implementation's behavior: when weights aren't provided, all hosts should get an equal
// number of hashes. In the case where this number exceeds the max_ring_size, it's scaled
// back down to fit.
double scale = Math.min(
Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight,
(double) config.maxRingSize);
ring = buildRing(serverWeights, totalWeight, scale);
totalWeight += weight;
EquivalentAddressGroup addrKey = stripAttrs(eag);
if (serverWeights.containsKey(addrKey)) {
serverWeights.put(addrKey, serverWeights.get(addrKey) + weight);
} else {
serverWeights.put(addrKey, weight);
}
}
// Calculate scale
long minWeight = Collections.min(serverWeights.values());
double normalizedMinWeight = (double) minWeight / totalWeight;
// Scale up the number of hashes per host such that the least-weighted host gets a whole
// number of hashes on the the ring. Other hosts might not end up with whole numbers, and
// that's fine (the ring-building algorithm can handle this). This preserves the original
// implementation's behavior: when weights aren't provided, all hosts should get an equal
// number of hashes. In the case where this number exceeds the max_ring_size, it's scaled
// back down to fit.
double scale = Math.min(
Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight,
(double) config.maxRingSize);
// Shut down subchannels for delisted addresses.
List<Subchannel> removedSubchannels = new ArrayList<>();
for (EquivalentAddressGroup addr : removedAddrs) {
removedSubchannels.add(subchannels.remove(addr));
}
// If we need to proactively start connecting, iterate through all the subchannels, starting
// at a random position.
// Alternatively, we should better start at the same position.
connectionAttemptIterator = subchannels.values().iterator();
int randomAdvance = random.nextInt(subchannels.size());
while (randomAdvance-- > 0) {
connectionAttemptIterator.next();
}
// Build the ring
ring = buildRing(serverWeights, totalWeight, scale);
// Update the picker before shutting down the subchannels, to reduce the chance of race
// between picking a subchannel and shutting it down.
updateBalancingState();
for (Subchannel subchann : removedSubchannels) {
shutdownSubchannel(subchann);
// Must update channel picker before return so that new RPCs will not be routed to deleted
// clusters and resolver can remove them in service config.
updateOverallBalancingState();
shutdownRemoved(acceptRetVal.removedChildren);
} finally {
this.resolvingAddresses = false;
}
return Status.OK;
}
/**
* Updates the overall balancing state by aggregating the connectivity states of all subchannels.
*
* <p>Aggregation rules (in order of dominance):
* <ol>
* <li>If there is at least one subchannel in READY state, overall state is READY</li>
* <li>If there are <em>2 or more</em> subchannels in TRANSIENT_FAILURE, overall state is
* TRANSIENT_FAILURE (to allow timely failover to another policy)</li>
* <li>If there is at least one subchannel in CONNECTING state, overall state is
* CONNECTING</li>
* <li> If there is one subchannel in TRANSIENT_FAILURE state and there is
* more than one subchannel, report CONNECTING </li>
* <li>If there is at least one subchannel in IDLE state, overall state is IDLE</li>
* <li>Otherwise, overall state is TRANSIENT_FAILURE</li>
* </ol>
*/
@Override
protected void updateOverallBalancingState() {
checkState(!getChildLbStates().isEmpty(), "no subchannel has been created");
if (this.currentConnectivityState == SHUTDOWN) {
// Ignore changes that happen after shutdown is called
logger.log(XdsLogLevel.DEBUG, "UpdateOverallBalancingState called after shutdown");
return;
}
// Calculate the current overall state to report
int numIdle = 0;
int numReady = 0;
int numConnecting = 0;
int numTF = 0;
forloop:
for (ChildLbState childLbState : getChildLbStates()) {
ConnectivityState state = childLbState.getCurrentState();
switch (state) {
case READY:
numReady++;
break forloop;
case CONNECTING:
numConnecting++;
break;
case IDLE:
numIdle++;
break;
case TRANSIENT_FAILURE:
numTF++;
break;
default:
// ignore it
}
}
ConnectivityState overallState;
if (numReady > 0) {
overallState = READY;
} else if (numTF >= 2) {
overallState = TRANSIENT_FAILURE;
} else if (numConnecting > 0) {
overallState = CONNECTING;
} else if (numTF == 1 && getChildLbStates().size() > 1) {
overallState = CONNECTING;
} else if (numIdle > 0) {
overallState = IDLE;
} else {
overallState = TRANSIENT_FAILURE;
}
RingHashPicker picker = new RingHashPicker(syncContext, ring, getImmutableChildMap());
getHelper().updateBalancingState(overallState, picker);
this.currentConnectivityState = overallState;
}
@Override
protected boolean reconnectOnIdle() {
return false;
}
@Override
protected boolean reactivateChildOnReuse() {
return false;
}
@Override
protected ChildLbState createChildLbState(Object key, Object policyConfig,
SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) {
return new RingHashChildLbState((Endpoint)key,
getChildAddresses(key, resolvedAddresses, null));
}
private Status validateAddrList(List<EquivalentAddressGroup> addrList) {
if (addrList.isEmpty()) {
Status unavailableStatus = Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS "
@ -197,7 +266,7 @@ final class RingHashLoadBalancer extends LoadBalancer {
if (weight < 0) {
Status unavailableStatus = Status.UNAVAILABLE.withDescription(
String.format("Ring hash lb error: EDS resolution was successful, but returned a "
String.format("Ring hash lb error: EDS resolution was successful, but returned a "
+ "negative weight for %s.", stripAttrs(eag)));
handleNameResolutionError(unavailableStatus);
return unavailableStatus;
@ -252,10 +321,10 @@ final class RingHashLoadBalancer extends LoadBalancer {
double currentHashes = 0.0;
double targetHashes = 0.0;
for (Map.Entry<EquivalentAddressGroup, Long> entry : serverWeights.entrySet()) {
EquivalentAddressGroup addrKey = entry.getKey();
Endpoint endpoint = new Endpoint(entry.getKey());
double normalizedWeight = (double) entry.getValue() / totalWeight;
// TODO(chengyuanzhang): is using the list of socket address correct?
StringBuilder sb = new StringBuilder(addrKey.getAddresses().toString());
// Per GRFC A61 use the first address for the hash
StringBuilder sb = new StringBuilder(entry.getKey().getAddresses().get(0).toString());
sb.append('_');
int lengthWithoutCounter = sb.length();
targetHashes += scale * normalizedWeight;
@ -263,7 +332,7 @@ final class RingHashLoadBalancer extends LoadBalancer {
while (currentHashes < targetHashes) {
sb.append(i);
long hash = hashFunc.hashAsciiString(sb.toString());
ring.add(new RingEntry(hash, addrKey));
ring.add(new RingEntry(hash, endpoint));
i++;
currentHashes++;
sb.setLength(lengthWithoutCounter);
@ -273,159 +342,14 @@ final class RingHashLoadBalancer extends LoadBalancer {
return Collections.unmodifiableList(ring);
}
@Override
public void handleNameResolutionError(Status error) {
if (currentState != READY) {
helper.updateBalancingState(
TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error)));
@SuppressWarnings("ReferenceEquality")
public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
if (eag.getAttributes() == Attributes.EMPTY) {
return eag;
}
}
@Override
public void shutdown() {
logger.log(XdsLogLevel.INFO, "Shutdown");
for (Subchannel subchannel : subchannels.values()) {
shutdownSubchannel(subchannel);
}
subchannels.clear();
}
/**
* Updates the overall balancing state by aggregating the connectivity states of all subchannels.
*
* <p>Aggregation rules (in order of dominance):
* <ol>
* <li>If there is at least one subchannel in READY state, overall state is READY</li>
* <li>If there are <em>2 or more</em> subchannels in TRANSIENT_FAILURE, overall state is
* TRANSIENT_FAILURE</li>
* <li>If there is at least one subchannel in CONNECTING state, overall state is
* CONNECTING</li>
* <li> If there is one subchannel in TRANSIENT_FAILURE state and there is
* more than one subchannel, report CONNECTING </li>
* <li>If there is at least one subchannel in IDLE state, overall state is IDLE</li>
* <li>Otherwise, overall state is TRANSIENT_FAILURE</li>
* </ol>
*/
private void updateBalancingState() {
checkState(!subchannels.isEmpty(), "no subchannel has been created");
boolean startConnectionAttempt = false;
int numIdle = 0;
int numReady = 0;
int numConnecting = 0;
int numTransientFailure = 0;
for (Subchannel subchannel : subchannels.values()) {
ConnectivityState state = getSubchannelStateInfoRef(subchannel).value.getState();
if (state == READY) {
numReady++;
break;
} else if (state == TRANSIENT_FAILURE) {
numTransientFailure++;
} else if (state == CONNECTING ) {
numConnecting++;
} else if (state == IDLE) {
numIdle++;
}
}
ConnectivityState overallState;
if (numReady > 0) {
overallState = READY;
} else if (numTransientFailure >= 2) {
overallState = TRANSIENT_FAILURE;
startConnectionAttempt = (numConnecting == 0);
} else if (numConnecting > 0) {
overallState = CONNECTING;
} else if (numTransientFailure == 1 && subchannels.size() > 1) {
overallState = CONNECTING;
startConnectionAttempt = true;
} else if (numIdle > 0) {
overallState = IDLE;
} else {
overallState = TRANSIENT_FAILURE;
startConnectionAttempt = true;
}
RingHashPicker picker = new RingHashPicker(syncContext, ring, subchannels);
// TODO(chengyuanzhang): avoid unnecessary reprocess caused by duplicated server addr updates
helper.updateBalancingState(overallState, picker);
currentState = overallState;
// While the ring_hash policy is reporting TRANSIENT_FAILURE, it will
// not be getting any pick requests from the priority policy.
// However, because the ring_hash policy does not attempt to
// reconnect to subchannels unless it is getting pick requests,
// it will need special handling to ensure that it will eventually
// recover from TRANSIENT_FAILURE state once the problem is resolved.
// Specifically, it will make sure that it is attempting to connect to
// at least one subchannel at any given time. After a given subchannel
// fails a connection attempt, it will move on to the next subchannel
// in the ring. It will keep doing this until one of the subchannels
// successfully connects, at which point it will report READY and stop
// proactively trying to connect. The policy will remain in
// TRANSIENT_FAILURE until at least one subchannel becomes connected,
// even if subchannels are in state CONNECTING during that time.
//
// Note that we do the same thing when the policy is in state
// CONNECTING, just to ensure that we don't remain in CONNECTING state
// indefinitely if there are no new picks coming in.
if (startConnectionAttempt) {
if (!connectionAttemptIterator.hasNext()) {
connectionAttemptIterator = subchannels.values().iterator();
}
connectionAttemptIterator.next().requestConnection();
}
}
private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) {
return;
}
if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) {
helper.refreshNameResolution();
}
updateConnectivityState(subchannel, stateInfo);
updateBalancingState();
}
private void updateConnectivityState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
Ref<ConnectivityStateInfo> subchannelStateRef = getSubchannelStateInfoRef(subchannel);
ConnectivityState previousConnectivityState = subchannelStateRef.value.getState();
// Don't proactively reconnect if the subchannel enters IDLE, even if previously was connected.
// If the subchannel was previously in TRANSIENT_FAILURE, it is considered to stay in
// TRANSIENT_FAILURE until it becomes READY.
if (previousConnectivityState == TRANSIENT_FAILURE) {
if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) {
return;
}
}
subchannelStateRef.value = stateInfo;
}
private static void shutdownSubchannel(Subchannel subchannel) {
subchannel.shutdown();
getSubchannelStateInfoRef(subchannel).value = ConnectivityStateInfo.forNonError(SHUTDOWN);
}
/**
* Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
* remove all attributes. The values are the original EAGs.
*/
private static Map<EquivalentAddressGroup, EquivalentAddressGroup> stripAttrs(
List<EquivalentAddressGroup> groupList) {
Map<EquivalentAddressGroup, EquivalentAddressGroup> addrs =
new HashMap<>(groupList.size() * 2);
for (EquivalentAddressGroup group : groupList) {
addrs.put(stripAttrs(group), group);
}
return addrs;
}
private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
return new EquivalentAddressGroup(eag.getAddresses());
}
private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef(
Subchannel subchannel) {
return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
}
private static final class RingHashPicker extends SubchannelPicker {
private final SynchronizationContext syncContext;
private final List<RingEntry> ring;
@ -433,38 +357,31 @@ final class RingHashLoadBalancer extends LoadBalancer {
// freeze picker's view of subchannel's connectivity state.
// TODO(chengyuanzhang): can be more performance-friendly with
// IdentityHashMap<Subchannel, ConnectivityStateInfo> and RingEntry contains Subchannel.
private final Map<EquivalentAddressGroup, SubchannelView> pickableSubchannels; // read-only
private final Map<Endpoint, SubchannelView> pickableSubchannels; // read-only
private RingHashPicker(
SynchronizationContext syncContext, List<RingEntry> ring,
Map<EquivalentAddressGroup, Subchannel> subchannels) {
ImmutableMap<Object, ChildLbState> subchannels) {
this.syncContext = syncContext;
this.ring = ring;
pickableSubchannels = new HashMap<>(subchannels.size());
for (Map.Entry<EquivalentAddressGroup, Subchannel> entry : subchannels.entrySet()) {
Subchannel subchannel = entry.getValue();
ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).value;
pickableSubchannels.put(entry.getKey(), new SubchannelView(subchannel, stateInfo));
for (Map.Entry<Object, ChildLbState> entry : subchannels.entrySet()) {
RingHashChildLbState childLbState = (RingHashChildLbState) entry.getValue();
pickableSubchannels.put((Endpoint)entry.getKey(),
new SubchannelView(childLbState, childLbState.getCurrentState()));
}
}
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY);
if (requestHash == null) {
return PickResult.withError(RPC_HASH_NOT_FOUND);
// Find the ring entry with hash next to (clockwise) the RPC's hash (binary search).
private int getTargetIndex(Long requestHash) {
if (ring.size() <= 1) {
return 0;
}
// Find the ring entry with hash next to (clockwise) the RPC's hash.
int low = 0;
int high = ring.size();
int mid;
while (true) {
mid = (low + high) / 2;
if (mid == ring.size()) {
mid = 0;
break;
}
int high = ring.size() - 1;
int mid = (low + high) / 2;
do {
long midVal = ring.get(mid).hash;
long midValL = mid == 0 ? 0 : ring.get(mid - 1).hash;
if (requestHash <= midVal && requestHash > midValL) {
@ -475,79 +392,61 @@ final class RingHashLoadBalancer extends LoadBalancer {
} else {
high = mid - 1;
}
if (low > high) {
mid = 0;
break;
}
mid = (low + high) / 2;
} while (mid < ring.size() && low <= high);
return mid;
}
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY);
if (requestHash == null) {
return PickResult.withError(RPC_HASH_NOT_FOUND);
}
// Try finding a READY subchannel. Starting from the ring entry next to the RPC's hash.
// If the one of the first two subchannels is not in TRANSIENT_FAILURE, return result
// based on that subchannel. Otherwise, fail the pick unless a READY subchannel is found.
// Meanwhile, trigger connection for the channel and status:
// For the first subchannel that is in IDLE or TRANSIENT_FAILURE;
// And for the second subchannel that is in IDLE or TRANSIENT_FAILURE;
// And for each of the following subchannels that is in TRANSIENT_FAILURE or IDLE,
// stop until we find the first subchannel that is in CONNECTING or IDLE status.
boolean foundFirstNonFailed = false; // true if having subchannel(s) in CONNECTING or IDLE
Subchannel firstSubchannel = null;
Subchannel secondSubchannel = null;
int targetIndex = getTargetIndex(requestHash);
// Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore
// all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If
// CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in
// IDLE, we initiate a connection.
for (int i = 0; i < ring.size(); i++) {
int index = (mid + i) % ring.size();
EquivalentAddressGroup addrKey = ring.get(index).addrKey;
SubchannelView subchannel = pickableSubchannels.get(addrKey);
if (subchannel.stateInfo.getState() == READY) {
return PickResult.withSubchannel(subchannel.subchannel);
int index = (targetIndex + i) % ring.size();
SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey);
RingHashChildLbState childLbState = subchannelView.childLbState;
if (subchannelView.connectivityState == READY) {
return childLbState.getCurrentPicker().pickSubchannel(args);
}
// RPCs can be buffered if any of the first two subchannels is pending. Otherwise, RPCs
// RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs
// are failed unless there is a READY connection.
if (firstSubchannel == null) {
firstSubchannel = subchannel.subchannel;
PickResult maybeBuffer = pickSubchannelsNonReady(subchannel);
if (maybeBuffer != null) {
return maybeBuffer;
}
} else if (subchannel.subchannel != firstSubchannel && secondSubchannel == null) {
secondSubchannel = subchannel.subchannel;
PickResult maybeBuffer = pickSubchannelsNonReady(subchannel);
if (maybeBuffer != null) {
return maybeBuffer;
}
} else if (subchannel.subchannel != firstSubchannel
&& subchannel.subchannel != secondSubchannel) {
if (!foundFirstNonFailed) {
pickSubchannelsNonReady(subchannel);
if (subchannel.stateInfo.getState() != TRANSIENT_FAILURE) {
foundFirstNonFailed = true;
}
if (subchannelView.connectivityState == CONNECTING) {
return PickResult.withNoResult();
}
if (subchannelView.connectivityState == IDLE || childLbState.isDeactivated()) {
if (childLbState.isDeactivated()) {
childLbState.activate();
} else {
syncContext.execute(() -> childLbState.getLb().requestConnection());
}
return PickResult.withNoResult(); // Indicates that this should be retried after backoff
}
}
// Fail the pick with error status of the original subchannel hit by hash.
SubchannelView originalSubchannel = pickableSubchannels.get(ring.get(mid).addrKey);
return PickResult.withError(originalSubchannel.stateInfo.getStatus());
// return the pick from the original subchannel hit by hash, which is probably an error
RingHashChildLbState originalSubchannel =
pickableSubchannels.get(ring.get(targetIndex).addrKey).childLbState;
return originalSubchannel.getCurrentPicker().pickSubchannel(args);
}
@Nullable
private PickResult pickSubchannelsNonReady(SubchannelView subchannel) {
if (subchannel.stateInfo.getState() == TRANSIENT_FAILURE
|| subchannel.stateInfo.getState() == IDLE ) {
final Subchannel finalSubchannel = subchannel.subchannel;
syncContext.execute(new Runnable() {
@Override
public void run() {
finalSubchannel.requestConnection();
}
});
}
if (subchannel.stateInfo.getState() == CONNECTING
|| subchannel.stateInfo.getState() == IDLE) {
return PickResult.withNoResult();
} else {
return null;
}
}
}
@Override
protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) {
throw new UnsupportedOperationException("Not used by RingHash");
}
/**
@ -555,20 +454,20 @@ final class RingHashLoadBalancer extends LoadBalancer {
* state changes.
*/
private static final class SubchannelView {
private final Subchannel subchannel;
private final ConnectivityStateInfo stateInfo;
private final RingHashChildLbState childLbState;
private final ConnectivityState connectivityState;
private SubchannelView(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
this.subchannel = subchannel;
this.stateInfo = stateInfo;
private SubchannelView(RingHashChildLbState childLbState, ConnectivityState state) {
this.childLbState = childLbState;
this.connectivityState = state;
}
}
private static final class RingEntry implements Comparable<RingEntry> {
private final long hash;
private final EquivalentAddressGroup addrKey;
private final Endpoint addrKey;
private RingEntry(long hash, EquivalentAddressGroup addrKey) {
private RingEntry(long hash, Endpoint addrKey) {
this.hash = hash;
this.addrKey = addrKey;
}
@ -579,17 +478,6 @@ final class RingHashLoadBalancer extends LoadBalancer {
}
}
/**
* A lighter weight Reference than AtomicReference.
*/
private static final class Ref<T> {
T value;
Ref(T value) {
this.value = value;
}
}
/**
* Configures the ring property. The larger the ring is (that is, the more hashes there are
* for each provided host) the better the request distribution will reflect the desired weights.
@ -614,4 +502,58 @@ final class RingHashLoadBalancer extends LoadBalancer {
.toString();
}
}
}
static Set<EquivalentAddressGroup> getStrippedChildEags(Collection<ChildLbState> states) {
return states.stream()
.map(ChildLbState::getEag)
.map(RingHashLoadBalancer::stripAttrs)
.collect(Collectors.toSet());
}
@Override
protected Collection<ChildLbState> getChildLbStates() {
return super.getChildLbStates();
}
@Override
protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) {
return super.getChildLbStateEag(eag);
}
class RingHashChildLbState extends MultiChildLoadBalancer.ChildLbState {
public RingHashChildLbState(Endpoint key, ResolvedAddresses resolvedAddresses) {
super(key, pickFirstLbProvider, null, EMPTY_PICKER, resolvedAddresses, true);
}
@Override
protected void reactivate(LoadBalancerProvider policyProvider) {
if (!isDeactivated()) {
return;
}
currentConnectivityState = CONNECTING;
getLb().switchTo(pickFirstLbProvider);
markReactivated();
getLb().acceptResolvedAddresses(this.getResolvedAddresses()); // Time to get a subchannel
logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey());
}
public void activate() {
reactivate(pickFirstLbProvider);
}
// Need to expose this to the LB class
@Override
protected void shutdown() {
super.shutdown();
}
// Need to expose this to the LB class
@Override
protected GracefulSwitchLoadBalancer getLb() {
return super.getLb();
}
}
}

View File

@ -97,7 +97,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override
protected ChildLbState createChildLbState(Object key, Object policyConfig,
SubchannelPicker initialPicker) {
SubchannelPicker initialPicker, ResolvedAddresses unused) {
ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig,
initialPicker);
return childLbState;
@ -115,13 +115,31 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
config =
(WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
Status addressAcceptanceStatus = super.acceptResolvedAddresses(resolvedAddresses);
if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
weightUpdateTimer.cancel();
AcceptResolvedAddressRetVal acceptRetVal;
try {
resolvingAddresses = true;
acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses);
if (!acceptRetVal.status.isOk()) {
return acceptRetVal.status;
}
if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
weightUpdateTimer.cancel();
}
updateWeightTask.run();
createAndApplyOrcaListeners();
// Must update channel picker before return so that new RPCs will not be routed to deleted
// clusters and resolver can remove them in service config.
updateOverallBalancingState();
shutdownRemoved(acceptRetVal.removedChildren);
} finally {
resolvingAddresses = false;
}
updateWeightTask.run();
afterAcceptAddresses();
return addressAcceptanceStatus;
return acceptRetVal.status;
}
@Override
@ -228,7 +246,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
}
private void afterAcceptAddresses() {
private void createAndApplyOrcaListeners() {
for (ChildLbState child : getChildLbStates()) {
WeightedChildLbState wChild = (WeightedChildLbState) child;
for (WrrSubchannel weightedSubchannel : wChild.subchannels) {

View File

@ -106,7 +106,7 @@ public class LeastRequestLoadBalancerTest {
@Captor
private ArgumentCaptor<CreateSubchannelArgs> createArgsCaptor;
private final TestHelper testHelperInstance = new TestHelper();
private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance));
private final Helper helper = mock(Helper.class, delegatesTo(testHelperInstance));
@Mock
private ThreadSafeRandom mockRandom;
@ -522,7 +522,6 @@ public class LeastRequestLoadBalancerTest {
loadBalancer.handleNameResolutionError(error);
loadBalancer.setResolvingAddresses(false);
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
assertNull(pickResult.getSubchannel());
assertEquals(error, pickResult.getStatus());

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.ConnectivityState.CONNECTING;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
@ -59,6 +60,7 @@ import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancer
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@ -78,7 +80,9 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@ -176,7 +180,7 @@ public class WeightedRoundRobinLoadBalancerTest {
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
.forNonError(CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
@ -477,7 +481,7 @@ public class WeightedRoundRobinLoadBalancerTest {
.setAttributes(affinity).build()));
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue().getClass().getName())
.isEqualTo("io.grpc.util.RoundRobinLoadBalancer$EmptyPicker");
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
@ -554,7 +558,7 @@ public class WeightedRoundRobinLoadBalancerTest {
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
.forNonError(CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
@ -1063,6 +1067,24 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(sequence.get()).isEqualTo(9);
}
@Test
public void removingAddressShutsdownSubchannel() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
final Subchannel subchannel2 = subchannels.get(Collections.singletonList(servers.get(2)));
InOrder inOrder = Mockito.inOrder(helper, subchannel2);
// send LB only the first 2 addresses
List<EquivalentAddressGroup> svs2 = Arrays.asList(servers.get(0), servers.get(1));
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(svs2).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any());
inOrder.verify(subchannel2).shutdown();
}
private static final class VerifyingScheduler {
private final StaticStrideScheduler delegate;
private final int max;