lb:Implement LeastRequestLB as a petiole policy and restore RR and WRR (#10584)

* Change LeastRequest, Round Robin and WeightedRoundRobin into petiole policies
This commit is contained in:
Larry Safran 2023-10-16 16:40:20 -07:00 committed by GitHub
parent fc03f2be9d
commit 0d39bf5018
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1548 additions and 1045 deletions

View File

@ -128,6 +128,9 @@ public final class EquivalentAddressGroup {
*/
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof EquivalentAddressGroup)) {
return false;
}

View File

@ -161,7 +161,7 @@ import org.mockito.stubbing.Answer;
/** Unit tests for {@link ManagedChannelImpl}. */
@RunWith(JUnit4.class)
// TODO(creamsoup) remove backward compatible check when fully migrated
@SuppressWarnings("deprecation")
@SuppressWarnings({"deprecation"})
public class ManagedChannelImplTest {
private static final int DEFAULT_PORT = 447;

View File

@ -21,9 +21,11 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ChannelLogger;
import io.grpc.ClientStreamTracer;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
@ -143,6 +145,14 @@ public final class TestUtils {
return captor;
}
@SuppressWarnings("ReferenceEquality")
public static final EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
if (eag.getAttributes() == Attributes.EMPTY) {
return eag;
}
return new EquivalentAddressGroup(eag.getAddresses());
}
private TestUtils() {
}

View File

