diff --git a/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java index a1f1d5b017..8f2c6bc7a4 100644 --- a/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java @@ -121,7 +121,7 @@ final class HealthCheckingLoadBalancerFactory extends Factory { HealthCheckState hcState = new HealthCheckState( this, originalSubchannel, syncContext, delegate.getScheduledExecutorService()); hcStates.add(hcState); - Subchannel subchannel = new SubchannelImpl(originalSubchannel, hcState); + Subchannel subchannel = new SubchannelImpl(originalSubchannel, this, hcState); if (healthCheckedService != null) { hcState.setServiceName(healthCheckedService); } @@ -144,10 +144,12 @@ final class HealthCheckingLoadBalancerFactory extends Factory { @VisibleForTesting static final class SubchannelImpl extends ForwardingSubchannel { final Subchannel delegate; + final HelperImpl helperImpl; final HealthCheckState hcState; - SubchannelImpl(Subchannel delegate, HealthCheckState hcState) { + SubchannelImpl(Subchannel delegate, HelperImpl helperImpl, HealthCheckState hcState) { this.delegate = checkNotNull(delegate, "delegate"); + this.helperImpl = checkNotNull(helperImpl, "helperImpl"); this.hcState = checkNotNull(hcState, "hcState"); } @@ -161,6 +163,13 @@ final class HealthCheckingLoadBalancerFactory extends Factory { hcState.init(listener); delegate().start(hcState); } + + @Override + public void shutdown() { + helperImpl.getSynchronizationContext().throwIfNotInThisSynchronizationContext(); + delegate().shutdown(); + helperImpl.hcStates.remove(hcState); + } } private static final class HealthCheckingLoadBalancer extends ForwardingLoadBalancer { @@ -282,9 +291,6 @@ final class HealthCheckingLoadBalancerFactory extends Factory { // may be available on the new connection. disabled = false; } - if (Objects.equal(rawState.getState(), SHUTDOWN)) { - helperImpl.hcStates.remove(this); - } this.rawState = rawState; adjustHealthCheck(); } diff --git a/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java b/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java index 20747bf1b4..e9096612f8 100644 --- a/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java +++ b/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java @@ -258,6 +258,7 @@ public class HealthCheckingLoadBalancerFactoryTest { verify(origHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(origHelper); verifyNoMoreInteractions(origLb); + Subchannel[] wrappedSubchannels = new Subchannel[NUM_SUBCHANNELS]; // Simulate that the orignal LB creates Subchannels for (int i = 0; i < NUM_SUBCHANNELS; i++) { @@ -265,8 +266,8 @@ public class HealthCheckingLoadBalancerFactoryTest { String subchannelAttrValue = "eag attr " + i; Attributes attrs = Attributes.newBuilder() .set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); - // We don't wrap Subchannels, thus origLb gets the original Subchannels. - assertThat(unwrap(createSubchannel(i, attrs))).isSameInstanceAs(subchannels[i]); + wrappedSubchannels[i] = createSubchannel(i, attrs); + assertThat(unwrap(wrappedSubchannels[i])).isSameInstanceAs(subchannels[i]); verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) @@ -340,9 +341,17 @@ public class HealthCheckingLoadBalancerFactoryTest { assertThat(serverCall.cancelled).isFalse(); verifyNoMoreInteractions(mockStateListener); + assertThat(subchannels[i].isShutdown).isFalse(); + final Subchannel wrappedSubchannel = wrappedSubchannels[i]; // Subchannel enters SHUTDOWN state as a response to shutdown(), and that will cancel the // health check RPC - subchannel.shutdown(); + syncContext.execute(new Runnable() { + @Override + public void run() { + wrappedSubchannel.shutdown(); + } + }); + assertThat(subchannels[i].isShutdown).isTrue(); assertThat(serverCall.cancelled).isTrue(); verify(mockStateListener).onSubchannelState( eq(ConnectivityStateInfo.forNonError(SHUTDOWN))); @@ -1004,34 +1013,38 @@ public class HealthCheckingLoadBalancerFactoryTest { verify(origLb).handleResolvedAddresses(result); verifyNoMoreInteractions(origLb); + ServerSideCall[] serverCalls = new ServerSideCall[NUM_SUBCHANNELS]; - Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); - SubchannelStateListener mockListener = mockStateListeners[0]; - assertThat(unwrap(subchannel)).isSameInstanceAs(subchannels[0]); + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + Subchannel subchannel = createSubchannel(i, Attributes.EMPTY); + SubchannelStateListener mockListener = mockStateListeners[i]; + assertThat(unwrap(subchannel)).isSameInstanceAs(subchannels[i]); - // Trigger the health check - deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + // Trigger the health check + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(READY)); - HealthImpl healthImpl = healthImpls[0]; - assertThat(healthImpl.calls).hasSize(1); - ServerSideCall serverCall = healthImpl.calls.poll(); - assertThat(serverCall.cancelled).isFalse(); + HealthImpl healthImpl = healthImpls[i]; + assertThat(healthImpl.calls).hasSize(1); + serverCalls[i] = healthImpl.calls.poll(); + assertThat(serverCalls[i].cancelled).isFalse(); - verify(mockListener).onSubchannelState( - eq(ConnectivityStateInfo.forNonError(CONNECTING))); + verify(mockListener).onSubchannelState( + eq(ConnectivityStateInfo.forNonError(CONNECTING))); + } // Shut down the balancer hcLbEventDelivery.shutdown(); verify(origLb).shutdown(); // Health check stream should be cancelled - assertThat(serverCall.cancelled).isTrue(); + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + assertThat(serverCalls[i].cancelled).isTrue(); + // LoadBalancer API requires no more callbacks on LoadBalancer after shutdown() is called. + verifyNoMoreInteractions(origLb, mockStateListeners[i]); + // No more health check call is made or scheduled + assertThat(healthImpls[i].calls).isEmpty(); + } - // LoadBalancer API requires no more callbacks on LoadBalancer after shutdown() is called. - verifyNoMoreInteractions(origLb, mockListener); - - // No more health check call is made or scheduled - assertThat(healthImpl.calls).isEmpty(); assertThat(clock.getPendingTasks()).isEmpty(); } @@ -1156,6 +1169,7 @@ public class HealthCheckingLoadBalancerFactoryTest { final ArrayList logs = new ArrayList<>(); final int index; SubchannelStateListener listener; + boolean isShutdown; private final ChannelLogger logger = new ChannelLogger() { @Override public void log(ChannelLogLevel level, String msg) { @@ -1183,6 +1197,7 @@ public class HealthCheckingLoadBalancerFactoryTest { @Override public void shutdown() { + isShutdown = true; deliverSubchannelState(index, ConnectivityStateInfo.forNonError(SHUTDOWN)); }