diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index 91b7a93151..69238c39dd 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -84,21 +84,18 @@ class ClusterManagerLoadBalancer extends LoadBalancer { } else { childLbStates.get(name).reactivate(childPolicyProvider); } - final LoadBalancer childLb = childLbStates.get(name).lb; - final ResolvedAddresses childAddresses = + LoadBalancer childLb = childLbStates.get(name).lb; + ResolvedAddresses childAddresses = resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); - syncContext.execute(new Runnable() { - @Override - public void run() { - childLb.handleResolvedAddresses(childAddresses); - } - }); + childLb.handleResolvedAddresses(childAddresses); } for (String name : childLbStates.keySet()) { if (!newChildPolicies.containsKey(name)) { childLbStates.get(name).deactivate(); } } + // Must update channel picker before return so that new RPCs will not be routed to deleted + // clusters and resolver can remove them in service config. updateOverallBalancingState(); } @@ -245,12 +242,20 @@ class ClusterManagerLoadBalancer extends LoadBalancer { private final class ChildLbStateHelper extends ForwardingLoadBalancerHelper { @Override - public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { - currentState = newState; - currentPicker = newPicker; - if (!deactivated) { - updateOverallBalancingState(); - } + public void updateBalancingState(final ConnectivityState newState, + final SubchannelPicker newPicker) { + syncContext.execute(new Runnable() { + @Override + public void run() { + currentState = newState; + currentPicker = newPicker; + // Subchannel picker and state are saved, but will only be propagated to the channel + // when the child instance exits deactivated state. + if (!deactivated) { + updateOverallBalancingState(); + } + } + }); } @Override diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java index cace127b35..6925836b1a 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java @@ -251,43 +251,42 @@ final class PriorityLoadBalancer extends LoadBalancer { * already exists. */ void updateResolvedAddresses() { - final ResolvedAddresses addresses = resolvedAddresses; - syncContext.execute( - new Runnable() { - @Override - public void run() { - PriorityLbConfig config = (PriorityLbConfig) addresses.getLoadBalancingPolicyConfig(); - PolicySelection childPolicySelection = config.childConfigs.get(priority); - LoadBalancerProvider lbProvider = childPolicySelection.getProvider(); - String newPolicy = lbProvider.getPolicyName(); - if (!newPolicy.equals(policy)) { - policy = newPolicy; - lb.switchTo(lbProvider); - } - lb.handleResolvedAddresses( - addresses - .toBuilder() - .setAddresses(AddressFilter.filter(addresses.getAddresses(), priority)) - .setLoadBalancingPolicyConfig(childPolicySelection.getConfig()) - .build()); - } - }); + PriorityLbConfig config = + (PriorityLbConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + PolicySelection childPolicySelection = config.childConfigs.get(priority); + LoadBalancerProvider lbProvider = childPolicySelection.getProvider(); + String newPolicy = lbProvider.getPolicyName(); + if (!newPolicy.equals(policy)) { + policy = newPolicy; + lb.switchTo(lbProvider); + } + lb.handleResolvedAddresses( + resolvedAddresses.toBuilder() + .setAddresses(AddressFilter.filter(resolvedAddresses.getAddresses(), priority)) + .setLoadBalancingPolicyConfig(childPolicySelection.getConfig()) + .build()); } final class ChildHelper extends ForwardingLoadBalancerHelper { @Override - public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { - connectivityState = newState; - picker = newPicker; - if (deletionTimer != null && deletionTimer.isPending()) { - return; - } - if (failOverTimer.isPending()) { - if (newState.equals(READY) || newState.equals(TRANSIENT_FAILURE)) { - failOverTimer.cancel(); + public void updateBalancingState(final ConnectivityState newState, + final SubchannelPicker newPicker) { + syncContext.execute(new Runnable() { + @Override + public void run() { + connectivityState = newState; + picker = newPicker; + if (deletionTimer != null && deletionTimer.isPending()) { + return; + } + if (failOverTimer.isPending()) { + if (newState.equals(READY) || newState.equals(TRANSIENT_FAILURE)) { + failOverTimer.cancel(); + } + } + tryNextPriority(true); } - } - tryNextPriority(true); + }); } @Override diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java index 741c59a6ff..75c11740a5 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java @@ -28,6 +28,7 @@ import io.grpc.ConnectivityState; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; @@ -48,11 +49,13 @@ final class WeightedTargetLoadBalancer extends LoadBalancer { private final Map childBalancers = new HashMap<>(); private final Map childHelpers = new HashMap<>(); private final Helper helper; + private final SynchronizationContext syncContext; private Map targets = ImmutableMap.of(); WeightedTargetLoadBalancer(Helper helper) { - this.helper = helper; + this.helper = checkNotNull(helper, "helper"); + this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); logger = XdsLogger.withLogId( InternalLogId.allocate("weighted-target-lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); @@ -63,10 +66,8 @@ final class WeightedTargetLoadBalancer extends LoadBalancer { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); Object lbConfig = resolvedAddresses.getLoadBalancingPolicyConfig(); checkNotNull(lbConfig, "missing weighted_target lb config"); - WeightedTargetConfig weightedTargetConfig = (WeightedTargetConfig) lbConfig; Map newTargets = weightedTargetConfig.targets; - for (String targetName : newTargets.keySet()) { WeightedPolicySelection weightedChildLbConfig = newTargets.get(targetName); if (!targets.containsKey(targetName)) { @@ -81,9 +82,7 @@ final class WeightedTargetLoadBalancer extends LoadBalancer { .switchTo(weightedChildLbConfig.policySelection.getProvider()); } } - targets = newTargets; - for (String targetName : targets.keySet()) { childBalancers.get(targetName).handleResolvedAddresses( resolvedAddresses.toBuilder() @@ -101,6 +100,7 @@ final class WeightedTargetLoadBalancer extends LoadBalancer { } childBalancers.keySet().retainAll(targets.keySet()); childHelpers.keySet().retainAll(targets.keySet()); + updateOverallBalancingState(); } @Override @@ -180,10 +180,16 @@ final class WeightedTargetLoadBalancer extends LoadBalancer { SubchannelPicker currentPicker = BUFFER_PICKER; @Override - public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { - currentState = newState; - currentPicker = newPicker; - updateOverallBalancingState(); + public void updateBalancingState(final ConnectivityState newState, + final SubchannelPicker newPicker) { + syncContext.execute(new Runnable() { + @Override + public void run() { + currentState = newState; + currentPicker = newPicker; + updateOverallBalancingState(); + } + }); } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 2dc47fe0c1..d7edcc8110 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -41,6 +41,7 @@ import io.grpc.Metadata; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.SynchronizationContext; import io.grpc.internal.ObjectPool; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; @@ -86,6 +87,13 @@ public class ClusterImplLoadBalancerTest { private static final String CLUSTER = "cluster-foo.googleapis.com"; private static final String EDS_SERVICE_NAME = "service.googleapis.com"; private static final String LRS_SERVER_NAME = ""; + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); private final Locality locality = new Locality("test-region", "test-zone", "test-subzone"); private final PolicySelection roundRobin = @@ -583,6 +591,12 @@ public class ClusterImplLoadBalancerTest { } private final class FakeLbHelper extends LoadBalancer.Helper { + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + @Override public void updateBalancingState( @Nonnull ConnectivityState newState, @Nonnull SubchannelPicker newPicker) { diff --git a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java index 55ecfcd41d..f7f3d6c77b 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java @@ -23,18 +23,18 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.grpc.Attributes; -import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -45,6 +45,7 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; @@ -68,6 +69,13 @@ import org.mockito.MockitoAnnotations; @RunWith(JUnit4.class) public class WeightedTargetLoadBalancerTest { + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); private final List childBalancers = new ArrayList<>(); private final List childHelpers = new ArrayList<>(); @@ -143,8 +151,6 @@ public class WeightedTargetLoadBalancerTest { @Mock private Helper helper; - @Mock - private ChannelLogger channelLogger; private LoadBalancer weightedTargetLb; private int fooLbCreated; @@ -153,8 +159,7 @@ public class WeightedTargetLoadBalancerTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - - doReturn(channelLogger).when(helper).getChannelLogger(); + when(helper.getSynchronizationContext()).thenReturn(syncContext); lbRegistry.register(fooLbProvider); lbRegistry.register(barLbProvider); @@ -198,7 +203,7 @@ public class WeightedTargetLoadBalancerTest { .setAttributes(Attributes.newBuilder().set(fakeKey, fakeValue).build()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) .build()); - + verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); assertThat(childBalancers).hasSize(4); assertThat(childHelpers).hasSize(4); assertThat(fooLbCreated).isEqualTo(2); @@ -235,7 +240,7 @@ public class WeightedTargetLoadBalancerTest { .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(newTargets)) .build()); - + verify(helper, atLeast(2)).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); assertThat(childBalancers).hasSize(5); assertThat(childHelpers).hasSize(5); assertThat(fooLbCreated).isEqualTo(3); // One more foo LB created for target4 @@ -277,6 +282,7 @@ public class WeightedTargetLoadBalancerTest { .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) .build()); + verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); // Error after child balancers created. weightedTargetLb.handleNameResolutionError(Status.ABORTED); @@ -303,6 +309,7 @@ public class WeightedTargetLoadBalancerTest { .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) .build()); + verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); // Subchannels to be created for each child balancer. final SubchannelPicker[] subchannelPickers = new SubchannelPicker[]{ @@ -316,7 +323,7 @@ public class WeightedTargetLoadBalancerTest { childHelpers.get(1).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.ABORTED)); verify(helper, never()).updateBalancingState( eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); + verify(helper, times(2)).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); // Another child balancer goes to READY. childHelpers.get(2).updateBalancingState(READY, subchannelPickers[2]);