@ -1,5 +1,6 @@
plugins {
id "java-library"
id "java-test-fixtures"
id "maven-publish"
id "me.champeau.jmh"
@ -19,11 +20,18 @@ dependencies {
implementation libraries.animalsniffer.annotations,
libraries.guava
testImplementation testFixtures(project(':grpc-api')),
testImplementation libraries.guava.testlib,
testFixtures(project(':grpc-api')),
testFixtures(project(':grpc-core')),
project(':grpc-testing')
testImplementation libraries.guava.testlib
testFixturesApi project(':grpc-core')
testFixturesImplementation libraries.guava,
libraries.junit,
libraries.mockito.core,
testFixtures(project(':grpc-api')),
testFixtures(project(':grpc-core')),
project(':grpc-testing')
jmh project(':grpc-testing')
signature libraries.signature.java

View File

@ -16,25 +16,32 @@
package io.grpc.util;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Internal;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.internal.PickFirstLoadBalancerProvider;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@ -46,23 +53,26 @@ import javax.annotation.Nullable;
@Internal
public abstract class MultiChildLoadBalancer extends LoadBalancer {
@VisibleForTesting
public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15;
private static final Logger logger = Logger.getLogger(MultiChildLoadBalancer.class.getName());
private final Map<Object, ChildLbState> childLbStates = new HashMap<>();
private final Map<Object, ChildLbState> childLbStates = new LinkedHashMap<>();
private final Helper helper;
protected final SynchronizationContext syncContext;
private final ScheduledExecutorService timeService;
// Set to true if currently in the process of handling resolved addresses.
private boolean resolvingAddresses;
@VisibleForTesting
protected boolean resolvingAddresses;
protected final PickFirstLoadBalancerProvider pickFirstLbProvider =
new PickFirstLoadBalancerProvider();
protected ConnectivityState currentConnectivityState;
protected MultiChildLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
logger.log(Level.FINE, "Created");
}
protected abstract SubchannelPicker getSubchannelPicker(
Map<Object, SubchannelPicker> childPickers);
protected SubchannelPicker getInitialPicker() {
return EMPTY_PICKER;
}
@ -71,12 +81,67 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return new FixedResultPicker(PickResult.withError(error));
}
protected abstract Map<Object, PolicySelection> getPolicySelectionMap(
ResolvedAddresses resolvedAddresses);
/**
* Generally, the only reason to override this is to expose it to a test of a LB in a different
* package.
*/
@VisibleForTesting
protected Collection<ChildLbState> getChildLbStates() {
return childLbStates.values();
}
protected abstract SubchannelPicker getSubchannelPicker(
Map<Object, SubchannelPicker> childPickers);
/**
* Generally, the only reason to override this is to expose it to a test of a LB in a
* different package.
*/
protected ChildLbState getChildLbState(Object key) {
if (key == null) {
return null;
}
if (key instanceof EquivalentAddressGroup) {
key = new Endpoint((EquivalentAddressGroup) key);
}
return childLbStates.get(key);
}
/**
* Generally, the only reason to override this is to expose it to a test of a LB in a different
* package.
*/
protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) {
return getChildLbState(new Endpoint(eag));
}
/**
* Override to utilize parsing of the policy configuration or alternative helper/lb generation.
*/
protected Map<Object, ChildLbState> createChildLbMap(ResolvedAddresses resolvedAddresses) {
Map<Object, ChildLbState> childLbMap = new HashMap<>();
List<EquivalentAddressGroup> addresses = resolvedAddresses.getAddresses();
for (EquivalentAddressGroup eag : addresses) {
Endpoint endpoint = new Endpoint(eag); // keys need to be just addresses
ChildLbState existingChildLbState = childLbStates.get(endpoint);
if (existingChildLbState != null) {
childLbMap.put(endpoint, existingChildLbState);
} else {
childLbMap.put(endpoint, createChildLbState(endpoint, null, getInitialPicker()));
}
}
return childLbMap;
}
/**
* Override to create an instance of a subclass.
*/
protected ChildLbState createChildLbState(Object key, Object policyConfig,
SubchannelPicker initialPicker) {
return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker);
}
/**
* Override to completely replace the default logic or to do additional activities.
*/
@Override
public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
try {
@ -87,25 +152,71 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
}
}
/**
* Override this if your keys are not of type Endpoint.
* @param key Key to identify the ChildLbState
* @param resolvedAddresses list of addresses which include attributes
* @param childConfig a load balancing policy config. This field is optional.
* @return a fully loaded ResolvedAddresses object for the specified key
*/
protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses,
Object childConfig) {
if (key instanceof EquivalentAddressGroup) {
key = new Endpoint((EquivalentAddressGroup) key);
}
checkArgument(key instanceof Endpoint, "key is wrong type");
// Retrieve the non-stripped version
EquivalentAddressGroup eagToUse = null;
for (EquivalentAddressGroup currEag : resolvedAddresses.getAddresses()) {
if (key.equals(new Endpoint(currEag))) {
eagToUse = currEag;
break;
}
}
checkNotNull(eagToUse, key + " no longer present in load balancer children");
return resolvedAddresses.toBuilder()
.setAddresses(Collections.singletonList(eagToUse))
.setLoadBalancingPolicyConfig(childConfig)
.build();
}
private boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) {
logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses);
Map<Object, PolicySelection> newChildPolicies = getPolicySelectionMap(resolvedAddresses);
for (Map.Entry<Object, PolicySelection> entry : newChildPolicies.entrySet()) {
Map<Object, ChildLbState> newChildren = createChildLbMap(resolvedAddresses);
if (newChildren.isEmpty()) {
handleNameResolutionError(Status.UNAVAILABLE.withDescription(
"NameResolver returned no usable address. " + resolvedAddresses));
return false;
}
// Do adds and updates
for (Map.Entry<Object, ChildLbState> entry : newChildren.entrySet()) {
final Object key = entry.getKey();
LoadBalancerProvider childPolicyProvider = entry.getValue().getProvider();
LoadBalancerProvider childPolicyProvider = entry.getValue().getPolicyProvider();
Object childConfig = entry.getValue().getConfig();
if (!childLbStates.containsKey(key)) {
childLbStates.put(key, new ChildLbState(key, childPolicyProvider, getInitialPicker()));
childLbStates.put(key, entry.getValue());
} else {
childLbStates.get(key).reactivate(childPolicyProvider);
// Reuse the existing one
ChildLbState existingChildLbState = childLbStates.get(key);
if (existingChildLbState.isDeactivated()) {
existingChildLbState.reactivate(childPolicyProvider);
}
}
LoadBalancer childLb = childLbStates.get(key).lb;
ResolvedAddresses childAddresses =
resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build();
childLb.handleResolvedAddresses(childAddresses);
ResolvedAddresses childAddresses = getChildAddresses(key, resolvedAddresses, childConfig);
childLbStates.get(key).setResolvedAddresses(childAddresses); // update child state
childLb.handleResolvedAddresses(childAddresses); // update child LB
}
for (Object key : childLbStates.keySet()) {
if (!newChildPolicies.containsKey(key)) {
// Do removals
for (Object key : ImmutableList.copyOf(childLbStates.keySet())) {
if (!newChildren.containsKey(key)) {
childLbStates.get(key).deactivate();
}
}
@ -117,19 +228,23 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
@Override
public void handleNameResolutionError(Status error) {
logger.log(Level.WARNING, "Received name resolution error: {0}", error);
boolean gotoTransientFailure = true;
for (ChildLbState state : childLbStates.values()) {
if (!state.deactivated) {
gotoTransientFailure = false;
state.lb.handleNameResolutionError(error);
}
}
if (gotoTransientFailure) {
helper.updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error));
if (currentConnectivityState != READY) {
updateHelperBalancingState(TRANSIENT_FAILURE, getErrorPicker(error));
}
}
protected void handleNameResolutionError(ChildLbState child, Status error) {
child.lb.handleNameResolutionError(error);
}
/**
* If true, then when a subchannel state changes to idle, the corresponding child will
* have requestConnection called on its LB.
*/
protected boolean reconnectOnIdle() {
return true;
}
@Override
public void shutdown() {
logger.log(Level.INFO, "Shutdown");
@ -139,10 +254,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
childLbStates.clear();
}
private void updateOverallBalancingState() {
protected void updateOverallBalancingState() {
ConnectivityState overallState = null;
final Map<Object, SubchannelPicker> childPickers = new HashMap<>();
for (ChildLbState childLbState : childLbStates.values()) {
for (ChildLbState childLbState : getChildLbStates()) {
if (childLbState.deactivated) {
continue;
}
@ -151,11 +266,17 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
}
if (overallState != null) {
helper.updateBalancingState(overallState, getSubchannelPicker(childPickers));
currentConnectivityState = overallState;
}
}
protected final void updateHelperBalancingState(ConnectivityState newState,
SubchannelPicker newPicker) {
helper.updateBalancingState(newState, newPicker);
}
@Nullable
private static ConnectivityState aggregateState(
protected static ConnectivityState aggregateState(
@Nullable ConnectivityState overallState, ConnectivityState childState) {
if (overallState == null) {
return childState;
@ -172,70 +293,155 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return overallState;
}
private final class ChildLbState {
protected Helper getHelper() {
return helper;
}
protected void removeChild(Object key) {
childLbStates.remove(key);
}
/**
* Filters out non-ready and deactivated child load balancers (subchannels).
*/
protected List<ChildLbState> getReadyChildren() {
List<ChildLbState> activeChildren = new ArrayList<>();
for (ChildLbState child : getChildLbStates()) {
if (!child.isDeactivated() && child.getCurrentState() == READY) {
activeChildren.add(child);
}
}
return activeChildren;
}
/**
* This represents the state of load balancer children. Each endpoint (represented by an
* EquivalentAddressGroup or EDS string) will have a separate ChildLbState which in turn will
* define a GracefulSwitchLoadBalancer. When the GracefulSwitchLoadBalancer is activated, a
* single PickFirstLoadBalancer will be created which will then create a subchannel and start
* trying to connect to it.
*
* <p>A ChildLbStateHelper is the glue between ChildLbState and the helpers associated with the
* petiole policy above and the PickFirstLoadBalancer's helper below.
*
* <p>If you wish to store additional state information related to each subchannel, then extend
* this class.
*/
public class ChildLbState {
private final Object key;
private ResolvedAddresses resolvedAddresses;
private final Object config;
private final GracefulSwitchLoadBalancer lb;
private LoadBalancerProvider policyProvider;
private ConnectivityState currentState = CONNECTING;
private SubchannelPicker currentPicker;
private boolean deactivated;
@Nullable
ScheduledHandle deletionTimer;
ChildLbState(Object key, LoadBalancerProvider policyProvider, SubchannelPicker initialPicker) {
public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
SubchannelPicker initialPicker) {
this.key = key;
this.policyProvider = policyProvider;
lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper());
lb.switchTo(policyProvider);
currentPicker = initialPicker;
config = childConfig;
}
void deactivate() {
@Override
public String toString() {
return "Address = " + key
+ ", state = " + currentState
+ ", picker type: " + currentPicker.getClass()
+ ", lb: " + lb.delegate().getClass()
+ (deactivated ? ", deactivated" : "");
}
public Object getKey() {
return key;
}
Object getConfig() {
return config;
}
public LoadBalancerProvider getPolicyProvider() {
return policyProvider;
}
protected Subchannel getSubchannels(PickSubchannelArgs args) {
if (getCurrentPicker() == null) {
return null;
}
return getCurrentPicker().pickSubchannel(args).getSubchannel();
}
public ConnectivityState getCurrentState() {
return currentState;
}
public SubchannelPicker getCurrentPicker() {
return currentPicker;
}
public EquivalentAddressGroup getEag() {
if (resolvedAddresses == null || resolvedAddresses.getAddresses().isEmpty()) {
return null;
}
return resolvedAddresses.getAddresses().get(0);
}
public boolean isDeactivated() {
return deactivated;
}
protected void setDeactivated() {
deactivated = true;
}
protected void setResolvedAddresses(ResolvedAddresses newAddresses) {
checkNotNull(newAddresses, "Missing address list for child");
resolvedAddresses = newAddresses;
}
protected void deactivate() {
if (deactivated) {
return;
}
class DeletionTask implements Runnable {
@Override
public void run() {
shutdown();
childLbStates.remove(key);
}
}
deletionTimer =
syncContext.schedule(
new DeletionTask(),
DELAYED_CHILD_DELETION_TIME_MINUTES,
TimeUnit.MINUTES,
timeService);
shutdown();
childLbStates.remove(key);
deactivated = true;
logger.log(Level.FINE, "Child balancer {0} deactivated", key);
}
void reactivate(LoadBalancerProvider policyProvider) {
if (deletionTimer != null && deletionTimer.isPending()) {
deletionTimer.cancel();
deactivated = false;
logger.log(Level.FINE, "Child balancer {0} reactivated", key);
}
protected void reactivate(LoadBalancerProvider policyProvider) {
if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) {
Object[] objects = {
key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()};
logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects);
lb.switchTo(policyProvider);
this.policyProvider = policyProvider;
} else {
logger.log(Level.FINE, "Child balancer {0} reactivated", key);
lb.acceptResolvedAddresses(resolvedAddresses);
}
deactivated = false;
}
void shutdown() {
if (deletionTimer != null && deletionTimer.isPending()) {
deletionTimer.cancel();
}
protected void shutdown() {
lb.shutdown();
this.currentState = SHUTDOWN;
logger.log(Level.FINE, "Child balancer {0} deleted", key);
}
/**
* ChildLbStateHelper is the glue between ChildLbState and the helpers associated with the
* petiole policy above and the PickFirstLoadBalancer's helper below.
*
* <p>The ChildLbState updates happen during updateBalancingState. Otherwise, it is doing
* simple forwarding.
*/
private final class ChildLbStateHelper extends ForwardingLoadBalancerHelper {
@Override
@ -251,6 +457,9 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
currentState = newState;
currentPicker = newPicker;
if (!deactivated && !resolvingAddresses) {
if (newState == IDLE && reconnectOnIdle()) {
lb.requestConnection();
}
updateOverallBalancingState();
}
}
@ -261,4 +470,58 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
}
}
}
/**
* Endpoint is an optimization to quickly lookup and compare EquivalentAddressGroup address sets.
* Ignores the attributes, orders the addresses in a deterministic manner and converts each
* address into a string for easy comparison. Also caches the hashcode.
* Is used as a key for ChildLbState for most load balancers (ClusterManagerLB uses a String).
*/
protected static class Endpoint {
final String[] addrs;
final int hashCode;
Endpoint(EquivalentAddressGroup eag) {
checkNotNull(eag, "eag");
addrs = new String[eag.getAddresses().size()];
int i = 0;
for (SocketAddress address : eag.getAddresses()) {
addrs[i] = address.toString();
}
Arrays.sort(addrs);
hashCode = Arrays.hashCode(addrs);
}
@Override
public int hashCode() {
return hashCode;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null) {
return false;
}
if (!(other instanceof Endpoint)) {
return false;
}
Endpoint o = (Endpoint) other;
if (o.hashCode != hashCode || o.addrs.length != addrs.length) {
return false;
}
return Arrays.equals(o.addrs, this.addrs);
}
@Override
public String toString() {
return Arrays.toString(addrs);
}
}
}

View File

@ -16,11 +16,9 @@
package io.grpc.util;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.annotations.VisibleForTesting;
@ -37,13 +35,10 @@ import io.grpc.NameResolver;
import io.grpc.Status;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import javax.annotation.Nonnull;
@ -52,131 +47,22 @@ import javax.annotation.Nonnull;
* EquivalentAddressGroup}s from the {@link NameResolver}.
*/
@Internal
public class RoundRobinLoadBalancer extends LoadBalancer {
public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
@VisibleForTesting
static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.create("state-info");
private final Helper helper;
private final Map<EquivalentAddressGroup, Subchannel> subchannels =
new HashMap<>();
private final Random random;
private ConnectivityState currentState;
protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK);
public RoundRobinLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
super(helper);
this.random = new Random();
}
@Override
public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
if (resolvedAddresses.getAddresses().isEmpty()) {
handleNameResolutionError(Status.UNAVAILABLE.withDescription(
"NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses()
+ ", attrs=" + resolvedAddresses.getAttributes()));
return false;
}
List<EquivalentAddressGroup> servers = resolvedAddresses.getAddresses();
Set<EquivalentAddressGroup> currentAddrs = subchannels.keySet();
Map<EquivalentAddressGroup, EquivalentAddressGroup> latestAddrs = stripAttrs(servers);
Set<EquivalentAddressGroup> removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet());
for (Map.Entry<EquivalentAddressGroup, EquivalentAddressGroup> latestEntry :
latestAddrs.entrySet()) {
EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey();
EquivalentAddressGroup originalAddressGroup = latestEntry.getValue();
Subchannel existingSubchannel = subchannels.get(strippedAddressGroup);
if (existingSubchannel != null) {
// EAG's Attributes may have changed.
existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup));
continue;
}
// Create new subchannels for new addresses.
// NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel
// doesn't need them. They're describing the resolved server list but we're not taking
// any action based on this information.
Attributes.Builder subchannelAttrs = Attributes.newBuilder()
// NB(lukaszx0): because attributes are immutable we can't set new value for the key
// after creation but since we can mutate the values we leverage that and set
// AtomicReference which will allow mutating state info for given channel.
.set(STATE_INFO,
new Ref<>(ConnectivityStateInfo.forNonError(IDLE)));
final Subchannel subchannel = checkNotNull(
helper.createSubchannel(CreateSubchannelArgs.newBuilder()
.setAddresses(originalAddressGroup)
.setAttributes(subchannelAttrs.build())
.build()),
"subchannel");
subchannel.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo state) {
processSubchannelState(subchannel, state);
}
});
subchannels.put(strippedAddressGroup, subchannel);
subchannel.requestConnection();
}
ArrayList<Subchannel> removedSubchannels = new ArrayList<>();
for (EquivalentAddressGroup addressGroup : removedAddrs) {
removedSubchannels.add(subchannels.remove(addressGroup));
}
// Update the picker before shutting down the subchannels, to reduce the chance of the race
// between picking a subchannel and shutting it down.
updateBalancingState();
// Shutdown removed subchannels
for (Subchannel removedSubchannel : removedSubchannels) {
shutdownSubchannel(removedSubchannel);
}
return true;
}
@Override
public void handleNameResolutionError(Status error) {
if (currentState != READY) {
updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error));
}
}
private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) {
return;
}
if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) {
helper.refreshNameResolution();
}
if (stateInfo.getState() == IDLE) {
subchannel.requestConnection();
}
Ref<ConnectivityStateInfo> subchannelStateRef = getSubchannelStateInfoRef(subchannel);
if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) {
if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) {
return;
}
}
subchannelStateRef.value = stateInfo;
updateBalancingState();
}
private void shutdownSubchannel(Subchannel subchannel) {
subchannel.shutdown();
getSubchannelStateInfoRef(subchannel).value =
ConnectivityStateInfo.forNonError(SHUTDOWN);
}
@Override
public void shutdown() {
for (Subchannel subchannel : getSubchannels()) {
shutdownSubchannel(subchannel);
}
subchannels.clear();
protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) {
throw new UnsupportedOperationException(); // local updateOverallBalancingState doesn't use this
}
private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready");
@ -184,102 +70,54 @@ public class RoundRobinLoadBalancer extends LoadBalancer {
/**
* Updates picker with the list of active subchannels (state == READY).
*/
@SuppressWarnings("ReferenceEquality")
private void updateBalancingState() {
List<Subchannel> activeList = filterNonFailingSubchannels(getSubchannels());
@Override
protected void updateOverallBalancingState() {
List<ChildLbState> activeList = getReadyChildren();
if (activeList.isEmpty()) {
// No READY subchannels, determine aggregate state and error status
// No READY subchannels
// RRLB will request connection immediately on subchannel IDLE.
boolean isConnecting = false;
Status aggStatus = EMPTY_OK;
for (Subchannel subchannel : getSubchannels()) {
ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value;
// This subchannel IDLE is not because of channel IDLE_TIMEOUT,
// in which case LB is already shutdown.
// RRLB will request connection immediately on subchannel IDLE.
if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) {
for (ChildLbState childLbState : getChildLbStates()) {
ConnectivityState state = childLbState.getCurrentState();
if (state == CONNECTING || state == IDLE) {
isConnecting = true;
}
if (aggStatus == EMPTY_OK || !aggStatus.isOk()) {
aggStatus = stateInfo.getStatus();
break;
}
}
updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE,
// If all subchannels are TRANSIENT_FAILURE, return the Status associated with
// an arbitrary subchannel, otherwise return OK.
new EmptyPicker(aggStatus));
if (isConnecting) {
updateBalancingState(CONNECTING, new EmptyPicker(Status.OK));
} else {
updateBalancingState(TRANSIENT_FAILURE, createReadyPicker(getChildLbStates()));
}
} else {
updateBalancingState(READY, createReadyPicker(activeList));
}
}
private void updateBalancingState(ConnectivityState state, RoundRobinPicker picker) {
if (state != currentState || !picker.isEquivalentTo(currentPicker)) {
helper.updateBalancingState(state, picker);
currentState = state;
if (state != currentConnectivityState || !picker.isEquivalentTo(currentPicker)) {
getHelper().updateBalancingState(state, picker);
currentConnectivityState = state;
currentPicker = picker;
}
}
protected RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
protected RoundRobinPicker createReadyPicker(Collection<ChildLbState> children) {
// initialize the Picker to a random start index to ensure that a high frequency of Picker
// churn does not skew subchannel selection.
int startIndex = random.nextInt(activeList.size());
return new ReadyPicker(activeList, startIndex);
}
int startIndex = random.nextInt(children.size());
/**
* Filters out non-ready subchannels.
*/
private static List<Subchannel> filterNonFailingSubchannels(
Collection<Subchannel> subchannels) {
List<Subchannel> readySubchannels = new ArrayList<>(subchannels.size());
for (Subchannel subchannel : subchannels) {
if (isReady(subchannel)) {
readySubchannels.add(subchannel);
}
List<SubchannelPicker> pickerList = new ArrayList<>();
for (ChildLbState child : children) {
SubchannelPicker picker = child.getCurrentPicker();
pickerList.add(picker);
}
return readySubchannels;
return new ReadyPicker(pickerList, startIndex);
}
/**
* Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
* remove all attributes. The values are the original EAGs.
*/
private static Map<EquivalentAddressGroup, EquivalentAddressGroup> stripAttrs(
List<EquivalentAddressGroup> groupList) {
Map<EquivalentAddressGroup, EquivalentAddressGroup> addrs = new HashMap<>(groupList.size() * 2);
for (EquivalentAddressGroup group : groupList) {
addrs.put(stripAttrs(group), group);
}
return addrs;
}
private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
return new EquivalentAddressGroup(eag.getAddresses());
}
@VisibleForTesting
protected Collection<Subchannel> getSubchannels() {
return subchannels.values();
}
private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef(
Subchannel subchannel) {
return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
}
// package-private to avoid synthetic access
static boolean isReady(Subchannel subchannel) {
return getSubchannelStateInfoRef(subchannel).value.getState() == READY;
}
private static <T> Set<T> setsDifference(Set<T> a, Set<T> b) {
Set<T> aCopy = new HashSet<>(a);
aCopy.removeAll(b);
return aCopy;
}
// Only subclasses are ReadyPicker or EmptyPicker
public abstract static class RoundRobinPicker extends SubchannelPicker {
public abstract boolean isEquivalentTo(RoundRobinPicker picker);
}
@ -289,40 +127,42 @@ public class RoundRobinLoadBalancer extends LoadBalancer {
private static final AtomicIntegerFieldUpdater<ReadyPicker> indexUpdater =
AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index");
private final List<Subchannel> list; // non-empty
private final List<SubchannelPicker> subchannelPickers; // non-empty
@SuppressWarnings("unused")
private volatile int index;
public ReadyPicker(List<Subchannel> list, int startIndex) {
public ReadyPicker(List<SubchannelPicker> list, int startIndex) {
Preconditions.checkArgument(!list.isEmpty(), "empty list");
this.list = list;
this.subchannelPickers = list;
this.index = startIndex - 1;
}
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
return PickResult.withSubchannel(nextSubchannel());
return subchannelPickers.get(nextIndex()).pickSubchannel(args);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(ReadyPicker.class).add("list", list).toString();
return MoreObjects.toStringHelper(ReadyPicker.class)
.add("subchannelPickers", subchannelPickers)
.toString();
}
private Subchannel nextSubchannel() {
int size = list.size();
private int nextIndex() {
int size = subchannelPickers.size();
int i = indexUpdater.incrementAndGet(this);
if (i >= size) {
int oldi = i;
i %= size;
indexUpdater.compareAndSet(this, oldi, i);
}
return list.get(i);
return i;
}
@VisibleForTesting
List<Subchannel> getList() {
return list;
List<SubchannelPicker> getSubchannelPickers() {
return subchannelPickers;
}
@Override
@ -333,7 +173,8 @@ public class RoundRobinLoadBalancer extends LoadBalancer {
ReadyPicker other = (ReadyPicker) picker;
// the lists cannot contain duplicate subchannels
return other == this
|| (list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list));
|| (subchannelPickers.size() == other.subchannelPickers.size() && new HashSet<>(
subchannelPickers).containsAll(other.subchannelPickers));
}
}

View File

@ -512,7 +512,7 @@ public class OutlierDetectionLoadBalancerTest {
loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers));
generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8);
generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 12);
// Move forward in time to a point where the detection timer has fired.
forwardTime(config);
@ -546,7 +546,7 @@ public class OutlierDetectionLoadBalancerTest {
assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0)));
// Now we produce more load, but the subchannel start working and is no longer an outlier.
generateLoad(ImmutableMap.of(), 8);
generateLoad(ImmutableMap.of(), 12);
// Move forward in time to a point where the detection timer has fired.
fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS);

