services: fix HealthCheckingLoadBalancer.shutdown() (#5887)

The issue: HealthCheckingLoadBalancer.shutdown() calls
hcState.onSubchannelState(SHUTDOWN) which removes that hcState from
helper.hcStates. Therefore, if more than one Subchannels are present,
ConcurrentModificationException will be thrown.

This is an alternative approach from #5848 that was reverted in #5875. Thanks to #5883, HealthCheckingLoadBalancer.shutdown() no longer has to fake SHUTDOWN notifications, and can completely rely on Subchannels' real SHUTDOWN notifications for triggering the clean-up.
This commit is contained in:
Kun Zhang 2019-06-14 16:47:17 -07:00 committed by GitHub
parent fda406b0ff
commit ff33ecd339
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 35 deletions

View File

@ -98,7 +98,6 @@ final class HealthCheckingLoadBalancerFactory extends Factory {
private final SynchronizationContext syncContext; private final SynchronizationContext syncContext;
@Nullable String healthCheckedService; @Nullable String healthCheckedService;
private boolean balancerShutdown;
final HashSet<HealthCheckState> hcStates = new HashSet<>(); final HashSet<HealthCheckState> hcStates = new HashSet<>();
@ -190,19 +189,6 @@ final class HealthCheckingLoadBalancerFactory extends Factory {
super.handleResolvedAddresses(resolvedAddresses); super.handleResolvedAddresses(resolvedAddresses);
} }
@Override
public void shutdown() {
super.shutdown();
helper.balancerShutdown = true;
for (HealthCheckState hcState : helper.hcStates) {
// ManagedChannel will stop calling onSubchannelState() after shutdown() is called,
// which is required by LoadBalancer API semantics. We need to deliver the final SHUTDOWN
// signal to health checkers so that they can cancel the streams.
hcState.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN));
}
helper.hcStates.clear();
}
@Override @Override
public String toString() { public String toString() {
return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString();
@ -341,7 +327,7 @@ final class HealthCheckingLoadBalancerFactory extends Factory {
private void gotoState(ConnectivityStateInfo newState) { private void gotoState(ConnectivityStateInfo newState) {
checkState(subchannel != null, "init() not called"); checkState(subchannel != null, "init() not called");
if (!helperImpl.balancerShutdown && !Objects.equal(concludedState, newState)) { if (!Objects.equal(concludedState, newState)) {
concludedState = newState; concludedState = newState;
stateListener.onSubchannelState(concludedState); stateListener.onSubchannelState(concludedState);
} }

View File

@ -29,6 +29,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same; import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -95,6 +96,8 @@ import org.mockito.InOrder;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.hamcrest.MockitoHamcrest; import org.mockito.hamcrest.MockitoHamcrest;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Tests for {@link HealthCheckingLoadBalancerFactory}. */ /** Tests for {@link HealthCheckingLoadBalancerFactory}. */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
@ -258,6 +261,7 @@ public class HealthCheckingLoadBalancerFactoryTest {
verify(origHelper, atLeast(0)).getScheduledExecutorService(); verify(origHelper, atLeast(0)).getScheduledExecutorService();
verifyNoMoreInteractions(origHelper); verifyNoMoreInteractions(origHelper);
verifyNoMoreInteractions(origLb); verifyNoMoreInteractions(origLb);
Subchannel[] wrappedSubchannels = new Subchannel[NUM_SUBCHANNELS];
// Simulate that the orignal LB creates Subchannels // Simulate that the orignal LB creates Subchannels
for (int i = 0; i < NUM_SUBCHANNELS; i++) { for (int i = 0; i < NUM_SUBCHANNELS; i++) {
@ -265,8 +269,8 @@ public class HealthCheckingLoadBalancerFactoryTest {
String subchannelAttrValue = "eag attr " + i; String subchannelAttrValue = "eag attr " + i;
Attributes attrs = Attributes.newBuilder() Attributes attrs = Attributes.newBuilder()
.set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); .set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build();
// We don't wrap Subchannels, thus origLb gets the original Subchannels. wrappedSubchannels[i] = createSubchannel(i, attrs);
assertThat(unwrap(createSubchannel(i, attrs))).isSameInstanceAs(subchannels[i]); assertThat(unwrap(wrappedSubchannels[i])).isSameInstanceAs(subchannels[i]);
verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture());
assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]);
assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY))
@ -340,9 +344,17 @@ public class HealthCheckingLoadBalancerFactoryTest {
assertThat(serverCall.cancelled).isFalse(); assertThat(serverCall.cancelled).isFalse();
verifyNoMoreInteractions(mockStateListener); verifyNoMoreInteractions(mockStateListener);
assertThat(subchannels[i].isShutdown).isFalse();
final Subchannel wrappedSubchannel = wrappedSubchannels[i];
// Subchannel enters SHUTDOWN state as a response to shutdown(), and that will cancel the // Subchannel enters SHUTDOWN state as a response to shutdown(), and that will cancel the
// health check RPC // health check RPC
subchannel.shutdown(); syncContext.execute(new Runnable() {
@Override
public void run() {
wrappedSubchannel.shutdown();
}
});
assertThat(subchannels[i].isShutdown).isTrue();
assertThat(serverCall.cancelled).isTrue(); assertThat(serverCall.cancelled).isTrue();
verify(mockStateListener).onSubchannelState( verify(mockStateListener).onSubchannelState(
eq(ConnectivityStateInfo.forNonError(SHUTDOWN))); eq(ConnectivityStateInfo.forNonError(SHUTDOWN)));
@ -1004,34 +1016,51 @@ public class HealthCheckingLoadBalancerFactoryTest {
verify(origLb).handleResolvedAddresses(result); verify(origLb).handleResolvedAddresses(result);
verifyNoMoreInteractions(origLb); verifyNoMoreInteractions(origLb);
ServerSideCall[] serverCalls = new ServerSideCall[NUM_SUBCHANNELS];
Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); final Subchannel[] wrappedSubchannels = new Subchannel[NUM_SUBCHANNELS];
SubchannelStateListener mockListener = mockStateListeners[0];
assertThat(unwrap(subchannel)).isSameInstanceAs(subchannels[0]); for (int i = 0; i < NUM_SUBCHANNELS; i++) {
Subchannel subchannel = createSubchannel(i, Attributes.EMPTY);
wrappedSubchannels[i] = subchannel;
SubchannelStateListener mockListener = mockStateListeners[i];
assertThat(unwrap(subchannel)).isSameInstanceAs(subchannels[i]);
// Trigger the health check // Trigger the health check
deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(i, ConnectivityStateInfo.forNonError(READY));
HealthImpl healthImpl = healthImpls[0]; HealthImpl healthImpl = healthImpls[i];
assertThat(healthImpl.calls).hasSize(1); assertThat(healthImpl.calls).hasSize(1);
ServerSideCall serverCall = healthImpl.calls.poll(); serverCalls[i] = healthImpl.calls.poll();
assertThat(serverCall.cancelled).isFalse(); assertThat(serverCalls[i].cancelled).isFalse();
verify(mockListener).onSubchannelState( verify(mockListener).onSubchannelState(
eq(ConnectivityStateInfo.forNonError(CONNECTING))); eq(ConnectivityStateInfo.forNonError(CONNECTING)));
}
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
for (int i = 0; i < NUM_SUBCHANNELS; i++) {
wrappedSubchannels[i].shutdown();
}
return null;
}
}).when(origLb).shutdown();
// Shut down the balancer // Shut down the balancer
hcLbEventDelivery.shutdown(); hcLbEventDelivery.shutdown();
verify(origLb).shutdown(); verify(origLb).shutdown();
// Health check stream should be cancelled // Health check stream should be cancelled
assertThat(serverCall.cancelled).isTrue(); for (int i = 0; i < NUM_SUBCHANNELS; i++) {
assertThat(serverCalls[i].cancelled).isTrue();
// LoadBalancer API requires no more callbacks on LoadBalancer after shutdown() is called. verifyNoMoreInteractions(origLb);
verifyNoMoreInteractions(origLb, mockListener); verify(mockStateListeners[i]).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN));
// No more health check call is made or scheduled // No more health check call is made or scheduled
assertThat(healthImpl.calls).isEmpty(); assertThat(healthImpls[i].calls).isEmpty();
}
assertThat(clock.getPendingTasks()).isEmpty(); assertThat(clock.getPendingTasks()).isEmpty();
} }
@ -1156,6 +1185,7 @@ public class HealthCheckingLoadBalancerFactoryTest {
final ArrayList<String> logs = new ArrayList<>(); final ArrayList<String> logs = new ArrayList<>();
final int index; final int index;
SubchannelStateListener listener; SubchannelStateListener listener;
boolean isShutdown;
private final ChannelLogger logger = new ChannelLogger() { private final ChannelLogger logger = new ChannelLogger() {
@Override @Override
public void log(ChannelLogLevel level, String msg) { public void log(ChannelLogLevel level, String msg) {
@ -1183,6 +1213,7 @@ public class HealthCheckingLoadBalancerFactoryTest {
@Override @Override
public void shutdown() { public void shutdown() {
isShutdown = true;
deliverSubchannelState(index, ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(index, ConnectivityStateInfo.forNonError(SHUTDOWN));
} }