From 0c17c4c995fcd4b8e37250223120f72223fe150a Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Thu, 9 May 2019 18:13:46 -0700 Subject: [PATCH] api: make LoadBalancer.Helper and Subchannel further non-thread-safe. (#5718) I see more cases of wrapping Helper and Subchannel during the work of XdsLoadBalancer, we will require that all methods that involve mutable state to be called from the Synchronization Context. We will start logging warnings first, and make them throw in a future release. Helper.createSubchannel() is already doing so. This change adds warnings to the other eligible methods. https://github.com/grpc/grpc-java/issues/5015 --- api/src/main/java/io/grpc/LoadBalancer.java | 32 +++ .../io/grpc/internal/ManagedChannelImpl.java | 26 ++- .../ManagedChannelImplIdlenessTest.java | 53 ++++- .../grpc/internal/ManagedChannelImplTest.java | 210 +++++++++++------- 4 files changed, 224 insertions(+), 97 deletions(-) diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index 36568a85f4..479ca4b2a8 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -694,6 +694,10 @@ public abstract class LoadBalancer { * Equivalent to {@link #updateSubchannelAddresses(io.grpc.LoadBalancer.Subchannel, List)} with * the given single {@code EquivalentAddressGroup}. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * * @since 1.4.0 */ public final void updateSubchannelAddresses( @@ -707,6 +711,10 @@ public abstract class LoadBalancer { * {@link #createSubchannel} when the new and old addresses overlap, since the subchannel can * continue using an existing connection. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * * @throws IllegalArgumentException if {@code subchannel} was not returned from {@link * #createSubchannel} or {@code addrs} is empty * @since 1.14.0 @@ -776,6 +784,10 @@ public abstract class LoadBalancer { * updateBalancingState()} has never been called, the channel will buffer all RPCs until a * picker is provided. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * *

The passed state will be the channel's new state. The SHUTDOWN state should not be passed * and its behavior is undefined. * @@ -787,6 +799,10 @@ public abstract class LoadBalancer { /** * Call {@link NameResolver#refresh} on the channel's resolver. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * * @since 1.18.0 */ public void refreshNameResolution() { @@ -903,6 +919,10 @@ public abstract class LoadBalancer { * Shuts down the Subchannel. After this method is called, this Subchannel should no longer * be returned by the latest {@link SubchannelPicker picker}, and can be safely discarded. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * * @since 1.2.0 */ public abstract void shutdown(); @@ -910,6 +930,10 @@ public abstract class LoadBalancer { /** * Asks the Subchannel to create a connection (aka transport), if there isn't an active one. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * * @since 1.2.0 */ public abstract void requestConnection(); @@ -919,6 +943,10 @@ public abstract class LoadBalancer { * the Subchannel has only one {@link EquivalentAddressGroup}. Under the hood it calls * {@link #getAllAddresses}. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * * @throws IllegalStateException if this subchannel has more than one EquivalentAddressGroup. * Use {@link #getAllAddresses} instead * @since 1.2.0 @@ -932,6 +960,10 @@ public abstract class LoadBalancer { /** * Returns the addresses that this Subchannel is bound to. The returned list will not be empty. * + *

It should be called from the Synchronization Context. Currently will log a warning if + * violated. It will become an exception eventually. See #5015 for the background. + * * @since 1.14.0 */ public List getAllAddresses() { diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index d8a97a202a..ae7303697a 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -1048,14 +1048,7 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public AbstractSubchannel createSubchannel( List addressGroups, Attributes attrs) { - try { - syncContext.throwIfNotInThisSynchronizationContext(); - } catch (IllegalStateException e) { - logger.log(Level.WARNING, - "We sugguest you call createSubchannel() from SynchronizationContext." - + " Otherwise, it may race with handleSubchannelState()." - + " See https://github.com/grpc/grpc-java/issues/5015", e); - } + logWarningIfNotInSyncContext("createSubchannel()"); checkNotNull(addressGroups, "addressGroups"); checkNotNull(attrs, "attrs"); // TODO(ejona): can we be even stricter? Like loadBalancer == null? @@ -1149,6 +1142,7 @@ final class ManagedChannelImpl extends ManagedChannel implements final ConnectivityState newState, final SubchannelPicker newPicker) { checkNotNull(newState, "newState"); checkNotNull(newPicker, "newPicker"); + logWarningIfNotInSyncContext("updateBalancingState()"); final class UpdateBalancingState implements Runnable { @Override public void run() { @@ -1170,6 +1164,7 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public void refreshNameResolution() { + logWarningIfNotInSyncContext("refreshNameResolution()"); final class LoadBalancerRefreshNameResolution implements Runnable { @Override public void run() { @@ -1185,6 +1180,7 @@ final class ManagedChannelImpl extends ManagedChannel implements LoadBalancer.Subchannel subchannel, List addrs) { checkArgument(subchannel instanceof SubchannelImpl, "subchannel must have been returned from createSubchannel"); + logWarningIfNotInSyncContext("updateSubchannelAddresses()"); ((SubchannelImpl) subchannel).subchannel.updateAddresses(addrs); } @@ -1478,6 +1474,7 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public void shutdown() { + logWarningIfNotInSyncContext("Subchannel.shutdown()"); synchronized (shutdownLock) { if (shutdownRequested) { if (terminating && delayedShutdownTask != null) { @@ -1521,11 +1518,13 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public void requestConnection() { + logWarningIfNotInSyncContext("Subchannel.requestConnection()"); subchannel.obtainActiveTransport(); } @Override public List getAllAddresses() { + logWarningIfNotInSyncContext("Subchannel.getAllAddresses()"); return subchannel.getAddressGroups(); } @@ -1784,4 +1783,15 @@ final class ManagedChannelImpl extends ManagedChannel implements } } } + + private void logWarningIfNotInSyncContext(String method) { + try { + syncContext.throwIfNotInThisSynchronizationContext(); + } catch (IllegalStateException e) { + logger.log(Level.WARNING, + method + " should be called from SynchronizationContext. " + + "This warning will become an exception in a future release. " + + "See https://github.com/grpc/grpc-java/issues/5015 for more details", e); + } + } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index d7ba03a9fc..f1f0d4e0d7 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -38,6 +38,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; +import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; import io.grpc.IntegerMarshaller; import io.grpc.LoadBalancer; @@ -310,14 +311,14 @@ public class ManagedChannelImplIdlenessTest { // Assume LoadBalancer has received an address, then create a subchannel. Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo t0 = newTransports.poll(); t0.listener.transportReady(); SubchannelPicker mockPicker = mock(SubchannelPicker.class); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); // Delayed transport creates real streams in the app executor executor.runDueTasks(); @@ -350,13 +351,13 @@ public class ManagedChannelImplIdlenessTest { Helper helper = helperCaptor.getValue(); Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo t0 = newTransports.poll(); t0.listener.transportReady(); - helper.updateSubchannelAddresses(subchannel, servers.get(1)); + updateSubchannelAddressesSafely(helper, subchannel, servers.get(1)); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo t1 = newTransports.poll(); t1.listener.transportReady(); } @@ -370,15 +371,15 @@ public class ManagedChannelImplIdlenessTest { Helper helper = helperCaptor.getValue(); Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo t0 = newTransports.poll(); t0.listener.transportReady(); List changedList = new ArrayList<>(servers.get(0).getAddresses()); changedList.add(new FakeSocketAddress("aDifferentServer")); - helper.updateSubchannelAddresses(subchannel, new EquivalentAddressGroup(changedList)); + updateSubchannelAddressesSafely(helper, subchannel, new EquivalentAddressGroup(changedList)); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); assertNull(newTransports.poll()); } @@ -397,7 +398,7 @@ public class ManagedChannelImplIdlenessTest { SubchannelPicker failingPicker = mock(SubchannelPicker.class); when(failingPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withError(Status.UNAVAILABLE)); - helper.updateBalancingState(TRANSIENT_FAILURE, failingPicker); + updateBalancingStateSafely(helper, TRANSIENT_FAILURE, failingPicker); executor.runDueTasks(); verify(mockCallListener).onClose(same(Status.UNAVAILABLE), any(Metadata.class)); @@ -499,7 +500,7 @@ public class ManagedChannelImplIdlenessTest { } } - // We need this because createSubchannel() should be called from the SynchronizationContext + // Helper methods to call methods from SynchronizationContext private static Subchannel createSubchannelSafely( final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { final AtomicReference resultCapture = new AtomicReference<>(); @@ -512,4 +513,36 @@ public class ManagedChannelImplIdlenessTest { }); return resultCapture.get(); } + + private static void requestConnectionSafely(Helper helper, final Subchannel subchannel) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + subchannel.requestConnection(); + } + }); + } + + private static void updateBalancingStateSafely( + final Helper helper, final ConnectivityState state, final SubchannelPicker picker) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + helper.updateBalancingState(state, picker); + } + }); + } + + private static void updateSubchannelAddressesSafely( + final Helper helper, final Subchannel subchannel, final EquivalentAddressGroup addrs) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + helper.updateSubchannelAddresses(subchannel, addrs); + } + }); + } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 67f9273fe9..60c13d0b2c 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -360,7 +360,7 @@ public class ManagedChannelImplTest { LogRecord record = logRef.get(); assertThat(record.getLevel()).isEqualTo(Level.WARNING); assertThat(record.getMessage()).contains( - "We sugguest you call createSubchannel() from SynchronizationContext"); + "createSubchannel() should be called from SynchronizationContext"); assertThat(record.getThrown()).isInstanceOf(IllegalStateException.class); } finally { logger.removeHandler(handler); @@ -434,7 +434,7 @@ public class ManagedChannelImplTest { assertThat(getStats(channel).subchannels) .containsExactly(subchannel.getInternalSubchannel()); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo transportInfo = transports.poll(); assertNotNull(transportInfo); assertTrue(channelz.containsClientSocket(transportInfo.transport.getLogId())); @@ -445,7 +445,7 @@ public class ManagedChannelImplTest { // terminate subchannel assertTrue(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId())); - subchannel.shutdown(); + shutdownSafely(helper, subchannel); timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS, TimeUnit.SECONDS); timer.runDueTasks(); assertFalse(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId())); @@ -520,7 +520,7 @@ public class ManagedChannelImplTest { // Configure the picker so that first RPC goes to delayed transport, and second RPC goes to // real transport. Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); verify(mockTransportFactory) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); @@ -539,7 +539,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel( new PickSubchannelArgsImpl(method, headers2, CallOptions.DEFAULT))).thenReturn( PickResult.withSubchannel(subchannel)); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); // First RPC, will be pending ClientCall call = channel.newCall(method, CallOptions.DEFAULT); @@ -598,7 +598,7 @@ public class ManagedChannelImplTest { SubchannelPicker picker2 = mock(SubchannelPicker.class); when(picker2.pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT))) .thenReturn(PickResult.withSubchannel(subchannel)); - helper.updateBalancingState(READY, picker2); + updateBalancingStateSafely(helper, READY, picker2); executor.runDueTasks(); verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); verify(mockStream).start(any(ClientStreamListener.class)); @@ -616,7 +616,7 @@ public class ManagedChannelImplTest { verify(mockTransport, never()).shutdownNow(any(Status.class)); } // LoadBalancer should shutdown the subchannel - subchannel.shutdown(); + shutdownSafely(helper, subchannel); if (shutdownNow) { verify(mockTransport).shutdown(same(ManagedChannelImpl.SHUTDOWN_NOW_STATUS)); } else { @@ -660,8 +660,8 @@ public class ManagedChannelImplTest { Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel1.requestConnection(); - subchannel2.requestConnection(); + requestConnectionSafely(helper, subchannel1); + requestConnectionSafely(helper, subchannel2); verify(mockTransportFactory, times(2)) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); @@ -729,7 +729,7 @@ public class ManagedChannelImplTest { verify(mockTransportFactory, never()) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); verify(mockTransportFactory) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); @@ -742,7 +742,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); assertEquals(0, callExecutor.numPendingTasks()); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); // Real streams are started in the call executor if they were previously buffered. assertEquals(1, callExecutor.runDueTasks()); @@ -763,7 +763,7 @@ public class ManagedChannelImplTest { transportListener.transportTerminated(); // Clean up as much as possible to allow the channel to terminate. - subchannel.shutdown(); + shutdownSafely(helper, subchannel); timer.forwardNanos( TimeUnit.SECONDS.toNanos(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS)); } @@ -843,7 +843,7 @@ public class ManagedChannelImplTest { Status status = Status.UNAVAILABLE.withDescription("for test"); when(picker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withDrop(status)); - helper.updateBalancingState(READY, picker); + updateBalancingStateSafely(helper, READY, picker); executor.runDueTasks(); verify(mockCallListener).onClose(same(status), any(Metadata.class)); @@ -987,7 +987,7 @@ public class ManagedChannelImplTest { Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); inOrder.verify(mockLoadBalancer).handleSubchannelState( same(subchannel), stateInfoCaptor.capture()); assertEquals(CONNECTING, stateInfoCaptor.getValue().getState()); @@ -1020,7 +1020,7 @@ public class ManagedChannelImplTest { assertEquals(READY, stateInfoCaptor.getValue().getState()); // A typical LoadBalancer will call this once the subchannel becomes READY - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); // Delayed transport uses the app executor to create real streams. executor.runDueTasks(); @@ -1069,7 +1069,7 @@ public class ManagedChannelImplTest { when(picker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(drop ? PickResult.withDrop(status) : PickResult.withError(status)); - helper.updateBalancingState(READY, picker); + updateBalancingStateSafely(helper, READY, picker); executor.runDueTasks(); if (shouldFail) { @@ -1137,7 +1137,7 @@ public class ManagedChannelImplTest { Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); inOrder.verify(mockLoadBalancer).handleSubchannelState( same(subchannel), stateInfoCaptor.capture()); @@ -1172,7 +1172,7 @@ public class ManagedChannelImplTest { SubchannelPicker picker2 = mock(SubchannelPicker.class); when(picker2.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withError(server2Error)); - helper.updateBalancingState(TRANSIENT_FAILURE, picker2); + updateBalancingStateSafely(helper, TRANSIENT_FAILURE, picker2); executor.runDueTasks(); // ... which fails the fail-fast call @@ -1193,14 +1193,24 @@ public class ManagedChannelImplTest { // createSubchannel() always return a new Subchannel Attributes attrs1 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr1").build(); Attributes attrs2 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr2").build(); - Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1); - Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2); + final Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1); + final Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2); assertNotSame(sub1, sub2); assertNotSame(attrs1, attrs2); assertSame(attrs1, sub1.getAttributes()); assertSame(attrs2, sub2.getAttributes()); - assertSame(addressGroup, sub1.getAddresses()); - assertSame(addressGroup, sub2.getAddresses()); + + final AtomicBoolean snippetPassed = new AtomicBoolean(false); + helper.getSynchronizationContext().execute(new Runnable() { + @Override + public void run() { + // getAddresses() must be called from sync context + assertSame(addressGroup, sub1.getAddresses()); + assertSame(addressGroup, sub2.getAddresses()); + snippetPassed.set(true); + } + }); + assertThat(snippetPassed.get()).isTrue(); // requestConnection() verify(mockTransportFactory, never()) @@ -1208,7 +1218,7 @@ public class ManagedChannelImplTest { any(SocketAddress.class), any(ClientTransportOptions.class), any(TransportLogger.class)); - sub1.requestConnection(); + requestConnectionSafely(helper, sub1); verify(mockTransportFactory) .newClientTransport( eq(socketAddress), @@ -1217,7 +1227,7 @@ public class ManagedChannelImplTest { MockClientTransportInfo transportInfo1 = transports.poll(); assertNotNull(transportInfo1); - sub2.requestConnection(); + requestConnectionSafely(helper, sub2); verify(mockTransportFactory, times(2)) .newClientTransport( eq(socketAddress), @@ -1226,17 +1236,17 @@ public class ManagedChannelImplTest { MockClientTransportInfo transportInfo2 = transports.poll(); assertNotNull(transportInfo2); - sub1.requestConnection(); - sub2.requestConnection(); + requestConnectionSafely(helper, sub1); + requestConnectionSafely(helper, sub2); // The subchannel doesn't matter since this isn't called verify(mockTransportFactory, times(2)) .newClientTransport( eq(socketAddress), eq(clientTransportOptions), isA(TransportLogger.class)); // shutdown() has a delay - sub1.shutdown(); + shutdownSafely(helper, sub1); timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS - 1, TimeUnit.SECONDS); - sub1.shutdown(); + shutdownSafely(helper, sub1); verify(transportInfo1.transport, never()).shutdown(any(Status.class)); timer.forwardTime(1, TimeUnit.SECONDS); verify(transportInfo1.transport).shutdown(same(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_STATUS)); @@ -1247,7 +1257,7 @@ public class ManagedChannelImplTest { verify(mockLoadBalancer).shutdown(); verify(transportInfo2.transport, never()).shutdown(any(Status.class)); - sub2.shutdown(); + shutdownSafely(helper, sub2); verify(transportInfo2.transport).shutdown(same(ManagedChannelImpl.SHUTDOWN_STATUS)); // Cleanup @@ -1263,8 +1273,8 @@ public class ManagedChannelImplTest { createChannel(); Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - sub1.requestConnection(); - sub2.requestConnection(); + requestConnectionSafely(helper, sub1); + requestConnectionSafely(helper, sub2); assertThat(transports).hasSize(2); MockClientTransportInfo ti1 = transports.poll(); @@ -1294,9 +1304,9 @@ public class ManagedChannelImplTest { channel.shutdown(); verify(mockLoadBalancer).shutdown(); - sub1.shutdown(); + shutdownSafely(helper, sub1); assertFalse(channel.isTerminated()); - sub2.shutdown(); + shutdownSafely(helper, sub2); assertTrue(channel.isTerminated()); verify(mockTransportFactory, never()) .newClientTransport( @@ -1499,7 +1509,7 @@ public class ManagedChannelImplTest { CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(5, TimeUnit.SECONDS); // Subchannel must be READY when creating the RPC. - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); verify(mockTransportFactory) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); @@ -1524,7 +1534,7 @@ public class ManagedChannelImplTest { Channel sChannel = subchannel.asChannel(); Metadata headers = new Metadata(); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); verify(mockTransportFactory) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); @@ -1553,7 +1563,7 @@ public class ManagedChannelImplTest { Metadata headers = new Metadata(); // Subchannel must be READY when creating the RPC. - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); verify(mockTransportFactory) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); @@ -1656,7 +1666,7 @@ public class ManagedChannelImplTest { oobChannel.getSubchannel().requestConnection(); } else { Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); } MockClientTransportInfo transportInfo = transports.poll(); @@ -1743,7 +1753,7 @@ public class ManagedChannelImplTest { // Simulate name resolution results EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(socketAddress); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); verify(mockTransportFactory) .newClientTransport( same(socketAddress), eq(clientTransportOptions), any(ChannelLogger.class)); @@ -1766,7 +1776,7 @@ public class ManagedChannelImplTest { transportInfo.listener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); @@ -1816,7 +1826,7 @@ public class ManagedChannelImplTest { ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); createChannel(); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; @@ -1826,7 +1836,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory2)); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); CallOptions callOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1); ClientCall call = channel.newCall(method, callOptions); @@ -1854,7 +1864,7 @@ public class ManagedChannelImplTest { call.start(mockCallListener, new Metadata()); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; @@ -1864,7 +1874,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory2)); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); assertEquals(1, executor.runDueTasks()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); @@ -1884,7 +1894,7 @@ public class ManagedChannelImplTest { createChannel(); assertEquals(IDLE, channel.getState(false)); - helper.updateBalancingState(TRANSIENT_FAILURE, mockPicker); + updateBalancingStateSafely(helper, TRANSIENT_FAILURE, mockPicker); assertEquals(TRANSIENT_FAILURE, channel.getState(false)); } @@ -1904,7 +1914,7 @@ public class ManagedChannelImplTest { verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertEquals(CONNECTING, channel.getState(false)); assertEquals(CONNECTING, channel.getState(true)); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); @@ -1917,7 +1927,7 @@ public class ManagedChannelImplTest { createChannel(); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); - helper.updateBalancingState(IDLE, mockPicker); + updateBalancingStateSafely(helper, IDLE, mockPicker); assertEquals(IDLE, channel.getState(true)); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); @@ -1944,7 +1954,7 @@ public class ManagedChannelImplTest { assertFalse(stateChanged.get()); // state change from IDLE to CONNECTING - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); // onStateChanged callback should run executor.runDueTasks(); assertTrue(stateChanged.get()); @@ -1982,7 +1992,7 @@ public class ManagedChannelImplTest { stateChanged.set(false); channel.notifyWhenStateChanged(SHUTDOWN, onStateChanged); - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertEquals(SHUTDOWN, channel.getState(false)); executor.runDueTasks(); @@ -1996,7 +2006,7 @@ public class ManagedChannelImplTest { createChannel(); assertEquals(IDLE, channel.getState(false)); - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertEquals(CONNECTING, channel.getState(false)); timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); @@ -2040,7 +2050,7 @@ public class ManagedChannelImplTest { if (initialState == IDLE) { timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); } else { - helper.updateBalancingState(initialState, mockPicker); + updateBalancingStateSafely(helper, initialState, mockPicker); } assertEquals(initialState, channel.getState(false)); @@ -2083,7 +2093,7 @@ public class ManagedChannelImplTest { // A misbehaving balancer that calls updateBalancingState() after it's shut down will not be // able to revive it. - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); verifyPanicMode(panicReason); // Cannot be revived by exitIdleMode() @@ -2106,7 +2116,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withNoResult()); - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); // Start RPCs that will be buffered in delayedTransport ClientCall call = @@ -2180,10 +2190,10 @@ public class ManagedChannelImplTest { // Updating on the old helper (whose balancer has been shutdown) does not change the channel // state. - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertEquals(IDLE, channel.getState(false)); - helper2.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper2, CONNECTING, mockPicker); assertEquals(CONNECTING, channel.getState(false)); } @@ -2207,7 +2217,7 @@ public class ManagedChannelImplTest { // Move channel into TRANSIENT_FAILURE, which will fail the pending call when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withError(pickError)); - helper.updateBalancingState(TRANSIENT_FAILURE, mockPicker); + updateBalancingStateSafely(helper, TRANSIENT_FAILURE, mockPicker); assertEquals(TRANSIENT_FAILURE, channel.getState(false)); executor.runDueTasks(); verify(mockCallListener).onClose(same(pickError), any(Metadata.class)); @@ -2229,7 +2239,7 @@ public class ManagedChannelImplTest { // Establish a connection Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; @@ -2239,7 +2249,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - helper2.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper2, READY, mockPicker); assertEquals(READY, channel.getState(false)); executor.runDueTasks(); @@ -2251,7 +2261,7 @@ public class ManagedChannelImplTest { @Test public void enterIdleEntersIdle() { createChannel(); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); assertEquals(READY, channel.getState(false)); channel.enterIdle(); @@ -2297,7 +2307,7 @@ public class ManagedChannelImplTest { // Establish a connection Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); ClientStream mockStream = mock(ClientStream.class); MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; @@ -2307,7 +2317,7 @@ public class ManagedChannelImplTest { transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - helper2.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper2, READY, mockPicker); assertEquals(READY, channel.getState(false)); // Verify the original call was drained @@ -2327,7 +2337,7 @@ public class ManagedChannelImplTest { // Make the transport available with subchannel2 Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel2.requestConnection(); + requestConnectionSafely(helper, subchannel2); MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; @@ -2338,7 +2348,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel1)); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); verify(mockTransport, never()) @@ -2348,7 +2358,7 @@ public class ManagedChannelImplTest { when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel2)); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); @@ -2365,7 +2375,7 @@ public class ManagedChannelImplTest { Runnable onStateChanged = mock(Runnable.class); channel.notifyWhenStateChanged(IDLE, onStateChanged); - helper.updateBalancingState(SHUTDOWN, mockPicker); + updateBalancingStateSafely(helper, SHUTDOWN, mockPicker); assertEquals(IDLE, channel.getState(false)); executor.runDueTasks(); @@ -2381,7 +2391,7 @@ public class ManagedChannelImplTest { FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.get(0); int initialRefreshCount = resolver.refreshCalled; - helper.refreshNameResolution(); + refreshNameResolutionSafely(helper); assertEquals(initialRefreshCount + 1, resolver.refreshCalled); } @@ -2648,7 +2658,7 @@ public class ManagedChannelImplTest { channelBuilder.maxTraceEvents(10); createChannel(); timer.forwardNanos(1234); - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() .setDescription("Entering CONNECTING state") .setSeverity(ChannelTrace.Event.Severity.CT_INFO) @@ -2720,14 +2730,14 @@ public class ManagedChannelImplTest { helper = helperCaptor.getValue(); assertEquals(IDLE, getStats(channel).state); - helper.updateBalancingState(CONNECTING, mockPicker); + updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertEquals(CONNECTING, getStats(channel).state); AbstractSubchannel subchannel = (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); assertEquals(IDLE, getStats(subchannel).state); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); assertEquals(CONNECTING, getStats(subchannel).state); MockClientTransportInfo transportInfo = transports.poll(); @@ -2737,7 +2747,7 @@ public class ManagedChannelImplTest { assertEquals(READY, getStats(subchannel).state); assertEquals(CONNECTING, getStats(channel).state); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); assertEquals(READY, getStats(channel).state); channel.shutdownNow(); @@ -2779,7 +2789,7 @@ public class ManagedChannelImplTest { ClientStreamTracer.Factory factory = mock(ClientStreamTracer.Factory.class); AbstractSubchannel subchannel = (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; @@ -2791,7 +2801,7 @@ public class ManagedChannelImplTest { // subchannel stat bumped when call gets assigned to it assertEquals(0, getStats(subchannel).callsStarted); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); assertEquals(1, executor.runDueTasks()); verify(mockStream).start(streamListenerCaptor.capture()); assertEquals(1, getStats(subchannel).callsStarted); @@ -3018,7 +3028,7 @@ public class ManagedChannelImplTest { Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); @@ -3026,7 +3036,7 @@ public class ManagedChannelImplTest { when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); ArgumentCaptor streamListenerCaptor = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -3067,7 +3077,7 @@ public class ManagedChannelImplTest { streamListenerCaptor.getValue().closed(Status.INTERNAL, new Metadata()); verify(mockLoadBalancer).shutdown(); // simulating the shutdown of load balancer triggers the shutdown of subchannel - subchannel.shutdown(); + shutdownSafely(helper, subchannel); transportInfo.listener.transportTerminated(); // simulating transport terminated assertTrue( "channel.isTerminated() is expected to be true but was false", @@ -3114,10 +3124,10 @@ public class ManagedChannelImplTest { .build()); // simulating request connection and then transport ready after resolved address - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); - subchannel.requestConnection(); + requestConnectionSafely(helper, subchannel); MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); @@ -3125,7 +3135,7 @@ public class ManagedChannelImplTest { when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); - helper.updateBalancingState(READY, mockPicker); + updateBalancingStateSafely(helper, READY, mockPicker); ArgumentCaptor streamListenerCaptor = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -3165,7 +3175,7 @@ public class ManagedChannelImplTest { assertThat(timer.numPendingTasks()).isEqualTo(0); verify(mockLoadBalancer).shutdown(); // simulating the shutdown of load balancer triggers the shutdown of subchannel - subchannel.shutdown(); + shutdownSafely(helper, subchannel); transportInfo.listener.transportTerminated(); // simulating transport terminated assertTrue( "channel.isTerminated() is expected to be true but was false", @@ -3899,7 +3909,7 @@ public class ManagedChannelImplTest { return Iterables.getOnlyElement(timer.getPendingTasks(NAME_RESOLVER_REFRESH_TASK_FILTER), null); } - // We need this because createSubchannel() should be called from the SynchronizationContext + // Helper methods to call methods from SynchronizationContext private static Subchannel createSubchannelSafely( final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { final AtomicReference resultCapture = new AtomicReference<>(); @@ -3913,6 +3923,48 @@ public class ManagedChannelImplTest { return resultCapture.get(); } + private static void requestConnectionSafely(Helper helper, final Subchannel subchannel) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + subchannel.requestConnection(); + } + }); + } + + private static void updateBalancingStateSafely( + final Helper helper, final ConnectivityState state, final SubchannelPicker picker) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + helper.updateBalancingState(state, picker); + } + }); + } + + private static void refreshNameResolutionSafely(final Helper helper) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + helper.refreshNameResolution(); + } + }); + } + + private static void shutdownSafely( + final Helper helper, final Subchannel subchannel) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + subchannel.shutdown(); + } + }); + } + @SuppressWarnings("unchecked") private static Map parseConfig(String json) throws Exception { return (Map) JsonParser.parse(json);