View File

@ -22,26 +22,23 @@ import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.util.RoundRobinLoadBalancer.STATE_INFO;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
@ -53,18 +50,20 @@ import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.Status;
import io.grpc.internal.TestUtils;
import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker;
import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker;
import io.grpc.util.RoundRobinLoadBalancer.Ref;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@ -75,10 +74,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
import org.mockito.stubbing.Answer;
/** Unit test for {@link RoundRobinLoadBalancer}. */
@RunWith(JUnit4.class)
@ -89,9 +86,8 @@ public class RoundRobinLoadBalancerTest {
private RoundRobinLoadBalancer loadBalancer;
private final List<EquivalentAddressGroup> servers = Lists.newArrayList();
private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels =
new ConcurrentHashMap<>();
private final Attributes affinity =
Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build();
@ -101,8 +97,8 @@ public class RoundRobinLoadBalancerTest {
private ArgumentCaptor<ConnectivityState> stateCaptor;
@Captor
private ArgumentCaptor<CreateSubchannelArgs> createArgsCaptor;
@Mock
private Helper mockHelper;
private TestHelper testHelperInst = new TestHelper();
private Helper mockHelper = mock(Helper.class, delegatesTo(testHelperInst));
@Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown().
private PickSubchannelArgs mockArgs;
@ -113,34 +109,16 @@ public class RoundRobinLoadBalancerTest {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.add(eag);
Subchannel sc = mock(Subchannel.class);
subchannels.put(Arrays.asList(eag), sc);
}
when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class)))
.then(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
final Subchannel subchannel = subchannels.get(args.getAddresses());
when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
when(subchannel.getAttributes()).thenReturn(args.getAttributes());
doAnswer(
new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
subchannelStateListeners.put(
subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
return null;
}
}).when(subchannel).start(any(SubchannelStateListener.class));
return subchannel;
}
});
loadBalancer = new RoundRobinLoadBalancer(mockHelper);
}
private boolean acceptAddresses(List<EquivalentAddressGroup> eagList, Attributes attrs) {
return loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(eagList).setAttributes(attrs).build());
}
@After
public void tearDown() throws Exception {
verifyNoMoreInteractions(mockArgs);
@ -148,10 +126,9 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterResolved() throws Exception {
final Subchannel readySubchannel = subchannels.values().iterator().next();
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build());
boolean addressesAccepted = acceptAddresses(servers, affinity);
assertThat(addressesAccepted).isTrue();
final Subchannel readySubchannel = subchannels.values().iterator().next();
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture());
@ -178,10 +155,6 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterResolvedUpdatedHosts() throws Exception {
Subchannel removedSubchannel = mock(Subchannel.class);
Subchannel oldSubchannel = mock(Subchannel.class);
Subchannel newSubchannel = mock(Subchannel.class);
Attributes.Key<String> key = Attributes.Key.create("check-that-it-is-propagated");
FakeSocketAddress removedAddr = new FakeSocketAddress("removed");
EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr);
@ -193,6 +166,13 @@ public class RoundRobinLoadBalancerTest {
EquivalentAddressGroup newEag = new EquivalentAddressGroup(
newAddr, Attributes.newBuilder().set(key, "newattr").build());
Subchannel removedSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder()
.setAddresses(removedEag).build());
Subchannel oldSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder()
.setAddresses(oldEag1).build());
Subchannel newSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder()
.setAddresses(newEag).build());
subchannels.put(Collections.singletonList(removedEag), removedSubchannel);
subchannels.put(Collections.singletonList(oldEag1), oldSubchannel);
subchannels.put(Collections.singletonList(newEag), newSubchannel);
@ -201,9 +181,7 @@ public class RoundRobinLoadBalancerTest {
InOrder inOrder = inOrder(mockHelper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity)
.build());
boolean addressesAccepted = acceptAddresses(currentServers, affinity);
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
@ -218,8 +196,11 @@ public class RoundRobinLoadBalancerTest {
verify(removedSubchannel, times(1)).requestConnection();
verify(oldSubchannel, times(1)).requestConnection();
assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel,
oldSubchannel);
assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2);
assertThat(loadBalancer.getChildLbStateEag(removedEag).getCurrentPicker().pickSubchannel(null)
.getSubchannel()).isEqualTo(removedSubchannel);
assertThat(loadBalancer.getChildLbStateEag(oldEag1).getCurrentPicker().pickSubchannel(null)
.getSubchannel()).isEqualTo(oldSubchannel);
// This time with Attributes
List<EquivalentAddressGroup> latestServers = Lists.newArrayList(oldEag2, newEag);
@ -232,13 +213,15 @@ public class RoundRobinLoadBalancerTest {
verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2));
verify(removedSubchannel, times(1)).shutdown();
deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN));
deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY));
assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel,
newSubchannel);
assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2);
assertThat(loadBalancer.getChildLbStateEag(newEag).getCurrentPicker()
.pickSubchannel(null).getSubchannel()).isEqualTo(newSubchannel);
assertThat(loadBalancer.getChildLbStateEag(oldEag2).getCurrentPicker()
.pickSubchannel(null).getSubchannel()).isEqualTo(oldSubchannel);
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(mockHelper, times(6)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture());
picker = pickerCaptor.getValue();
@ -250,29 +233,26 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterStateChange() throws Exception {
InOrder inOrder = inOrder(mockHelper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
assertThat(addressesAccepted).isTrue();
Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
STATE_INFO);
// TODO figure out if this method testing the right things
ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next();
Subchannel subchannel = childLbState.getCurrentPicker().pickSubchannel(null).getSubchannel();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));
assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING);
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(READY));
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class);
assertThat(subchannelStateInfo.value).isEqualTo(
ConnectivityStateInfo.forNonError(READY));
assertThat(childLbState.getCurrentState()).isEqualTo(READY);
Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯");
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forTransientFailure(error));
assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error);
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
inOrder.verify(mockHelper).refreshNameResolution();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
@ -280,8 +260,7 @@ public class RoundRobinLoadBalancerTest {
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).refreshNameResolution();
assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error);
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
verify(subchannel, times(2)).requestConnection();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
@ -291,15 +270,14 @@ public class RoundRobinLoadBalancerTest {
@Test
public void ignoreShutdownSubchannelStateChange() {
InOrder inOrder = inOrder(mockHelper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
loadBalancer.shutdown();
for (Subchannel sc : loadBalancer.getSubchannels()) {
verify(sc).shutdown();
for (ChildLbState child : loadBalancer.getChildLbStates()) {
Subchannel sc = child.getCurrentPicker().pickSubchannel(null).getSubchannel();
verify(child).shutdown();
// When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered
// back to the subchannel state listener.
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN));
@ -311,36 +289,34 @@ public class RoundRobinLoadBalancerTest {
@Test
public void stayTransientFailureUntilReady() {
InOrder inOrder = inOrder(mockHelper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
Map<ChildLbState, Subchannel> childToSubChannelMap = new HashMap<>();
// Simulate state transitions for each subchannel individually.
for (Subchannel sc : loadBalancer.getSubchannels()) {
for ( ChildLbState child : loadBalancer.getChildLbStates()) {
Subchannel sc = child.getSubchannels(mockArgs);
childToSubChannelMap.put(child, sc);
Status error = Status.UNKNOWN.withDescription("connection broken");
deliverSubchannelState(
sc,
ConnectivityStateInfo.forTransientFailure(error));
assertEquals(TRANSIENT_FAILURE, child.getCurrentState());
inOrder.verify(mockHelper).refreshNameResolution();
deliverSubchannelState(
sc,
ConnectivityStateInfo.forNonError(CONNECTING));
Ref<ConnectivityStateInfo> scStateInfo = sc.getAttributes().get(
STATE_INFO);
assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(scStateInfo.value.getStatus()).isEqualTo(error);
assertEquals(TRANSIENT_FAILURE, child.getCurrentState());
}
inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class));
inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class));
inOrder.verifyNoMoreInteractions();
Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
ChildLbState child = loadBalancer.getChildLbStates().iterator().next();
Subchannel subchannel = childToSubChannelMap.get(child);
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
STATE_INFO);
assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY));
assertThat(child.getCurrentState()).isEqualTo(READY);
inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class));
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
@ -350,16 +326,15 @@ public class RoundRobinLoadBalancerTest {
@Test
public void refreshNameResolutionWhenSubchannelConnectionBroken() {
InOrder inOrder = inOrder(mockHelper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
assertThat(addressesAccepted).isTrue();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
// Simulate state transitions for each subchannel individually.
for (Subchannel sc : loadBalancer.getSubchannels()) {
for (ChildLbState child : loadBalancer.getChildLbStates()) {
Subchannel sc = child.getSubchannels(mockArgs);
verify(sc).requestConnection();
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING));
Status error = Status.UNKNOWN.withDescription("connection broken");
@ -383,11 +358,12 @@ public class RoundRobinLoadBalancerTest {
Subchannel subchannel1 = mock(Subchannel.class);
Subchannel subchannel2 = mock(Subchannel.class);
ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(
Lists.newArrayList(subchannel, subchannel1, subchannel2)),
0 /* startIndex */);
ArrayList<SubchannelPicker> pickers = Lists.newArrayList(
TestUtils.pickerOf(subchannel), TestUtils.pickerOf(subchannel1),
TestUtils.pickerOf(subchannel2));
assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2);
ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(pickers),
0 /* startIndex */);
assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(subchannel1, picker.pickSubchannel(mockArgs).getSubchannel());
@ -399,7 +375,7 @@ public class RoundRobinLoadBalancerTest {
public void pickerEmptyList() throws Exception {
SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN);
assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel());
assertNull(picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(Status.UNKNOWN,
picker.pickSubchannel(mockArgs).getStatus());
}
@ -417,12 +393,13 @@ public class RoundRobinLoadBalancerTest {
@Test
public void nameResolutionErrorWithActiveChannels() throws Exception {
boolean addressesAccepted = acceptAddresses(servers, affinity);
final Subchannel readySubchannel = subchannels.values().iterator().next();
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build());
assertThat(addressesAccepted).isTrue();
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
loadBalancer.resolvingAddresses = true;
loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
loadBalancer.resolvingAddresses = false;
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(mockHelper, times(2))
@ -443,15 +420,14 @@ public class RoundRobinLoadBalancerTest {
@Test
public void subchannelStateIsolation() throws Exception {
boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
assertThat(addressesAccepted).isTrue();
Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
Subchannel sc3 = subchannelIterator.next();
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
assertThat(addressesAccepted).isTrue();
verify(sc1, times(1)).requestConnection();
verify(sc2, times(1)).requestConnection();
verify(sc3, times(1)).requestConnection();
@ -491,7 +467,7 @@ public class RoundRobinLoadBalancerTest {
public void readyPicker_emptyList() {
// ready picker list must be non-empty
try {
new ReadyPicker(Collections.<Subchannel>emptyList(), 0);
new ReadyPicker(Collections.emptyList(), 0);
fail();
} catch (IllegalArgumentException expected) {
}
@ -503,9 +479,10 @@ public class RoundRobinLoadBalancerTest {
EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK"));
EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"));
acceptAddresses(servers, Attributes.EMPTY); // create subchannels
Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
SubchannelPicker sc1 = TestUtils.pickerOf(subchannelIterator.next());
SubchannelPicker sc2 = TestUtils.pickerOf(subchannelIterator.next());
ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 0);
ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 0);
ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 1);
@ -526,18 +503,26 @@ public class RoundRobinLoadBalancerTest {
public void emptyAddresses() {
assertThat(loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(Collections.<EquivalentAddressGroup>emptyList())
.setAddresses(Collections.emptyList())
.setAttributes(affinity)
.build())).isFalse();
}
private static List<Subchannel> getList(SubchannelPicker picker) {
return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() :
Collections.<Subchannel>emptyList();
private List<Subchannel> getList(SubchannelPicker picker) {
if (picker instanceof ReadyPicker) {
List<Subchannel> subchannelList = new ArrayList<>();
for (SubchannelPicker childPicker : ((ReadyPicker) picker).getSubchannelPickers()) {
subchannelList.add(childPicker.pickSubchannel(mockArgs).getSubchannel());
}
return subchannelList;
} else {
return new ArrayList<>();
}
}
private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {
subchannelStateListeners.get(subchannel).onSubchannelState(newState);
testHelperInst.deliverSubchannelState(subchannel, newState);
}
private static class FakeSocketAddress extends SocketAddress {
@ -552,4 +537,12 @@ public class RoundRobinLoadBalancerTest {
return "FakeSocketAddress-" + name;
}
}
private class TestHelper extends AbstractTestHelper {
@Override
public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() {
return subchannels;
}
}
}

View File

@ -0,0 +1,196 @@
/*
* Copyright 2023 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.util;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Mockito.mock;
import com.google.common.collect.Maps;
import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.ChannelLogger;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* A real class that can be used as a delegate of a mock Helper to provide more real representation
* and track the subchannels as is needed with petiole policies where the subchannels are no
* longer direct children of the loadbalancer.
* <br>
* To use it replace <br>
* \@mock Helper mockHelper<br>
* with<br>
* <p>Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));</p>
* <br>
* TestHelper will need to define accessors for the maps that information is store within as
* those maps need to be defined in the Test class.
*/
public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
public abstract Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap();
public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() {
return mockToRealSubChannelMap;
}
public Subchannel getRealForMockSubChannel(Subchannel mock) {
Subchannel realSc = getMockToRealSubChannelMap().get(mock);
if (realSc == null) {
realSc = mock;
}
return realSc;
}
public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() {
return subchannelStateListeners;
}
public void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {
Subchannel realSc = getMockToRealSubChannelMap().get(subchannel);
if (realSc == null) {
realSc = subchannel;
}
SubchannelStateListener listener = getSubchannelStateListeners().get(realSc);
if (listener == null) {
throw new IllegalArgumentException("subchannel does not have a matching listener");
}
listener.onSubchannelState(newState);
}
@Override
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
// do nothing, should have been done in the wrapper helpers
}
@Override
protected Helper delegate() {
throw new UnsupportedOperationException("This helper class is only for use in this test");
}
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
Subchannel subchannel = getSubchannelMap().get(args.getAddresses());
if (subchannel == null) {
TestSubchannel delegate = new TestSubchannel(args);
subchannel = mock(Subchannel.class, delegatesTo(delegate));
getSubchannelMap().put(args.getAddresses(), subchannel);
getMockToRealSubChannelMap().put(subchannel, delegate);
}
return subchannel;
}
@Override
public void refreshNameResolution() {
// no-op
}
public void setChannel(Subchannel subchannel, Channel channel) {
((TestSubchannel)subchannel).channel = channel;
}
@Override
public String toString() {
return "Test Helper";
}
private class TestSubchannel extends ForwardingSubchannel {
CreateSubchannelArgs args;
Channel channel;
public TestSubchannel(CreateSubchannelArgs args) {
this.args = args;
}
@Override
protected Subchannel delegate() {
throw new UnsupportedOperationException("Only to be used in tests");
}
@Override
public List<EquivalentAddressGroup> getAllAddresses() {
return args.getAddresses();
}
@Override
public Attributes getAttributes() {
return args.getAttributes();
}
@Override
public void requestConnection() {
// Ignore, we will manually update state
}
@Override
public void updateAddresses(List<EquivalentAddressGroup> addrs) {
if (args.getAddresses().equals(addrs)) {
return; // no changes so it's a no-op
}
List<EquivalentAddressGroup> oldAddrs = args.getAddresses();
Subchannel oldTarget = getSubchannelMap().get(oldAddrs);
this.args = args.toBuilder().setAddresses(addrs).build();
getSubchannelMap().put(addrs, oldTarget);
getSubchannelMap().remove(oldAddrs);
}
@Override
public void start(SubchannelStateListener listener) {
getSubchannelStateListeners().put(this, listener);
}
@Override
public void shutdown() {
getSubchannelStateListeners().remove(this);
for (EquivalentAddressGroup eag : getAllAddresses()) {
getSubchannelMap().remove(Collections.singletonList(eag));
}
}
@Override
public Channel asChannel() {
return channel;
}
@Override
public ChannelLogger getChannelLogger() {
return mock(ChannelLogger.class);
}
@Override
public String toString() {
return "Mock Subchannel" + args.toString();
}
}
}

View File

@ -58,7 +58,8 @@ dependencies {
def nettyDependency = implementation project(':grpc-netty')
testImplementation project(':grpc-rls')
testImplementation testFixtures(project(':grpc-core'))
testImplementation testFixtures(project(':grpc-core')),
testFixtures(project(':grpc-util'))
annotationProcessor libraries.auto.value
// At runtime use the epoll included in grpc-netty-shaded

View File

@ -16,36 +16,77 @@
package io.grpc.xds;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancerProvider;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.util.MultiChildLoadBalancer;
import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig;
import io.grpc.xds.XdsLogger.XdsLogLevel;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
/**
* The top-level load balancing policy.
* The top-level load balancing policy for use in XDS.
* This policy does not immediately delete its children. Instead, it marks them deactivated
* and starts a timer for deletion. If a subsequent address update restores the child, then it is
* simply reactivated instead of built from scratch. This is necessary because XDS can frequently
* remove and then add back a server as machines are rebooted or repurposed for load management.
*
* <p>Note that this LB does not automatically reconnect children who go into IDLE status
*/
class ClusterManagerLoadBalancer extends MultiChildLoadBalancer {
// 15 minutes is long enough for a reboot and the services to restart while not so long that
// many children are waiting for cleanup.
@VisibleForTesting
public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15;
protected final SynchronizationContext syncContext;
private final ScheduledExecutorService timeService;
private final XdsLogger logger;
ClusterManagerLoadBalancer(Helper helper) {
super(helper);
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
logger = XdsLogger.withLogId(
InternalLogId.allocate("cluster_manager-lb", helper.getAuthority()));
logger.log(XdsLogLevel.INFO, "Created");
}
@Override
protected Map<Object, PolicySelection> getPolicySelectionMap(
ResolvedAddresses resolvedAddresses) {
protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses,
Object childConfig) {
return resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build();
}
@Override
protected Map<Object, ChildLbState> createChildLbMap(ResolvedAddresses resolvedAddresses) {
ClusterManagerConfig config = (ClusterManagerConfig)
resolvedAddresses.getLoadBalancingPolicyConfig();
Map<Object, PolicySelection> newChildPolicies = new HashMap<>(config.childPolicies);
Map<Object, ChildLbState> newChildPolicies = new HashMap<>();
if (config != null) {
for (Entry<String, PolicySelection> entry : config.childPolicies.entrySet()) {
ChildLbState child = getChildLbState(entry.getKey());
if (child == null) {
child = new ClusterManagerLbState(entry.getKey(),
entry.getValue().getProvider(), entry.getValue().getConfig(), getInitialPicker());
}
newChildPolicies.put(entry.getKey(), child);
}
}
logger.log(
XdsLogLevel.INFO,
"Received cluster_manager lb config: child names={0}", newChildPolicies.keySet());
@ -75,4 +116,84 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer {
}
};
}
@Override
public void handleNameResolutionError(Status error) {
logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error);
boolean gotoTransientFailure = true;
for (ChildLbState state : getChildLbStates()) {
if (!state.isDeactivated()) {
gotoTransientFailure = false;
handleNameResolutionError(state, error);
}
}
if (gotoTransientFailure) {
getHelper().updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error));
}
}
@Override
protected boolean reconnectOnIdle() {
return false;
}
/**
* This differs from the base class in the use of the deletion timer. When it is deactivated,
* rather than immediately calling shutdown it starts a timer. If shutdown or reactivate
* are called before the timer fires, the timer is canceled. Otherwise, time timer calls shutdown
* and removes the child from the petiole policy when it is triggered.
*/
private class ClusterManagerLbState extends ChildLbState {
@Nullable
ScheduledHandle deletionTimer;
public ClusterManagerLbState(Object key, LoadBalancerProvider policyProvider,
Object childConfig, SubchannelPicker initialPicker) {
super(key, policyProvider, childConfig, initialPicker);
}
@Override
protected void shutdown() {
if (deletionTimer != null && deletionTimer.isPending()) {
deletionTimer.cancel();
}
super.shutdown();
}
@Override
protected void reactivate(LoadBalancerProvider policyProvider) {
if (deletionTimer != null && deletionTimer.isPending()) {
deletionTimer.cancel();
logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey());
}
super.reactivate(policyProvider);
}
@Override
protected void deactivate() {
if (isDeactivated()) {
return;
}
class DeletionTask implements Runnable {
@Override
public void run() {
shutdown();
removeChild(getKey());
}
}
deletionTimer =
syncContext.schedule(
new DeletionTask(),
DELAYED_CHILD_DELETION_TIME_MINUTES,
TimeUnit.MINUTES,
timeService);
setDeactivated();
logger.log(XdsLogLevel.DEBUG, "Child balancer {0} deactivated", getKey());
}
}
}

