From ff33ecd339eacdf3937313015bdddfa4f95dd5c6 Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Fri, 14 Jun 2019 16:47:17 -0700 Subject: [PATCH] services: fix HealthCheckingLoadBalancer.shutdown() (#5887) The issue: HealthCheckingLoadBalancer.shutdown() calls hcState.onSubchannelState(SHUTDOWN) which removes that hcState from helper.hcStates. Therefore, if more than one Subchannels are present, ConcurrentModificationException will be thrown. This is an alternative approach from #5848 that was reverted in #5875. Thanks to #5883, HealthCheckingLoadBalancer.shutdown() no longer has to fake SHUTDOWN notifications, and can completely rely on Subchannels' real SHUTDOWN notifications for triggering the clean-up. --- .../HealthCheckingLoadBalancerFactory.java | 16 +---- ...HealthCheckingLoadBalancerFactoryTest.java | 71 +++++++++++++------ 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java index a1f1d5b017..9fc9fd1da3 100644 --- a/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java @@ -98,7 +98,6 @@ final class HealthCheckingLoadBalancerFactory extends Factory { private final SynchronizationContext syncContext; @Nullable String healthCheckedService; - private boolean balancerShutdown; final HashSet hcStates = new HashSet<>(); @@ -190,19 +189,6 @@ final class HealthCheckingLoadBalancerFactory extends Factory { super.handleResolvedAddresses(resolvedAddresses); } - @Override - public void shutdown() { - super.shutdown(); - helper.balancerShutdown = true; - for (HealthCheckState hcState : helper.hcStates) { - // ManagedChannel will stop calling onSubchannelState() after shutdown() is called, - // which is required by LoadBalancer API semantics. We need to deliver the final SHUTDOWN - // signal to health checkers so that they can cancel the streams. - hcState.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); - } - helper.hcStates.clear(); - } - @Override public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); @@ -341,7 +327,7 @@ final class HealthCheckingLoadBalancerFactory extends Factory { private void gotoState(ConnectivityStateInfo newState) { checkState(subchannel != null, "init() not called"); - if (!helperImpl.balancerShutdown && !Objects.equal(concludedState, newState)) { + if (!Objects.equal(concludedState, newState)) { concludedState = newState; stateListener.onSubchannelState(concludedState); } diff --git a/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java b/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java index 20747bf1b4..734dfd41e4 100644 --- a/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java +++ b/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java @@ -29,6 +29,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -95,6 +96,8 @@ import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.hamcrest.MockitoHamcrest; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; /** Tests for {@link HealthCheckingLoadBalancerFactory}. */ @RunWith(JUnit4.class) @@ -258,6 +261,7 @@ public class HealthCheckingLoadBalancerFactoryTest { verify(origHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(origHelper); verifyNoMoreInteractions(origLb); + Subchannel[] wrappedSubchannels = new Subchannel[NUM_SUBCHANNELS]; // Simulate that the orignal LB creates Subchannels for (int i = 0; i < NUM_SUBCHANNELS; i++) { @@ -265,8 +269,8 @@ public class HealthCheckingLoadBalancerFactoryTest { String subchannelAttrValue = "eag attr " + i; Attributes attrs = Attributes.newBuilder() .set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); - // We don't wrap Subchannels, thus origLb gets the original Subchannels. - assertThat(unwrap(createSubchannel(i, attrs))).isSameInstanceAs(subchannels[i]); + wrappedSubchannels[i] = createSubchannel(i, attrs); + assertThat(unwrap(wrappedSubchannels[i])).isSameInstanceAs(subchannels[i]); verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) @@ -340,9 +344,17 @@ public class HealthCheckingLoadBalancerFactoryTest { assertThat(serverCall.cancelled).isFalse(); verifyNoMoreInteractions(mockStateListener); + assertThat(subchannels[i].isShutdown).isFalse(); + final Subchannel wrappedSubchannel = wrappedSubchannels[i]; // Subchannel enters SHUTDOWN state as a response to shutdown(), and that will cancel the // health check RPC - subchannel.shutdown(); + syncContext.execute(new Runnable() { + @Override + public void run() { + wrappedSubchannel.shutdown(); + } + }); + assertThat(subchannels[i].isShutdown).isTrue(); assertThat(serverCall.cancelled).isTrue(); verify(mockStateListener).onSubchannelState( eq(ConnectivityStateInfo.forNonError(SHUTDOWN))); @@ -1004,34 +1016,51 @@ public class HealthCheckingLoadBalancerFactoryTest { verify(origLb).handleResolvedAddresses(result); verifyNoMoreInteractions(origLb); + ServerSideCall[] serverCalls = new ServerSideCall[NUM_SUBCHANNELS]; - Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); - SubchannelStateListener mockListener = mockStateListeners[0]; - assertThat(unwrap(subchannel)).isSameInstanceAs(subchannels[0]); + final Subchannel[] wrappedSubchannels = new Subchannel[NUM_SUBCHANNELS]; - // Trigger the health check - deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + Subchannel subchannel = createSubchannel(i, Attributes.EMPTY); + wrappedSubchannels[i] = subchannel; + SubchannelStateListener mockListener = mockStateListeners[i]; + assertThat(unwrap(subchannel)).isSameInstanceAs(subchannels[i]); - HealthImpl healthImpl = healthImpls[0]; - assertThat(healthImpl.calls).hasSize(1); - ServerSideCall serverCall = healthImpl.calls.poll(); - assertThat(serverCall.cancelled).isFalse(); + // Trigger the health check + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(READY)); - verify(mockListener).onSubchannelState( - eq(ConnectivityStateInfo.forNonError(CONNECTING))); + HealthImpl healthImpl = healthImpls[i]; + assertThat(healthImpl.calls).hasSize(1); + serverCalls[i] = healthImpl.calls.poll(); + assertThat(serverCalls[i].cancelled).isFalse(); + + verify(mockListener).onSubchannelState( + eq(ConnectivityStateInfo.forNonError(CONNECTING))); + } + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + wrappedSubchannels[i].shutdown(); + } + return null; + } + }).when(origLb).shutdown(); // Shut down the balancer hcLbEventDelivery.shutdown(); verify(origLb).shutdown(); // Health check stream should be cancelled - assertThat(serverCall.cancelled).isTrue(); + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + assertThat(serverCalls[i].cancelled).isTrue(); + verifyNoMoreInteractions(origLb); + verify(mockStateListeners[i]).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); + // No more health check call is made or scheduled + assertThat(healthImpls[i].calls).isEmpty(); + } - // LoadBalancer API requires no more callbacks on LoadBalancer after shutdown() is called. - verifyNoMoreInteractions(origLb, mockListener); - - // No more health check call is made or scheduled - assertThat(healthImpl.calls).isEmpty(); assertThat(clock.getPendingTasks()).isEmpty(); } @@ -1156,6 +1185,7 @@ public class HealthCheckingLoadBalancerFactoryTest { final ArrayList logs = new ArrayList<>(); final int index; SubchannelStateListener listener; + boolean isShutdown; private final ChannelLogger logger = new ChannelLogger() { @Override public void log(ChannelLogLevel level, String msg) { @@ -1183,6 +1213,7 @@ public class HealthCheckingLoadBalancerFactoryTest { @Override public void shutdown() { + isShutdown = true; deliverSubchannelState(index, ConnectivityStateInfo.forNonError(SHUTDOWN)); }