View File

@ -21,7 +21,6 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.xds.LeastRequestLoadBalancerProvider.DEFAULT_CHOICE_COUNT;
import static io.grpc.xds.LeastRequestLoadBalancerProvider.MAX_CHOICE_COUNT;
@ -35,20 +34,17 @@ import io.grpc.Attributes;
import io.grpc.ClientStreamTracer;
import io.grpc.ClientStreamTracer.StreamInfo;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.util.MultiChildLoadBalancer;
import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nonnull;
@ -60,21 +56,13 @@ import javax.annotation.Nonnull;
* The default sampling amount of two is also known as
* the "power of two choices" (P2C).
*/
final class LeastRequestLoadBalancer extends LoadBalancer {
@VisibleForTesting
static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.create("state-info");
@VisibleForTesting
static final Attributes.Key<AtomicInteger> IN_FLIGHTS =
Attributes.Key.create("in-flights");
final class LeastRequestLoadBalancer extends MultiChildLoadBalancer {
private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready");
private static final EmptyPicker EMPTY_LR_PICKER = new EmptyPicker(EMPTY_OK);
private final Helper helper;
private final ThreadSafeRandom random;
private final Map<EquivalentAddressGroup, Subchannel> subchannels =
new HashMap<>();
private ConnectivityState currentState;
private LeastRequestPicker currentPicker = new EmptyPicker(EMPTY_OK);
private LeastRequestPicker currentPicker = EMPTY_LR_PICKER;
private int choiceCount = DEFAULT_CHOICE_COUNT;
LeastRequestLoadBalancer(Helper helper) {
@ -83,255 +71,167 @@ final class LeastRequestLoadBalancer extends LoadBalancer {
@VisibleForTesting
LeastRequestLoadBalancer(Helper helper, ThreadSafeRandom random) {
this.helper = checkNotNull(helper, "helper");
super(helper);
this.random = checkNotNull(random, "random");
}
@Override
protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) {
throw new UnsupportedOperationException(
"LeastRequestLoadBalancer uses its ChildLbStates, not these child pickers directly");
}
@Override
public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
if (resolvedAddresses.getAddresses().isEmpty()) {
handleNameResolutionError(Status.UNAVAILABLE.withDescription(
"NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses()
+ ", attrs=" + resolvedAddresses.getAttributes()));
return false;
}
// Need to update choiceCount before calling super so that the updateBalancingState call has the
// new value. However, if the update fails we need to revert it.
int oldChoiceCount = choiceCount;
LeastRequestConfig config =
(LeastRequestConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
// Config may be null if least_request is used outside xDS
if (config != null) {
choiceCount = config.choiceCount;
}
List<EquivalentAddressGroup> servers = resolvedAddresses.getAddresses();
Set<EquivalentAddressGroup> currentAddrs = subchannels.keySet();
Map<EquivalentAddressGroup, EquivalentAddressGroup> latestAddrs = stripAttrs(servers);
Set<EquivalentAddressGroup> removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet());
boolean successfulUpdate = super.acceptResolvedAddresses(resolvedAddresses);
for (Map.Entry<EquivalentAddressGroup, EquivalentAddressGroup> latestEntry :
latestAddrs.entrySet()) {
EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey();
EquivalentAddressGroup originalAddressGroup = latestEntry.getValue();
Subchannel existingSubchannel = subchannels.get(strippedAddressGroup);
if (existingSubchannel != null) {
// EAG's Attributes may have changed.
existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup));
continue;
}
// Create new subchannels for new addresses.
Attributes.Builder subchannelAttrs = Attributes.newBuilder()
.set(STATE_INFO, new Ref<>(ConnectivityStateInfo.forNonError(IDLE)))
// Used to track the in flight requests on this particular subchannel
.set(IN_FLIGHTS, new AtomicInteger(0));
final Subchannel subchannel = checkNotNull(
helper.createSubchannel(CreateSubchannelArgs.newBuilder()
.setAddresses(originalAddressGroup)
.setAttributes(subchannelAttrs.build())
.build()),
"subchannel");
subchannel.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo state) {
processSubchannelState(subchannel, state);
}
});
subchannels.put(strippedAddressGroup, subchannel);
subchannel.requestConnection();
if (!successfulUpdate) {
choiceCount = oldChoiceCount;
}
ArrayList<Subchannel> removedSubchannels = new ArrayList<>();
for (EquivalentAddressGroup addressGroup : removedAddrs) {
removedSubchannels.add(subchannels.remove(addressGroup));
}
// Update the picker before shutting down the subchannels, to reduce the chance of the race
// between picking a subchannel and shutting it down.
updateBalancingState();
// Shutdown removed subchannels
for (Subchannel removedSubchannel : removedSubchannels) {
shutdownSubchannel(removedSubchannel);
}
return true;
return successfulUpdate;
}
@Override
public void handleNameResolutionError(Status error) {
if (currentState != READY) {
updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error));
}
protected SubchannelPicker getErrorPicker(Status error) {
return new EmptyPicker(error);
}
private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) {
return;
}
if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) {
helper.refreshNameResolution();
}
if (stateInfo.getState() == IDLE) {
subchannel.requestConnection();
}
Ref<ConnectivityStateInfo> subchannelStateRef = getSubchannelStateInfoRef(subchannel);
if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) {
if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) {
return;
}
}
subchannelStateRef.value = stateInfo;
updateBalancingState();
}
private void shutdownSubchannel(Subchannel subchannel) {
subchannel.shutdown();
getSubchannelStateInfoRef(subchannel).value =
ConnectivityStateInfo.forNonError(SHUTDOWN);
}
@Override
public void shutdown() {
for (Subchannel subchannel : getSubchannels()) {
shutdownSubchannel(subchannel);
}
subchannels.clear();
}
private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready");
/**
* Updates picker with the list of active subchannels (state == READY).
*
* <p>
* If no active subchannels exist, but some are in TRANSIENT_FAILURE then returns a picker
* with all of the children in TF so that the application code will get an error from a varying
* random one when it tries to get a subchannel.
* </p>
*/
@SuppressWarnings("ReferenceEquality")
private void updateBalancingState() {
List<Subchannel> activeList = filterNonFailingSubchannels(getSubchannels());
@Override
protected void updateOverallBalancingState() {
List<ChildLbState> activeList = getReadyChildren();
if (activeList.isEmpty()) {
// No READY subchannels, determine aggregate state and error status
boolean isConnecting = false;
Status aggStatus = EMPTY_OK;
for (Subchannel subchannel : getSubchannels()) {
ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value;
// This subchannel IDLE is not because of channel IDLE_TIMEOUT,
// in which case LB is already shutdown.
// LRLB will request connection immediately on subchannel IDLE.
if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) {
List<ChildLbState> childrenInTf = new ArrayList<>();
for (ChildLbState childLbState : getChildLbStates()) {
ConnectivityState state = childLbState.getCurrentState();
if (state == CONNECTING || state == IDLE) {
isConnecting = true;
}
if (aggStatus == EMPTY_OK || !aggStatus.isOk()) {
aggStatus = stateInfo.getStatus();
} else if (state == TRANSIENT_FAILURE) {
childrenInTf.add(childLbState);
}
}
updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE,
// If all subchannels are TRANSIENT_FAILURE, return the Status associated with
// an arbitrary subchannel, otherwise return OK.
new EmptyPicker(aggStatus));
if (isConnecting) {
updateBalancingState(CONNECTING, EMPTY_LR_PICKER);
} else {
// Give it all the failing children and let it randomly pick among them
updateBalancingState(TRANSIENT_FAILURE,
new ReadyPicker(childrenInTf, choiceCount, random));
}
} else {
updateBalancingState(READY, new ReadyPicker(activeList, choiceCount, random));
}
}
@Override
protected ChildLbState createChildLbState(Object key, Object policyConfig,
SubchannelPicker initialPicker) {
return new LeastRequestLbState(key, pickFirstLbProvider, policyConfig, initialPicker);
}
private void updateBalancingState(ConnectivityState state, LeastRequestPicker picker) {
if (state != currentState || !picker.isEquivalentTo(currentPicker)) {
helper.updateBalancingState(state, picker);
currentState = state;
if (state != currentConnectivityState || !picker.isEquivalentTo(currentPicker)) {
super.updateHelperBalancingState(state, picker);
currentConnectivityState = state;
currentPicker = picker;
}
}
/**
* Filters out non-ready subchannels.
* This should ONLY be used by tests.
*/
private static List<Subchannel> filterNonFailingSubchannels(
Collection<Subchannel> subchannels) {
List<Subchannel> readySubchannels = new ArrayList<>(subchannels.size());
for (Subchannel subchannel : subchannels) {
if (isReady(subchannel)) {
readySubchannels.add(subchannel);
}
}
return readySubchannels;
@VisibleForTesting
void setResolvingAddresses(boolean newValue) {
super.resolvingAddresses = newValue;
}
/**
* Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
* remove all attributes. The values are the original EAGs.
*/
private static Map<EquivalentAddressGroup, EquivalentAddressGroup> stripAttrs(
List<EquivalentAddressGroup> groupList) {
Map<EquivalentAddressGroup, EquivalentAddressGroup> addrs = new HashMap<>(groupList.size() * 2);
for (EquivalentAddressGroup group : groupList) {
addrs.put(stripAttrs(group), group);
}
return addrs;
// Expose for tests in this package.
@Override
protected Collection<ChildLbState> getChildLbStates() {
return super.getChildLbStates();
}
private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
return new EquivalentAddressGroup(eag.getAddresses());
// Expose for tests in this package.
@Override
protected ChildLbState getChildLbState(Object key) {
return super.getChildLbState(key);
}
// Expose for tests in this package.
private static AtomicInteger getInFlights(ChildLbState childLbState) {
return ((LeastRequestLbState)childLbState).activeRequests;
}
@VisibleForTesting
Collection<Subchannel> getSubchannels() {
return subchannels.values();
}
private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef(
Subchannel subchannel) {
return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
}
private static AtomicInteger getInFlights(Subchannel subchannel) {
return checkNotNull(subchannel.getAttributes().get(IN_FLIGHTS), "IN_FLIGHTS");
}
// package-private to avoid synthetic access
static boolean isReady(Subchannel subchannel) {
return getSubchannelStateInfoRef(subchannel).value.getState() == READY;
}
private static <T> Set<T> setsDifference(Set<T> a, Set<T> b) {
Set<T> aCopy = new HashSet<>(a);
aCopy.removeAll(b);
return aCopy;
}
// Only subclasses are ReadyPicker or EmptyPicker
private abstract static class LeastRequestPicker extends SubchannelPicker {
abstract static class LeastRequestPicker extends SubchannelPicker {
abstract boolean isEquivalentTo(LeastRequestPicker picker);
}
@VisibleForTesting
static final class ReadyPicker extends LeastRequestPicker {
private final List<Subchannel> list; // non-empty
private final List<ChildLbState> childLbStates; // non-empty
private final int choiceCount;
private final ThreadSafeRandom random;
ReadyPicker(List<Subchannel> list, int choiceCount, ThreadSafeRandom random) {
checkArgument(!list.isEmpty(), "empty list");
this.list = list;
ReadyPicker(List<ChildLbState> childLbStates, int choiceCount, ThreadSafeRandom random) {
checkArgument(!childLbStates.isEmpty(), "empty list");
this.childLbStates = childLbStates;
this.choiceCount = choiceCount;
this.random = checkNotNull(random, "random");
}
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
final Subchannel subchannel = nextSubchannel();
final OutstandingRequestsTracingFactory factory =
new OutstandingRequestsTracingFactory(getInFlights(subchannel));
return PickResult.withSubchannel(subchannel, factory);
final ChildLbState childLbState = nextChildToUse();
PickResult childResult = childLbState.getCurrentPicker().pickSubchannel(args);
if (!childResult.getStatus().isOk() || childResult.getSubchannel() == null) {
return childResult;
}
if (childResult.getStreamTracerFactory() != null) {
// Already wrapped, so just use the current picker for selected child
return childResult;
} else {
// Wrap the subchannel
OutstandingRequestsTracingFactory factory =
new OutstandingRequestsTracingFactory(getInFlights(childLbState));
return PickResult.withSubchannel(childResult.getSubchannel(), factory);
}
}
@Override
public String toString() {
return MoreObjects.toStringHelper(ReadyPicker.class)
.add("list", list)
.add("list", childLbStates)
.add("choiceCount", choiceCount)
.toString();
}
private Subchannel nextSubchannel() {
Subchannel candidate = list.get(random.nextInt(list.size()));
private ChildLbState nextChildToUse() {
ChildLbState candidate = childLbStates.get(random.nextInt(childLbStates.size()));
for (int i = 0; i < choiceCount - 1; ++i) {
Subchannel sampled = list.get(random.nextInt(list.size()));
ChildLbState sampled = childLbStates.get(random.nextInt(childLbStates.size()));
if (getInFlights(sampled).get() < getInFlights(candidate).get()) {
candidate = sampled;
}
@ -340,10 +240,11 @@ final class LeastRequestLoadBalancer extends LoadBalancer {
}
@VisibleForTesting
List<Subchannel> getList() {
return list;
List<ChildLbState> getChildLbStates() {
return childLbStates;
}
@VisibleForTesting
@Override
boolean isEquivalentTo(LeastRequestPicker picker) {
if (!(picker instanceof ReadyPicker)) {
@ -352,7 +253,8 @@ final class LeastRequestLoadBalancer extends LoadBalancer {
ReadyPicker other = (ReadyPicker) picker;
// the lists cannot contain duplicate subchannels
return other == this
|| ((list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list))
|| ((childLbStates.size() == other.childLbStates.size() && new HashSet<>(
childLbStates).containsAll(other.childLbStates))
&& choiceCount == other.choiceCount);
}
}
@ -381,16 +283,10 @@ final class LeastRequestLoadBalancer extends LoadBalancer {
public String toString() {
return MoreObjects.toStringHelper(EmptyPicker.class).add("status", status).toString();
}
}
/**
* A lighter weight Reference than AtomicReference.
*/
static final class Ref<T> {
T value;
Ref(T value) {
this.value = value;
@VisibleForTesting
Status getStatus() {
return status;
}
}
@ -435,4 +331,17 @@ final class LeastRequestLoadBalancer extends LoadBalancer {
.toString();
}
}
protected class LeastRequestLbState extends ChildLbState {
private final AtomicInteger activeRequests = new AtomicInteger(0);
public LeastRequestLbState(Object key, LoadBalancerProvider policyProvider,
Object childConfig, SubchannelPicker initialPicker) {
super(key, policyProvider, childConfig, initialPicker);
}
int getActiveRequests() {
return activeRequests.get();
}
}
}

View File

@ -17,17 +17,20 @@
package io.grpc.xds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkElementIndex;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.Deadline.Ticker;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.NameResolver;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
@ -40,11 +43,13 @@ import io.grpc.xds.orca.OrcaOobUtil;
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
import io.grpc.xds.orca.OrcaPerRequestUtil;
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@ -90,6 +95,14 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random);
}
@Override
protected ChildLbState createChildLbState(Object key, Object policyConfig,
SubchannelPicker initialPicker) {
ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig,
initialPicker);
return childLbState;
}
@Override
public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
@ -111,90 +124,30 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
@Override
public RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
return new WeightedRoundRobinPicker(activeList, config.enableOobLoadReport,
config.errorUtilizationPenalty);
}
private final class UpdateWeightTask implements Runnable {
@Override
public void run() {
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
((WeightedRoundRobinPicker) currentPicker).updateWeight();
}
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
TimeUnit.NANOSECONDS, timeService);
}
}
private void afterAcceptAddresses() {
for (Subchannel subchannel : getSubchannels()) {
WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel;
if (config.enableOobLoadReport) {
OrcaOobUtil.setListener(weightedSubchannel,
weightedSubchannel.new OrcaReportListener(config.errorUtilizationPenalty),
OrcaOobUtil.OrcaReportingConfig.newBuilder()
.setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
.build());
} else {
OrcaOobUtil.setListener(weightedSubchannel, null, null);
}
}
public RoundRobinPicker createReadyPicker(Collection<ChildLbState> activeList) {
return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
config.enableOobLoadReport, config.errorUtilizationPenalty);
}
// Expose for tests in this package.
@Override
public void shutdown() {
if (weightUpdateTimer != null) {
weightUpdateTimer.cancel();
}
super.shutdown();
}
private static final class WrrHelper extends ForwardingLoadBalancerHelper {
private final Helper delegate;
private WeightedRoundRobinLoadBalancer wrr;
WrrHelper(Helper helper) {
this.delegate = helper;
}
void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) {
this.wrr = lb;
}
@Override
protected Helper delegate() {
return delegate;
}
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
return wrr.new WrrSubchannel(delegate().createSubchannel(args));
}
protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) {
return super.getChildLbStateEag(eag);
}
@VisibleForTesting
final class WrrSubchannel extends ForwardingSubchannel {
private final Subchannel delegate;
final class WeightedChildLbState extends ChildLbState {
private final Set<WrrSubchannel> subchannels = new HashSet<>();
private volatile long lastUpdated;
private volatile long nonEmptySince;
private volatile double weight;
private volatile double weight = 0;
WrrSubchannel(Subchannel delegate) {
this.delegate = checkNotNull(delegate, "delegate");
}
private OrcaReportListener orcaReportListener;
@Override
public void start(SubchannelStateListener listener) {
delegate().start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
if (newState.getState().equals(ConnectivityState.READY)) {
nonEmptySince = infTime;
}
listener.onSubchannelState(newState);
}
});
public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
SubchannelPicker initialPicker) {
super(key, policyProvider, childConfig, initialPicker);
}
private double getWeight() {
@ -213,9 +166,21 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
}
@Override
protected Subchannel delegate() {
return delegate;
public void addSubchannel(WrrSubchannel wrrSubchannel) {
subchannels.add(wrrSubchannel);
}
public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) {
if (orcaReportListener != null
&& orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) {
return orcaReportListener;
}
orcaReportListener = new OrcaReportListener(errorUtilizationPenalty);
return orcaReportListener;
}
public void removeSubchannel(WrrSubchannel wrrSubchannel) {
subchannels.remove(wrrSubchannel);
}
final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
@ -251,23 +216,124 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
}
private final class UpdateWeightTask implements Runnable {
@Override
public void run() {
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
((WeightedRoundRobinPicker) currentPicker).updateWeight();
}
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
TimeUnit.NANOSECONDS, timeService);
}
}
private void afterAcceptAddresses() {
for (ChildLbState child : getChildLbStates()) {
WeightedChildLbState wChild = (WeightedChildLbState) child;
for (WrrSubchannel weightedSubchannel : wChild.subchannels) {
if (config.enableOobLoadReport) {
OrcaOobUtil.setListener(weightedSubchannel,
wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty),
OrcaOobUtil.OrcaReportingConfig.newBuilder()
.setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
.build());
} else {
OrcaOobUtil.setListener(weightedSubchannel, null, null);
}
}
}
}
@Override
public void shutdown() {
if (weightUpdateTimer != null) {
weightUpdateTimer.cancel();
}
super.shutdown();
}
private static final class WrrHelper extends ForwardingLoadBalancerHelper {
private final Helper delegate;
private WeightedRoundRobinLoadBalancer wrr;
WrrHelper(Helper helper) {
this.delegate = helper;
}
void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) {
this.wrr = lb;
}
@Override
protected Helper delegate() {
return delegate;
}
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
checkElementIndex(0, args.getAddresses().size(), "Empty address group");
WeightedChildLbState childLbState =
(WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0));
return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState);
}
}
@VisibleForTesting
final class WrrSubchannel extends ForwardingSubchannel {
private final Subchannel delegate;
private final WeightedChildLbState owner;
WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) {
this.delegate = checkNotNull(delegate, "delegate");
this.owner = checkNotNull(owner, "owner");
}
@Override
public void start(SubchannelStateListener listener) {
owner.addSubchannel(this);
delegate().start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
if (newState.getState().equals(ConnectivityState.READY)) {
owner.nonEmptySince = infTime;
}
listener.onSubchannelState(newState);
}
});
}
@Override
protected Subchannel delegate() {
return delegate;
}
@Override
public void shutdown() {
super.shutdown();
owner.removeSubchannel(this);
}
}
@VisibleForTesting
final class WeightedRoundRobinPicker extends RoundRobinPicker {
private final List<Subchannel> list;
private final List<ChildLbState> children;
private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
new HashMap<>();
private final boolean enableOobLoadReport;
private final float errorUtilizationPenalty;
private volatile StaticStrideScheduler scheduler;
WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport,
WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
float errorUtilizationPenalty) {
checkNotNull(list, "list");
Preconditions.checkArgument(!list.isEmpty(), "empty list");
this.list = list;
for (Subchannel subchannel : list) {
this.subchannelToReportListenerMap.put(subchannel,
((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty));
checkNotNull(children, "children");
Preconditions.checkArgument(!children.isEmpty(), "empty child list");
this.children = children;
for (ChildLbState child : children) {
WeightedChildLbState wChild = (WeightedChildLbState) child;
for (WrrSubchannel subchannel : wChild.subchannels) {
this.subchannelToReportListenerMap
.put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
}
}
this.enableOobLoadReport = enableOobLoadReport;
this.errorUtilizationPenalty = errorUtilizationPenalty;
@ -276,22 +342,24 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
Subchannel subchannel = list.get(scheduler.pick());
ChildLbState childLbState = children.get(scheduler.pick());
WeightedChildLbState wChild = (WeightedChildLbState) childLbState;
PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);
Subchannel subchannel = pickResult.getSubchannel();
if (!enableOobLoadReport) {
return PickResult.withSubchannel(subchannel,
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannelToReportListenerMap.getOrDefault(subchannel,
((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty))));
wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));
} else {
return PickResult.withSubchannel(subchannel);
}
}
private void updateWeight() {
float[] newWeights = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
double newWeight = subchannel.getWeight();
float[] newWeights = new float[children.size()];
for (int i = 0; i < children.size(); i++) {
double newWeight = ((WeightedChildLbState)children.get(i)).getWeight();
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
this.scheduler = new StaticStrideScheduler(newWeights, sequence);
@ -302,12 +370,12 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
.add("enableOobLoadReport", enableOobLoadReport)
.add("errorUtilizationPenalty", errorUtilizationPenalty)
.add("list", list).toString();
.add("list", children).toString();
}
@VisibleForTesting
List<Subchannel> getList() {
return list;
List<ChildLbState> getChildren() {
return children;
}
@Override
@ -322,7 +390,8 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
// the lists cannot contain duplicate subchannels
return enableOobLoadReport == other.enableOobLoadReport
&& Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
&& list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list);
&& children.size() == other.children.size() && new HashSet<>(
children).containsAll(other.children);
}
}
@ -504,11 +573,13 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
@SuppressWarnings("UnusedReturnValue")
Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
this.blackoutPeriodNanos = blackoutPeriodNanos;
return this;
}
@SuppressWarnings("UnusedReturnValue")
Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
return this;

View File

@ -202,7 +202,9 @@ public final class OrcaOobUtil {
*/
public static void setListener(Subchannel subchannel, OrcaOobReportListener listener,
OrcaReportingConfig config) {
SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY);
Attributes attributes = subchannel.getAttributes();
SubchannelImpl orcaSubchannel =
(attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY);
if (orcaSubchannel == null) {
throw new IllegalArgumentException("Subchannel does not have orca Out-Of-Band stream enabled."
+ " Try to use a subchannel created by OrcaOobUtil.OrcaHelper.");
@ -241,7 +243,9 @@ public final class OrcaOobUtil {
public Subchannel createSubchannel(CreateSubchannelArgs args) {
syncContext.throwIfNotInThisSynchronizationContext();
Subchannel subchannel = super.createSubchannel(args);
SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY);
Attributes attributes = subchannel.getAttributes();
SubchannelImpl orcaSubchannel =
(attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY);
OrcaReportingState orcaState;
if (orcaSubchannel == null) {
// Only the first load balancing policy requesting ORCA reports instantiates an

View File

@ -22,17 +22,15 @@ import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.xds.LeastRequestLoadBalancer.IN_FLIGHTS;
import static io.grpc.xds.LeastRequestLoadBalancer.STATE_INFO;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -60,10 +58,13 @@ import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.util.AbstractTestHelper;
import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
import io.grpc.xds.LeastRequestLoadBalancer.EmptyPicker;
import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig;
import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestLbState;
import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestPicker;
import io.grpc.xds.LeastRequestLoadBalancer.ReadyPicker;
import io.grpc.xds.LeastRequestLoadBalancer.Ref;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
@ -71,6 +72,7 @@ import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@ -81,10 +83,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
import org.mockito.stubbing.Answer;
/** Unit test for {@link LeastRequestLoadBalancer}. */
@RunWith(JUnit4.class)
@ -96,8 +96,6 @@ public class LeastRequestLoadBalancerTest {
private LeastRequestLoadBalancer loadBalancer;
private final List<EquivalentAddressGroup> servers = Lists.newArrayList();
private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
private final Attributes affinity =
Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build();
@ -107,8 +105,9 @@ public class LeastRequestLoadBalancerTest {
private ArgumentCaptor<ConnectivityState> stateCaptor;
@Captor
private ArgumentCaptor<CreateSubchannelArgs> createArgsCaptor;
@Mock
private Helper mockHelper;
private final TestHelper testHelperInstance = new TestHelper();
private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance));
@Mock
private ThreadSafeRandom mockRandom;
@ -121,31 +120,9 @@ public class LeastRequestLoadBalancerTest {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.add(eag);
Subchannel sc = mock(Subchannel.class);
subchannels.put(Arrays.asList(eag), sc);
}
when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class)))
.then(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
final Subchannel subchannel = subchannels.get(args.getAddresses());
when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
when(subchannel.getAttributes()).thenReturn(args.getAttributes());
doAnswer(
new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
subchannelStateListeners.put(
subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
return null;
}
}).when(subchannel).start(any(SubchannelStateListener.class));
return subchannel;
}
});
loadBalancer = new LeastRequestLoadBalancer(mockHelper, mockRandom);
loadBalancer = new LeastRequestLoadBalancer(helper, mockRandom);
}
@After
@ -156,13 +133,13 @@ public class LeastRequestLoadBalancerTest {
@Test
public void pickAfterResolved() throws Exception {
final Subchannel readySubchannel = subchannels.values().iterator().next();
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build());
assertThat(addressesAccepted).isTrue();
final Subchannel readySubchannel = subchannels.values().iterator().next();
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture());
verify(helper, times(3)).createSubchannel(createArgsCaptor.capture());
List<List<EquivalentAddressGroup>> capturedAddrs = new ArrayList<>();
for (CreateSubchannelArgs arg : createArgsCaptor.getAllValues()) {
capturedAddrs.add(arg.getAddresses());
@ -174,22 +151,18 @@ public class LeastRequestLoadBalancerTest {
verify(subchannel, never()).shutdown();
}
verify(mockHelper, times(2))
verify(helper, times(2))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
assertEquals(CONNECTING, stateCaptor.getAllValues().get(0));
assertEquals(READY, stateCaptor.getAllValues().get(1));
assertThat(getList(pickerCaptor.getValue())).containsExactly(readySubchannel);
verifyNoMoreInteractions(mockHelper);
verifyNoMoreInteractions(helper);
}
@Test
public void pickAfterResolvedUpdatedHosts() throws Exception {
Subchannel removedSubchannel = mock(Subchannel.class);
Subchannel oldSubchannel = mock(Subchannel.class);
Subchannel newSubchannel = mock(Subchannel.class);
Attributes.Key<String> key = Attributes.Key.create("check-that-it-is-propagated");
FakeSocketAddress removedAddr = new FakeSocketAddress("removed");
EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr);
@ -201,33 +174,33 @@ public class LeastRequestLoadBalancerTest {
EquivalentAddressGroup newEag = new EquivalentAddressGroup(
newAddr, Attributes.newBuilder().set(key, "newattr").build());
subchannels.put(Collections.singletonList(removedEag), removedSubchannel);
subchannels.put(Collections.singletonList(oldEag1), oldSubchannel);
subchannels.put(Collections.singletonList(newEag), newSubchannel);
List<EquivalentAddressGroup> currentServers = Lists.newArrayList(removedEag, oldEag1);
InOrder inOrder = inOrder(mockHelper);
InOrder inOrder = inOrder(helper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity)
.build());
assertThat(addressesAccepted).isTrue();
Subchannel removedSubchannel = getSubchannel(removedEag);
Subchannel oldSubchannel = getSubchannel(oldEag1);
SubchannelStateListener removedListener =
testHelperInstance.getSubchannelStateListeners()
.get(testHelperInstance.getRealForMockSubChannel(removedSubchannel));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(READY));
deliverSubchannelState(oldSubchannel, ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture());
inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture());
SubchannelPicker picker = pickerCaptor.getValue();
assertThat(getList(picker)).containsExactly(removedSubchannel, oldSubchannel);
verify(removedSubchannel, times(1)).requestConnection();
verify(oldSubchannel, times(1)).requestConnection();
assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel,
oldSubchannel);
assertThat(getChildEags(loadBalancer)).containsExactly(removedEag, oldEag1);
// This time with Attributes
List<EquivalentAddressGroup> latestServers = Lists.newArrayList(oldEag2, newEag);
@ -236,81 +209,105 @@ public class LeastRequestLoadBalancerTest {
ResolvedAddresses.newBuilder().setAddresses(latestServers).setAttributes(affinity).build());
assertThat(addressesAccepted).isTrue();
Subchannel newSubchannel = getSubchannel(newEag);
verify(newSubchannel, times(1)).requestConnection();
verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2));
verify(removedSubchannel, times(1)).shutdown();
deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN));
removedListener.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN));
deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY));
assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel,
newSubchannel);
assertThat(getChildEags(loadBalancer)).containsExactly(oldEag2, newEag);
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture());
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture());
picker = pickerCaptor.getValue();
assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel);
assertThat(getList(pickerCaptor.getValue())).containsExactly(oldSubchannel, newSubchannel);
verifyNoMoreInteractions(mockHelper);
verifyNoMoreInteractions(helper);
}
private Subchannel getSubchannel(EquivalentAddressGroup removedEag) {
return subchannels.get(Collections.singletonList(removedEag));
}
private Subchannel getSubchannel(ChildLbState childLbState) {
return subchannels.get(Collections.singletonList(childLbState.getEag()));
}
private static List<Object> getChildEags(LeastRequestLoadBalancer loadBalancer) {
return loadBalancer.getChildLbStates().stream()
.map(ChildLbState::getEag)
// .map(EquivalentAddressGroup::getAddresses)
.collect(Collectors.toList());
}
private List<Subchannel> getSubchannels(LeastRequestLoadBalancer lb) {
return lb.getChildLbStates().stream()
.map(this::getSubchannel)
.collect(Collectors.toList());
}
private LeastRequestLbState getChildLbState(PickResult pickResult) {
EquivalentAddressGroup eag = pickResult.getSubchannel().getAddresses();
return (LeastRequestLbState) loadBalancer.getChildLbState(eag);
}
@Test
public void pickAfterStateChange() throws Exception {
InOrder inOrder = inOrder(mockHelper);
InOrder inOrder = inOrder(helper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
assertThat(addressesAccepted).isTrue();
Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
STATE_INFO);
ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next();
Subchannel subchannel = getSubchannel(childLbState);
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING);
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class);
assertThat(subchannelStateInfo.value).isEqualTo(
ConnectivityStateInfo.forNonError(READY));
assertThat(childLbState.getCurrentState()).isEqualTo(READY);
Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯");
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forTransientFailure(error));
assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error);
inOrder.verify(mockHelper).refreshNameResolution();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString());
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).refreshNameResolution();
assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error);
inOrder.verify(helper).refreshNameResolution();
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString());
verify(subchannel, times(2)).requestConnection();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verifyNoMoreInteractions(mockHelper);
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verifyNoMoreInteractions(helper);
}
@Test
public void pickAfterConfigChange() {
final LeastRequestConfig oldConfig = new LeastRequestConfig(4);
final LeastRequestConfig newConfig = new LeastRequestConfig(6);
final Subchannel readySubchannel = subchannels.values().iterator().next();
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity)
.setLoadBalancingPolicyConfig(oldConfig).build());
assertThat(addressesAccepted).isTrue();
final Subchannel readySubchannel = subchannels.values().iterator().next();
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(mockHelper, times(2))
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(helper, times(2))
.updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture());
// At this point it should use a ReadyPicker with oldConfig
// At this point it should use a ReadyPicker with oldConfig and 1 ready subchannel
pickerCaptor.getValue().pickSubchannel(mockArgs);
verify(mockRandom, times(oldConfig.choiceCount)).nextInt(1);
@ -318,26 +315,26 @@ public class LeastRequestLoadBalancerTest {
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity)
.setLoadBalancingPolicyConfig(newConfig).build());
assertThat(addressesAccepted).isTrue();
verify(mockHelper, times(3))
verify(helper, times(3))
.updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture());
// At this point it should use a ReadyPicker with newConfig
pickerCaptor.getValue().pickSubchannel(mockArgs);
verify(mockRandom, times(oldConfig.choiceCount + newConfig.choiceCount)).nextInt(1);
verifyNoMoreInteractions(mockHelper);
verifyNoMoreInteractions(helper);
}
@Test
public void ignoreShutdownSubchannelStateChange() {
InOrder inOrder = inOrder(mockHelper);
InOrder inOrder = inOrder(helper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
loadBalancer.shutdown();
for (Subchannel sc : loadBalancer.getSubchannels()) {
for (Subchannel sc : getSubchannels(loadBalancer)) {
verify(sc).shutdown();
// When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered
// back to the subchannel state listener.
@ -349,71 +346,101 @@ public class LeastRequestLoadBalancerTest {
@Test
public void stayTransientFailureUntilReady() {
InOrder inOrder = inOrder(mockHelper);
InOrder inOrder = inOrder(helper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
// Simulate state transitions for each subchannel individually.
for (Subchannel sc : loadBalancer.getSubchannels()) {
for (ChildLbState childLbState : loadBalancer.getChildLbStates()) {
Subchannel sc = getSubchannel(childLbState);
Status error = Status.UNKNOWN.withDescription("connection broken");
deliverSubchannelState(
sc,
ConnectivityStateInfo.forTransientFailure(error));
inOrder.verify(mockHelper).refreshNameResolution();
deliverSubchannelState(
sc,
ConnectivityStateInfo.forNonError(CONNECTING));
Ref<ConnectivityStateInfo> scStateInfo = sc.getAttributes().get(
STATE_INFO);
assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(scStateInfo.value.getStatus()).isEqualTo(error);
deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error));
inOrder.verify(helper).refreshNameResolution();
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING));
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
}
inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class));
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
assertThat(getStatusString((LeastRequestPicker)pickerCaptor.getValue()))
.contains("Status{code=UNKNOWN, description=connection broken");
inOrder.verifyNoMoreInteractions();
Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next();
Subchannel subchannel = getSubchannel(childLbState);
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
STATE_INFO);
assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class));
assertThat(childLbState.getCurrentState()).isEqualTo(READY);
inOrder.verify(helper).updateBalancingState(eq(READY), isA(ReadyPicker.class));
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verifyNoMoreInteractions(mockHelper);
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verifyNoMoreInteractions(helper);
}
private String getStatusString(LeastRequestPicker picker) {
if (picker == null) {
return "";
}
if (picker instanceof EmptyPicker) {
if (((EmptyPicker) picker).getStatus() == null) {
return "";
}
return ((EmptyPicker) picker).getStatus().toString();
} else if (picker instanceof ReadyPicker) {
List<ChildLbState> childLbStates = ((ReadyPicker)picker).getChildLbStates();
if (childLbStates == null || childLbStates.isEmpty()) {
return "";
};
// Note that this is dependent on PickFirst's picker toString retaining the representation
// of the status, but since it is a test and we don't want to expose this value it seems
// a reasonable tradeoff
String pickerStr = childLbStates.get(0).getCurrentPicker().toString();
int beg = pickerStr.indexOf(", status=Status{");
if (beg < 0) {
return "";
}
int end = pickerStr.indexOf('}', beg);
if (end < 0) {
return "";
}
return pickerStr.substring(beg + ", status=".length(), end + 1);
}
throw new IllegalArgumentException("Unrecognized picker: " + picker);
}
@Test
public void refreshNameResolutionWhenSubchannelConnectionBroken() {
InOrder inOrder = inOrder(mockHelper);
InOrder inOrder = inOrder(helper);
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
assertThat(addressesAccepted).isTrue();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
// Simulate state transitions for each subchannel individually.
for (Subchannel sc : loadBalancer.getSubchannels()) {
for (Subchannel sc : getSubchannels(loadBalancer)) {
verify(sc).requestConnection();
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING));
Status error = Status.UNKNOWN.withDescription("connection broken");
deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error));
inOrder.verify(mockHelper).refreshNameResolution();
inOrder.verify(helper).refreshNameResolution();
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class));
inOrder.verify(helper).updateBalancingState(eq(READY), isA(ReadyPicker.class));
// Simulate receiving go-away so READY subchannels transit to IDLE.
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).refreshNameResolution();
inOrder.verify(helper).refreshNameResolution();
verify(sc, times(2)).requestConnection();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
}
verifyNoMoreInteractions(mockHelper);
verifyNoMoreInteractions(helper);
}
@Test
@ -426,68 +453,64 @@ public class LeastRequestLoadBalancerTest {
.build());
assertThat(addressesAccepted).isTrue();
assertEquals(3, loadBalancer.getSubchannels().size());
assertEquals(3, loadBalancer.getChildLbStates().size());
List<Subchannel> subchannels = Lists.newArrayList(loadBalancer.getSubchannels());
List<ChildLbState> childLbStates = Lists.newArrayList(loadBalancer.getChildLbStates());
// Make sure all inFlight counters have started at 0
assertEquals(0,
subchannels.get(0).getAttributes().get(IN_FLIGHTS).get());
assertEquals(0,
subchannels.get(1).getAttributes().get(IN_FLIGHTS).get());
assertEquals(0,
subchannels.get(2).getAttributes().get(IN_FLIGHTS).get());
for (int i = 0; i < 3; i++) {
assertEquals("counter for child " + i, 0,
((LeastRequestLbState) childLbStates.get(i)).getActiveRequests());
}
for (Subchannel sc : subchannels) {
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY));
for (ChildLbState cs : childLbStates) {
deliverSubchannelState(getSubchannel(cs), ConnectivityStateInfo.forNonError(READY));
}
// Capture the active ReadyPicker once all subchannels are READY
verify(mockHelper, times(4))
verify(helper, times(4))
.updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class);
ReadyPicker picker = (ReadyPicker) pickerCaptor.getValue();
assertThat(picker.getList()).containsExactlyElementsIn(subchannels);
assertThat(picker.getChildLbStates()).containsExactlyElementsIn(childLbStates);
// Make random return 0, then 2 for the sample indexes.
when(mockRandom.nextInt(subchannels.size())).thenReturn(0, 2);
when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2);
PickResult pickResult1 = picker.pickSubchannel(mockArgs);
verify(mockRandom, times(choiceCount)).nextInt(subchannels.size());
assertEquals(subchannels.get(0), pickResult1.getSubchannel());
verify(mockRandom, times(choiceCount)).nextInt(childLbStates.size());
assertEquals(childLbStates.get(0), getChildLbState(pickResult1));
// This simulates sending the actual RPC on the picked channel
ClientStreamTracer streamTracer1 =
pickResult1.getStreamTracerFactory()
.newClientStreamTracer(StreamInfo.newBuilder().build(), new Metadata());
streamTracer1.streamCreated(Attributes.EMPTY, new Metadata());
assertEquals(1,
pickResult1.getSubchannel().getAttributes().get(IN_FLIGHTS).get());
assertEquals(1, getChildLbState(pickResult1).getActiveRequests());
// For the second pick it should pick the one with lower inFlight.
when(mockRandom.nextInt(subchannels.size())).thenReturn(0, 2);
when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2);
PickResult pickResult2 = picker.pickSubchannel(mockArgs);
// Since this is the second pick we expect the total random samples to be choiceCount * 2
verify(mockRandom, times(choiceCount * 2)).nextInt(subchannels.size());
assertEquals(subchannels.get(2), pickResult2.getSubchannel());
verify(mockRandom, times(choiceCount * 2)).nextInt(childLbStates.size());
assertEquals(childLbStates.get(2), getChildLbState(pickResult2));
// For the third pick we unavoidably pick subchannel with index 1.
when(mockRandom.nextInt(subchannels.size())).thenReturn(1, 1);
when(mockRandom.nextInt(childLbStates.size())).thenReturn(1, 1);
PickResult pickResult3 = picker.pickSubchannel(mockArgs);
verify(mockRandom, times(choiceCount * 3)).nextInt(subchannels.size());
assertEquals(subchannels.get(1), pickResult3.getSubchannel());
verify(mockRandom, times(choiceCount * 3)).nextInt(childLbStates.size());
assertEquals(childLbStates.get(1), getChildLbState(pickResult3));
// Finally ensure a finished RPC decreases inFlight
streamTracer1.streamClosed(Status.OK);
assertEquals(0,
pickResult1.getSubchannel().getAttributes().get(IN_FLIGHTS).get());
assertEquals(0, getChildLbState(pickResult1).getActiveRequests());
}
@Test
public void pickerEmptyList() throws Exception {
SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN);
assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel());
assertNull(picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(Status.UNKNOWN,
picker.pickSubchannel(mockArgs).getStatus());
}
@ -495,28 +518,37 @@ public class LeastRequestLoadBalancerTest {
@Test
public void nameResolutionErrorWithNoChannels() throws Exception {
Status error = Status.NOT_FOUND.withDescription("nameResolutionError");
loadBalancer.setResolvingAddresses(true);
loadBalancer.handleNameResolutionError(error);
verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
loadBalancer.setResolvingAddresses(false);
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
assertNull(pickResult.getSubchannel());
assertEquals(error, pickResult.getStatus());
verifyNoMoreInteractions(mockHelper);
verifyNoMoreInteractions(helper);
}
@Test
public void nameResolutionErrorWithActiveChannels() throws Exception {
int choiceCount = 8;
final Subchannel readySubchannel = subchannels.values().iterator().next();
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder()
.setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount))
.setAddresses(servers).setAttributes(affinity).build());
assertThat(addressesAccepted).isTrue();
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
final Subchannel readySubchannel = subchannels.values().iterator().next();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(mockHelper, times(2))
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
// TODO This test assumes that existing subchannels are left unchanged while the logic we have
// is to tell all of the children that there was a nameResolutionError. This seems to me to
// make more sense, just ignore a bad update.
loadBalancer.setResolvingAddresses(true);
loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
loadBalancer.setResolvingAddresses(false);
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(helper, times(2))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
@ -531,20 +563,20 @@ public class LeastRequestLoadBalancerTest {
LoadBalancer.PickResult pickResult2 = pickerCaptor.getValue().pickSubchannel(mockArgs);
verify(mockRandom, times(choiceCount * 2)).nextInt(1);
assertEquals(readySubchannel, pickResult2.getSubchannel());
verifyNoMoreInteractions(mockHelper);
verifyNoMoreInteractions(helper);
}
@Test
public void subchannelStateIsolation() throws Exception {
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
assertThat(addressesAccepted).isTrue();
Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
Subchannel sc3 = subchannelIterator.next();
boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
.build());
assertThat(addressesAccepted).isTrue();
verify(sc1, times(1)).requestConnection();
verify(sc2, times(1)).requestConnection();
verify(sc3, times(1)).requestConnection();
@ -555,7 +587,7 @@ public class LeastRequestLoadBalancerTest {
deliverSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE));
deliverSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
verify(mockHelper, times(6))
verify(helper, times(6))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
Iterator<SubchannelPicker> pickers = pickerCaptor.getAllValues().iterator();
@ -584,7 +616,7 @@ public class LeastRequestLoadBalancerTest {
public void readyPicker_emptyList() {
try {
// ready picker list must be non-empty
new ReadyPicker(Collections.<Subchannel>emptyList(), 2, mockRandom);
new ReadyPicker(Collections.emptyList(), 2, mockRandom);
fail();
} catch (IllegalArgumentException expected) {
}
@ -596,15 +628,19 @@ public class LeastRequestLoadBalancerTest {
EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK"));
EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"));
Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 2, mockRandom);
ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 2, mockRandom);
ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 2, mockRandom);
ReadyPicker ready4 = new ReadyPicker(Arrays.asList(sc1, sc2), 2, mockRandom);
ReadyPicker ready5 = new ReadyPicker(Arrays.asList(sc2, sc1), 2, mockRandom);
ReadyPicker ready6 = new ReadyPicker(Arrays.asList(sc2, sc1), 8, mockRandom);
loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build());
Iterator<ChildLbState> iterator = loadBalancer.getChildLbStates().iterator();
ChildLbState child1 = iterator.next();
ChildLbState child2 = iterator.next();
ReadyPicker ready1 = new ReadyPicker(Arrays.asList(child1, child2), 2, mockRandom);
ReadyPicker ready2 = new ReadyPicker(Arrays.asList(child1), 2, mockRandom);
ReadyPicker ready3 = new ReadyPicker(Arrays.asList(child2, child1), 2, mockRandom);
ReadyPicker ready4 = new ReadyPicker(Arrays.asList(child1, child2), 2, mockRandom);
ReadyPicker ready5 = new ReadyPicker(Arrays.asList(child2, child1), 2, mockRandom);
ReadyPicker ready6 = new ReadyPicker(Arrays.asList(child2, child1), 8, mockRandom);
assertTrue(emptyOk1.isEquivalentTo(emptyOk2));
assertFalse(emptyOk1.isEquivalentTo(emptyErr));
@ -623,16 +659,22 @@ public class LeastRequestLoadBalancerTest {
ResolvedAddresses.newBuilder()
.setAddresses(Collections.<EquivalentAddressGroup>emptyList())
.setAttributes(affinity)
.build())).isFalse();
.build()))
.isFalse();
}
private static List<Subchannel> getList(SubchannelPicker picker) {
return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() :
Collections.<Subchannel>emptyList();
private List<Subchannel> getList(SubchannelPicker picker) {
if (picker instanceof ReadyPicker) {
return ((ReadyPicker) picker).getChildLbStates().stream()
.map(this::getSubchannel)
.collect(Collectors.toList());
} else {
return Collections.emptyList();
}
}
private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {
subchannelStateListeners.get(subchannel).onSubchannelState(newState);
testHelperInstance.deliverSubchannelState(subchannel, newState);
}
private static class FakeSocketAddress extends SocketAddress {
@ -647,4 +689,12 @@ public class LeastRequestLoadBalancerTest {
return "FakeSocketAddress-" + name;
}
}
private class TestHelper extends AbstractTestHelper {
@Override
public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() {
return subchannels;
}
}
}

View File

@ -17,11 +17,10 @@
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -35,7 +34,6 @@ import com.google.common.collect.Maps;
import com.google.protobuf.Duration;
import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.ChannelLogger;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
@ -50,12 +48,15 @@ import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
import io.grpc.internal.TestUtils;
import io.grpc.services.InternalCallMetricRecorder;
import io.grpc.services.MetricReport;
import io.grpc.util.AbstractTestHelper;
import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.HashMap;
@ -67,6 +68,7 @@ import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Before;
@ -87,8 +89,8 @@ public class WeightedRoundRobinLoadBalancerTest {
@Rule
public final MockitoRule mockito = MockitoJUnit.rule();
@Mock
Helper helper;
private final TestHelper testHelperInstance = new TestHelper();
private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance));
@Mock
private LoadBalancer.PickSubchannelArgs mockArgs;
@ -99,9 +101,8 @@ public class WeightedRoundRobinLoadBalancerTest {
private ArgumentCaptor<SubchannelPicker> pickerCaptor2;
private final List<EquivalentAddressGroup> servers = Lists.newArrayList();
private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap();
private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
@ -134,7 +135,8 @@ public class WeightedRoundRobinLoadBalancerTest {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.add(eag);
Subchannel sc = mock(Subchannel.class);
Subchannel sc = helper.createSubchannel(CreateSubchannelArgs.newBuilder().setAddresses(eag)
.build());
Channel channel = mock(Channel.class);
when(channel.newCall(any(), any())).then(
new Answer<ClientCall<OrcaLoadReportRequest, OrcaLoadReport>>() {
@ -147,35 +149,13 @@ public class WeightedRoundRobinLoadBalancerTest {
return clientCall;
}
});
when(sc.asChannel()).thenReturn(channel);
testHelperInstance.setChannel(mockToRealSubChannelMap.get(sc), channel);
subchannels.put(Arrays.asList(eag), sc);
}
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(
fakeClock.getScheduledExecutorService());
when(helper.createSubchannel(any(CreateSubchannelArgs.class)))
.then(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
final Subchannel subchannel = subchannels.get(args.getAddresses());
when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
when(subchannel.getAttributes()).thenReturn(args.getAttributes());
when(subchannel.getChannelLogger()).thenReturn(mock(ChannelLogger.class));
doAnswer(
new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
subchannelStateListeners.put(
subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
return null;
}
}).when(subchannel).start(any(SubchannelStateListener.class));
return subchannel;
}
});
wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(),
new FakeRandom(0));
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
}
@Test
@ -183,44 +163,44 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0);
assertThat(weightedPicker.getList().size()).isEqualTo(1);
assertThat(weightedPicker.getChildren().size()).isEqualTo(1);
weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
assertThat(weightedPicker.getList().size()).isEqualTo(2);
assertThat(weightedPicker.getChildren().size()).isEqualTo(2);
String weightedPickerStr = weightedPicker.toString();
assertThat(weightedPickerStr).contains("enableOobLoadReport=false");
assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0");
assertThat(weightedPickerStr).contains("list=");
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(weightedChild1.getEag());
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s
@ -238,35 +218,44 @@ public class WeightedRoundRobinLoadBalancerTest {
verifyNoMoreInteractions(mockArgs);
}
/**
* Picks subchannel using mockArgs, gets its EAG, and then strips the Attrs to make a key.
*/
private EquivalentAddressGroup getAddressesFromPick(WeightedRoundRobinPicker weightedPicker) {
return TestUtils.stripAttrs(
weightedPicker.pickSubchannel(mockArgs).getSubchannel().getAddresses());
}
@Test
public void enableOobLoadReportConfig() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(getAddresses(pickResult))
.isEqualTo(weightedChild1.getEag());
assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener
assertThat(oobCalls.isEmpty()).isTrue();
@ -280,7 +269,8 @@ public class WeightedRoundRobinLoadBalancerTest {
eq(ConnectivityState.READY), pickerCaptor2.capture());
weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2);
pickResult = weightedPicker.pickSubchannel(mockArgs);
assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(getAddresses(pickResult))
.isEqualTo(weightedChild1.getEag());
assertThat(pickResult.getStreamTracerFactory()).isNull();
OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval(
Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build();
@ -295,46 +285,52 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(3)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
r1);
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
r2);
weightedSubchannel3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
r3);
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1);
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2);
weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3);
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 10000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio))
.isLessThan(0.0002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio ))
.isLessThan(0.0002);
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio ))
.isLessThan(0.0002);
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - subchannel1PickRatio))
.isAtMost(0.0002);
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - subchannel2PickRatio ))
.isAtMost(0.0002);
assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 10000.0 - subchannel3PickRatio ))
.isAtMost(0.0002);
}
private SubchannelStateListener getSubchannelStateListener(Subchannel mockSubChannel) {
return subchannelStateListeners.get(mockToRealSubChannelMap.get(mockSubChannel));
}
private static ChildLbState getChild(WeightedRoundRobinPicker picker, int index) {
return picker.getChildren().get(index);
}
@Test
@ -472,14 +468,14 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(null)
.setAttributes(affinity).build())).isFalse();
verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class));
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any());
assertThat(fakeClock.getPendingTasks()).isEmpty();
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue().getClass().getName())
@ -492,51 +488,51 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 10000; i++) {
EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
// within blackout period, fallback to simple round robin
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)).isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)).isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 0.5)).isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 0.5)).isLessThan(0.002);
assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1);
pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
for (int i = 0; i < 10000; i++) {
EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
// after blackout period
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 2.0 / 3))
.isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 1.0 / 3))
.isLessThan(0.002);
}
@ -545,39 +541,39 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0);
assertThat(weightedPicker.getList().size()).isEqualTo(1);
assertThat(weightedPicker.getChildren().size()).isEqualTo(1);
weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
assertThat(weightedPicker.getList().size()).isEqualTo(2);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
assertThat(weightedPicker.getChildren().size()).isEqualTo(2);
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(getAddressesFromPick(weightedPicker))
.isEqualTo(weightedChild1.getEag());
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s
@ -586,17 +582,18 @@ public class WeightedRoundRobinLoadBalancerTest {
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
//timer fires, new weight updated
assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel2);
assertThat(getAddressesFromPick(weightedPicker))
.isEqualTo(weightedChild2.getEag());
assertThat(getAddressesFromPick(weightedPicker))
.isEqualTo(weightedChild1.getEag());
}
@Test
@ -604,52 +601,52 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3))
.isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3))
.isLessThan(0.002);
// weight expired, fallback to simple round robin
assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1);
pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5))
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 0.5))
.isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5))
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 0.5))
.isLessThan(0.002);
}
@ -658,107 +655,113 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
Map<WrrSubchannel, Integer> qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2,
weightedSubchannel2, 1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
Map<EquivalentAddressGroup, Integer> qpsByChannel = ImmutableMap.of(weightedChild1.getEag(), 2,
weightedChild2.getEag(), 1);
Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
pickCount.put(pickResult.getSubchannel(),
pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
EquivalentAddressGroup addresses = getAddresses(pickResult);
pickCount.merge(addresses, 1, Integer::sum);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel();
subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses);
childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, qpsByChannel.get(subchannel), 0,
0.1, 0, 0.1, qpsByChannel.get(addresses), 0,
new HashMap<>(), new HashMap<>(), new HashMap<>()));
}
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2))
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2))
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
// Identical to above except forwards time after each pick
pickCount.clear();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
pickCount.put(pickResult.getSubchannel(),
pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
EquivalentAddressGroup addresses = getAddresses(pickResult);
pickCount.merge(addresses, 1, Integer::sum);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
WrrSubchannel subchannel = (WrrSubchannel) pickResult.getSubchannel();
subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses);
childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, qpsByChannel.get(subchannel), 0,
0.1, 0, 0.1, qpsByChannel.get(addresses), 0,
new HashMap<>(), new HashMap<>(), new HashMap<>()));
fakeClock.forwardTime(50, TimeUnit.MILLISECONDS);
}
assertThat(pickCount.size()).isEqualTo(2);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3))
.isAtMost(0.1);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3))
.isAtMost(0.1);
}
private static EquivalentAddressGroup getAddresses(PickResult pickResult) {
return TestUtils.stripAttrs(pickResult.getSubchannel().getAddresses());
}
@Test
public void unknownWeightIsAvgWeight() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class)); // 3 from setup plus 3 from the execute
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
getSubchannelStateListener(readySubchannel1)
.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
getSubchannelStateListener(readySubchannel2)
.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
getSubchannelStateListener(readySubchannel3)
.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
verify(helper, times(3)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
pickCount.merge(result.getAddresses(), 1, Integer::sum);
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9))
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 4.0 / 9))
.isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9))
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 2.0 / 9))
.isLessThan(0.002);
// subchannel3's weight is average of subchannel1 and subchannel2
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9))
assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 1000.0 - 3.0 / 9))
.isLessThan(0.002);
}
@ -767,33 +770,33 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
CyclicBarrier barrier = new CyclicBarrier(2);
Map<Subchannel, AtomicInteger> pickCount = new ConcurrentHashMap<>();
pickCount.put(weightedSubchannel1, new AtomicInteger(0));
pickCount.put(weightedSubchannel2, new AtomicInteger(0));
Map<EquivalentAddressGroup, AtomicInteger> pickCount = new ConcurrentHashMap<>();
pickCount.put(weightedChild1.getEag(), new AtomicInteger(0));
pickCount.put(weightedChild2.getEag(), new AtomicInteger(0));
new Thread(new Runnable() {
@Override
public void run() {
@ -802,7 +805,7 @@ public class WeightedRoundRobinLoadBalancerTest {
barrier.await();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.get(result).addAndGet(1);
pickCount.get(result.getAddresses()).addAndGet(1);
}
barrier.await();
} catch (Exception ex) {
@ -813,15 +816,15 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
barrier.await();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
EquivalentAddressGroup result = getAddresses(weightedPicker.pickSubchannel(mockArgs));
pickCount.get(result).addAndGet(1);
}
barrier.await();
assertThat(pickCount.size()).isEqualTo(2);
// after blackout period
assertThat(Math.abs(pickCount.get(weightedSubchannel1).get() / 2000.0 - 2.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild1.getEag()).get() / 2000.0 - 2.0 / 3))
.isLessThan(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2).get() / 2000.0 - 1.0 / 3))
assertThat(Math.abs(pickCount.get(weightedChild2.getEag()).get() / 2000.0 - 1.0 / 3))
.isLessThan(0.002);
}
@ -1104,4 +1107,34 @@ public class WeightedRoundRobinLoadBalancerTest {
return nextInt;
}
}
private class TestHelper extends AbstractTestHelper {
@Override
public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() {
return subchannels;
}
@Override
public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() {
return mockToRealSubChannelMap;
}
@Override
public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() {
return subchannelStateListeners;
}
@Override
public SynchronizationContext getSynchronizationContext() {
return syncContext;
}
@Override
public ScheduledExecutorService getScheduledExecutorService() {
return fakeClock.getScheduledExecutorService();
}
}
}