api: make LoadBalancer.Helper and Subchannel further non-thread-safe. (#5718)

I see more cases of wrapping Helper and Subchannel during the work of
XdsLoadBalancer, we will require that all methods that involve mutable
state to be called from the Synchronization Context.  We will start
logging warnings first, and make them throw in a future release.

Helper.createSubchannel() is already doing so.  This change adds
warnings to the other eligible methods.

https://github.com/grpc/grpc-java/issues/5015
This commit is contained in:
Kun Zhang 2019-05-09 18:13:46 -07:00 committed by GitHub
parent 175a423c10
commit 0c17c4c995
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 224 additions and 97 deletions

View File

@ -694,6 +694,10 @@ public abstract class LoadBalancer {
* Equivalent to {@link #updateSubchannelAddresses(io.grpc.LoadBalancer.Subchannel, List)} with * Equivalent to {@link #updateSubchannelAddresses(io.grpc.LoadBalancer.Subchannel, List)} with
* the given single {@code EquivalentAddressGroup}. * the given single {@code EquivalentAddressGroup}.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* @since 1.4.0 * @since 1.4.0
*/ */
public final void updateSubchannelAddresses( public final void updateSubchannelAddresses(
@ -707,6 +711,10 @@ public abstract class LoadBalancer {
* {@link #createSubchannel} when the new and old addresses overlap, since the subchannel can * {@link #createSubchannel} when the new and old addresses overlap, since the subchannel can
* continue using an existing connection. * continue using an existing connection.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* @throws IllegalArgumentException if {@code subchannel} was not returned from {@link * @throws IllegalArgumentException if {@code subchannel} was not returned from {@link
* #createSubchannel} or {@code addrs} is empty * #createSubchannel} or {@code addrs} is empty
* @since 1.14.0 * @since 1.14.0
@ -776,6 +784,10 @@ public abstract class LoadBalancer {
* updateBalancingState()} has never been called, the channel will buffer all RPCs until a * updateBalancingState()} has never been called, the channel will buffer all RPCs until a
* picker is provided. * picker is provided.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* <p>The passed state will be the channel's new state. The SHUTDOWN state should not be passed * <p>The passed state will be the channel's new state. The SHUTDOWN state should not be passed
* and its behavior is undefined. * and its behavior is undefined.
* *
@ -787,6 +799,10 @@ public abstract class LoadBalancer {
/** /**
* Call {@link NameResolver#refresh} on the channel's resolver. * Call {@link NameResolver#refresh} on the channel's resolver.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* @since 1.18.0 * @since 1.18.0
*/ */
public void refreshNameResolution() { public void refreshNameResolution() {
@ -903,6 +919,10 @@ public abstract class LoadBalancer {
* Shuts down the Subchannel. After this method is called, this Subchannel should no longer * Shuts down the Subchannel. After this method is called, this Subchannel should no longer
* be returned by the latest {@link SubchannelPicker picker}, and can be safely discarded. * be returned by the latest {@link SubchannelPicker picker}, and can be safely discarded.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* @since 1.2.0 * @since 1.2.0
*/ */
public abstract void shutdown(); public abstract void shutdown();
@ -910,6 +930,10 @@ public abstract class LoadBalancer {
/** /**
* Asks the Subchannel to create a connection (aka transport), if there isn't an active one. * Asks the Subchannel to create a connection (aka transport), if there isn't an active one.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* @since 1.2.0 * @since 1.2.0
*/ */
public abstract void requestConnection(); public abstract void requestConnection();
@ -919,6 +943,10 @@ public abstract class LoadBalancer {
* the Subchannel has only one {@link EquivalentAddressGroup}. Under the hood it calls * the Subchannel has only one {@link EquivalentAddressGroup}. Under the hood it calls
* {@link #getAllAddresses}. * {@link #getAllAddresses}.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* @throws IllegalStateException if this subchannel has more than one EquivalentAddressGroup. * @throws IllegalStateException if this subchannel has more than one EquivalentAddressGroup.
* Use {@link #getAllAddresses} instead * Use {@link #getAllAddresses} instead
* @since 1.2.0 * @since 1.2.0
@ -932,6 +960,10 @@ public abstract class LoadBalancer {
/** /**
* Returns the addresses that this Subchannel is bound to. The returned list will not be empty. * Returns the addresses that this Subchannel is bound to. The returned list will not be empty.
* *
* <p>It should be called from the Synchronization Context. Currently will log a warning if
* violated. It will become an exception eventually. See <a
* href="https://github.com/grpc/grpc-java/issues/5015">#5015</a> for the background.
*
* @since 1.14.0 * @since 1.14.0
*/ */
public List<EquivalentAddressGroup> getAllAddresses() { public List<EquivalentAddressGroup> getAllAddresses() {

View File

@ -1048,14 +1048,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override @Override
public AbstractSubchannel createSubchannel( public AbstractSubchannel createSubchannel(
List<EquivalentAddressGroup> addressGroups, Attributes attrs) { List<EquivalentAddressGroup> addressGroups, Attributes attrs) {
try { logWarningIfNotInSyncContext("createSubchannel()");
syncContext.throwIfNotInThisSynchronizationContext();
} catch (IllegalStateException e) {
logger.log(Level.WARNING,
"We sugguest you call createSubchannel() from SynchronizationContext."
+ " Otherwise, it may race with handleSubchannelState()."
+ " See https://github.com/grpc/grpc-java/issues/5015", e);
}
checkNotNull(addressGroups, "addressGroups"); checkNotNull(addressGroups, "addressGroups");
checkNotNull(attrs, "attrs"); checkNotNull(attrs, "attrs");
// TODO(ejona): can we be even stricter? Like loadBalancer == null? // TODO(ejona): can we be even stricter? Like loadBalancer == null?
@ -1149,6 +1142,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
final ConnectivityState newState, final SubchannelPicker newPicker) { final ConnectivityState newState, final SubchannelPicker newPicker) {
checkNotNull(newState, "newState"); checkNotNull(newState, "newState");
checkNotNull(newPicker, "newPicker"); checkNotNull(newPicker, "newPicker");
logWarningIfNotInSyncContext("updateBalancingState()");
final class UpdateBalancingState implements Runnable { final class UpdateBalancingState implements Runnable {
@Override @Override
public void run() { public void run() {
@ -1170,6 +1164,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override @Override
public void refreshNameResolution() { public void refreshNameResolution() {
logWarningIfNotInSyncContext("refreshNameResolution()");
final class LoadBalancerRefreshNameResolution implements Runnable { final class LoadBalancerRefreshNameResolution implements Runnable {
@Override @Override
public void run() { public void run() {
@ -1185,6 +1180,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
LoadBalancer.Subchannel subchannel, List<EquivalentAddressGroup> addrs) { LoadBalancer.Subchannel subchannel, List<EquivalentAddressGroup> addrs) {
checkArgument(subchannel instanceof SubchannelImpl, checkArgument(subchannel instanceof SubchannelImpl,
"subchannel must have been returned from createSubchannel"); "subchannel must have been returned from createSubchannel");
logWarningIfNotInSyncContext("updateSubchannelAddresses()");
((SubchannelImpl) subchannel).subchannel.updateAddresses(addrs); ((SubchannelImpl) subchannel).subchannel.updateAddresses(addrs);
} }
@ -1478,6 +1474,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override @Override
public void shutdown() { public void shutdown() {
logWarningIfNotInSyncContext("Subchannel.shutdown()");
synchronized (shutdownLock) { synchronized (shutdownLock) {
if (shutdownRequested) { if (shutdownRequested) {
if (terminating && delayedShutdownTask != null) { if (terminating && delayedShutdownTask != null) {
@ -1521,11 +1518,13 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override @Override
public void requestConnection() { public void requestConnection() {
logWarningIfNotInSyncContext("Subchannel.requestConnection()");
subchannel.obtainActiveTransport(); subchannel.obtainActiveTransport();
} }
@Override @Override
public List<EquivalentAddressGroup> getAllAddresses() { public List<EquivalentAddressGroup> getAllAddresses() {
logWarningIfNotInSyncContext("Subchannel.getAllAddresses()");
return subchannel.getAddressGroups(); return subchannel.getAddressGroups();
} }
@ -1784,4 +1783,15 @@ final class ManagedChannelImpl extends ManagedChannel implements
} }
} }
} }
private void logWarningIfNotInSyncContext(String method) {
try {
syncContext.throwIfNotInThisSynchronizationContext();
} catch (IllegalStateException e) {
logger.log(Level.WARNING,
method + " should be called from SynchronizationContext. "
+ "This warning will become an exception in a future release. "
+ "See https://github.com/grpc/grpc-java/issues/5015 for more details", e);
}
}
} }

View File

@ -38,6 +38,7 @@ import io.grpc.CallOptions;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
import io.grpc.IntegerMarshaller; import io.grpc.IntegerMarshaller;
import io.grpc.LoadBalancer; import io.grpc.LoadBalancer;
@ -310,14 +311,14 @@ public class ManagedChannelImplIdlenessTest {
// Assume LoadBalancer has received an address, then create a subchannel. // Assume LoadBalancer has received an address, then create a subchannel.
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo t0 = newTransports.poll(); MockClientTransportInfo t0 = newTransports.poll();
t0.listener.transportReady(); t0.listener.transportReady();
SubchannelPicker mockPicker = mock(SubchannelPicker.class); SubchannelPicker mockPicker = mock(SubchannelPicker.class);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
// Delayed transport creates real streams in the app executor // Delayed transport creates real streams in the app executor
executor.runDueTasks(); executor.runDueTasks();
@ -350,13 +351,13 @@ public class ManagedChannelImplIdlenessTest {
Helper helper = helperCaptor.getValue(); Helper helper = helperCaptor.getValue();
Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo t0 = newTransports.poll(); MockClientTransportInfo t0 = newTransports.poll();
t0.listener.transportReady(); t0.listener.transportReady();
helper.updateSubchannelAddresses(subchannel, servers.get(1)); updateSubchannelAddressesSafely(helper, subchannel, servers.get(1));
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo t1 = newTransports.poll(); MockClientTransportInfo t1 = newTransports.poll();
t1.listener.transportReady(); t1.listener.transportReady();
} }
@ -370,15 +371,15 @@ public class ManagedChannelImplIdlenessTest {
Helper helper = helperCaptor.getValue(); Helper helper = helperCaptor.getValue();
Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo t0 = newTransports.poll(); MockClientTransportInfo t0 = newTransports.poll();
t0.listener.transportReady(); t0.listener.transportReady();
List<SocketAddress> changedList = new ArrayList<>(servers.get(0).getAddresses()); List<SocketAddress> changedList = new ArrayList<>(servers.get(0).getAddresses());
changedList.add(new FakeSocketAddress("aDifferentServer")); changedList.add(new FakeSocketAddress("aDifferentServer"));
helper.updateSubchannelAddresses(subchannel, new EquivalentAddressGroup(changedList)); updateSubchannelAddressesSafely(helper, subchannel, new EquivalentAddressGroup(changedList));
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
assertNull(newTransports.poll()); assertNull(newTransports.poll());
} }
@ -397,7 +398,7 @@ public class ManagedChannelImplIdlenessTest {
SubchannelPicker failingPicker = mock(SubchannelPicker.class); SubchannelPicker failingPicker = mock(SubchannelPicker.class);
when(failingPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(failingPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withError(Status.UNAVAILABLE)); .thenReturn(PickResult.withError(Status.UNAVAILABLE));
helper.updateBalancingState(TRANSIENT_FAILURE, failingPicker); updateBalancingStateSafely(helper, TRANSIENT_FAILURE, failingPicker);
executor.runDueTasks(); executor.runDueTasks();
verify(mockCallListener).onClose(same(Status.UNAVAILABLE), any(Metadata.class)); verify(mockCallListener).onClose(same(Status.UNAVAILABLE), any(Metadata.class));
@ -499,7 +500,7 @@ public class ManagedChannelImplIdlenessTest {
} }
} }
// We need this because createSubchannel() should be called from the SynchronizationContext // Helper methods to call methods from SynchronizationContext
private static Subchannel createSubchannelSafely( private static Subchannel createSubchannelSafely(
final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) {
final AtomicReference<Subchannel> resultCapture = new AtomicReference<>(); final AtomicReference<Subchannel> resultCapture = new AtomicReference<>();
@ -512,4 +513,36 @@ public class ManagedChannelImplIdlenessTest {
}); });
return resultCapture.get(); return resultCapture.get();
} }
private static void requestConnectionSafely(Helper helper, final Subchannel subchannel) {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
subchannel.requestConnection();
}
});
}
private static void updateBalancingStateSafely(
final Helper helper, final ConnectivityState state, final SubchannelPicker picker) {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
helper.updateBalancingState(state, picker);
}
});
}
private static void updateSubchannelAddressesSafely(
final Helper helper, final Subchannel subchannel, final EquivalentAddressGroup addrs) {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
helper.updateSubchannelAddresses(subchannel, addrs);
}
});
}
} }

View File

@ -360,7 +360,7 @@ public class ManagedChannelImplTest {
LogRecord record = logRef.get(); LogRecord record = logRef.get();
assertThat(record.getLevel()).isEqualTo(Level.WARNING); assertThat(record.getLevel()).isEqualTo(Level.WARNING);
assertThat(record.getMessage()).contains( assertThat(record.getMessage()).contains(
"We sugguest you call createSubchannel() from SynchronizationContext"); "createSubchannel() should be called from SynchronizationContext");
assertThat(record.getThrown()).isInstanceOf(IllegalStateException.class); assertThat(record.getThrown()).isInstanceOf(IllegalStateException.class);
} finally { } finally {
logger.removeHandler(handler); logger.removeHandler(handler);
@ -434,7 +434,7 @@ public class ManagedChannelImplTest {
assertThat(getStats(channel).subchannels) assertThat(getStats(channel).subchannels)
.containsExactly(subchannel.getInternalSubchannel()); .containsExactly(subchannel.getInternalSubchannel());
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
assertNotNull(transportInfo); assertNotNull(transportInfo);
assertTrue(channelz.containsClientSocket(transportInfo.transport.getLogId())); assertTrue(channelz.containsClientSocket(transportInfo.transport.getLogId()));
@ -445,7 +445,7 @@ public class ManagedChannelImplTest {
// terminate subchannel // terminate subchannel
assertTrue(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId())); assertTrue(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId()));
subchannel.shutdown(); shutdownSafely(helper, subchannel);
timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS, TimeUnit.SECONDS); timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS, TimeUnit.SECONDS);
timer.runDueTasks(); timer.runDueTasks();
assertFalse(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId())); assertFalse(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId()));
@ -520,7 +520,7 @@ public class ManagedChannelImplTest {
// Configure the picker so that first RPC goes to delayed transport, and second RPC goes to // Configure the picker so that first RPC goes to delayed transport, and second RPC goes to
// real transport. // real transport.
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory) verify(mockTransportFactory)
.newClientTransport( .newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
@ -539,7 +539,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel( when(mockPicker.pickSubchannel(
new PickSubchannelArgsImpl(method, headers2, CallOptions.DEFAULT))).thenReturn( new PickSubchannelArgsImpl(method, headers2, CallOptions.DEFAULT))).thenReturn(
PickResult.withSubchannel(subchannel)); PickResult.withSubchannel(subchannel));
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
// First RPC, will be pending // First RPC, will be pending
ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT);
@ -598,7 +598,7 @@ public class ManagedChannelImplTest {
SubchannelPicker picker2 = mock(SubchannelPicker.class); SubchannelPicker picker2 = mock(SubchannelPicker.class);
when(picker2.pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT))) when(picker2.pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
helper.updateBalancingState(READY, picker2); updateBalancingStateSafely(helper, READY, picker2);
executor.runDueTasks(); executor.runDueTasks();
verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT));
verify(mockStream).start(any(ClientStreamListener.class)); verify(mockStream).start(any(ClientStreamListener.class));
@ -616,7 +616,7 @@ public class ManagedChannelImplTest {
verify(mockTransport, never()).shutdownNow(any(Status.class)); verify(mockTransport, never()).shutdownNow(any(Status.class));
} }
// LoadBalancer should shutdown the subchannel // LoadBalancer should shutdown the subchannel
subchannel.shutdown(); shutdownSafely(helper, subchannel);
if (shutdownNow) { if (shutdownNow) {
verify(mockTransport).shutdown(same(ManagedChannelImpl.SHUTDOWN_NOW_STATUS)); verify(mockTransport).shutdown(same(ManagedChannelImpl.SHUTDOWN_NOW_STATUS));
} else { } else {
@ -660,8 +660,8 @@ public class ManagedChannelImplTest {
Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel1.requestConnection(); requestConnectionSafely(helper, subchannel1);
subchannel2.requestConnection(); requestConnectionSafely(helper, subchannel2);
verify(mockTransportFactory, times(2)) verify(mockTransportFactory, times(2))
.newClientTransport( .newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
@ -729,7 +729,7 @@ public class ManagedChannelImplTest {
verify(mockTransportFactory, never()) verify(mockTransportFactory, never())
.newClientTransport( .newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory) verify(mockTransportFactory)
.newClientTransport( .newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
@ -742,7 +742,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
assertEquals(0, callExecutor.numPendingTasks()); assertEquals(0, callExecutor.numPendingTasks());
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
// Real streams are started in the call executor if they were previously buffered. // Real streams are started in the call executor if they were previously buffered.
assertEquals(1, callExecutor.runDueTasks()); assertEquals(1, callExecutor.runDueTasks());
@ -763,7 +763,7 @@ public class ManagedChannelImplTest {
transportListener.transportTerminated(); transportListener.transportTerminated();
// Clean up as much as possible to allow the channel to terminate. // Clean up as much as possible to allow the channel to terminate.
subchannel.shutdown(); shutdownSafely(helper, subchannel);
timer.forwardNanos( timer.forwardNanos(
TimeUnit.SECONDS.toNanos(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS)); TimeUnit.SECONDS.toNanos(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS));
} }
@ -843,7 +843,7 @@ public class ManagedChannelImplTest {
Status status = Status.UNAVAILABLE.withDescription("for test"); Status status = Status.UNAVAILABLE.withDescription("for test");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))) when(picker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withDrop(status)); .thenReturn(PickResult.withDrop(status));
helper.updateBalancingState(READY, picker); updateBalancingStateSafely(helper, READY, picker);
executor.runDueTasks(); executor.runDueTasks();
verify(mockCallListener).onClose(same(status), any(Metadata.class)); verify(mockCallListener).onClose(same(status), any(Metadata.class));
@ -987,7 +987,7 @@ public class ManagedChannelImplTest {
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
inOrder.verify(mockLoadBalancer).handleSubchannelState( inOrder.verify(mockLoadBalancer).handleSubchannelState(
same(subchannel), stateInfoCaptor.capture()); same(subchannel), stateInfoCaptor.capture());
assertEquals(CONNECTING, stateInfoCaptor.getValue().getState()); assertEquals(CONNECTING, stateInfoCaptor.getValue().getState());
@ -1020,7 +1020,7 @@ public class ManagedChannelImplTest {
assertEquals(READY, stateInfoCaptor.getValue().getState()); assertEquals(READY, stateInfoCaptor.getValue().getState());
// A typical LoadBalancer will call this once the subchannel becomes READY // A typical LoadBalancer will call this once the subchannel becomes READY
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
// Delayed transport uses the app executor to create real streams. // Delayed transport uses the app executor to create real streams.
executor.runDueTasks(); executor.runDueTasks();
@ -1069,7 +1069,7 @@ public class ManagedChannelImplTest {
when(picker.pickSubchannel(any(PickSubchannelArgs.class))) when(picker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(drop ? PickResult.withDrop(status) : PickResult.withError(status)); .thenReturn(drop ? PickResult.withDrop(status) : PickResult.withError(status));
helper.updateBalancingState(READY, picker); updateBalancingStateSafely(helper, READY, picker);
executor.runDueTasks(); executor.runDueTasks();
if (shouldFail) { if (shouldFail) {
@ -1137,7 +1137,7 @@ public class ManagedChannelImplTest {
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
inOrder.verify(mockLoadBalancer).handleSubchannelState( inOrder.verify(mockLoadBalancer).handleSubchannelState(
same(subchannel), stateInfoCaptor.capture()); same(subchannel), stateInfoCaptor.capture());
@ -1172,7 +1172,7 @@ public class ManagedChannelImplTest {
SubchannelPicker picker2 = mock(SubchannelPicker.class); SubchannelPicker picker2 = mock(SubchannelPicker.class);
when(picker2.pickSubchannel(any(PickSubchannelArgs.class))) when(picker2.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withError(server2Error)); .thenReturn(PickResult.withError(server2Error));
helper.updateBalancingState(TRANSIENT_FAILURE, picker2); updateBalancingStateSafely(helper, TRANSIENT_FAILURE, picker2);
executor.runDueTasks(); executor.runDueTasks();
// ... which fails the fail-fast call // ... which fails the fail-fast call
@ -1193,14 +1193,24 @@ public class ManagedChannelImplTest {
// createSubchannel() always return a new Subchannel // createSubchannel() always return a new Subchannel
Attributes attrs1 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr1").build(); Attributes attrs1 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr1").build();
Attributes attrs2 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr2").build(); Attributes attrs2 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr2").build();
Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1); final Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1);
Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2); final Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2);
assertNotSame(sub1, sub2); assertNotSame(sub1, sub2);
assertNotSame(attrs1, attrs2); assertNotSame(attrs1, attrs2);
assertSame(attrs1, sub1.getAttributes()); assertSame(attrs1, sub1.getAttributes());
assertSame(attrs2, sub2.getAttributes()); assertSame(attrs2, sub2.getAttributes());
final AtomicBoolean snippetPassed = new AtomicBoolean(false);
helper.getSynchronizationContext().execute(new Runnable() {
@Override
public void run() {
// getAddresses() must be called from sync context
assertSame(addressGroup, sub1.getAddresses()); assertSame(addressGroup, sub1.getAddresses());
assertSame(addressGroup, sub2.getAddresses()); assertSame(addressGroup, sub2.getAddresses());
snippetPassed.set(true);
}
});
assertThat(snippetPassed.get()).isTrue();
// requestConnection() // requestConnection()
verify(mockTransportFactory, never()) verify(mockTransportFactory, never())
@ -1208,7 +1218,7 @@ public class ManagedChannelImplTest {
any(SocketAddress.class), any(SocketAddress.class),
any(ClientTransportOptions.class), any(ClientTransportOptions.class),
any(TransportLogger.class)); any(TransportLogger.class));
sub1.requestConnection(); requestConnectionSafely(helper, sub1);
verify(mockTransportFactory) verify(mockTransportFactory)
.newClientTransport( .newClientTransport(
eq(socketAddress), eq(socketAddress),
@ -1217,7 +1227,7 @@ public class ManagedChannelImplTest {
MockClientTransportInfo transportInfo1 = transports.poll(); MockClientTransportInfo transportInfo1 = transports.poll();
assertNotNull(transportInfo1); assertNotNull(transportInfo1);
sub2.requestConnection(); requestConnectionSafely(helper, sub2);
verify(mockTransportFactory, times(2)) verify(mockTransportFactory, times(2))
.newClientTransport( .newClientTransport(
eq(socketAddress), eq(socketAddress),
@ -1226,17 +1236,17 @@ public class ManagedChannelImplTest {
MockClientTransportInfo transportInfo2 = transports.poll(); MockClientTransportInfo transportInfo2 = transports.poll();
assertNotNull(transportInfo2); assertNotNull(transportInfo2);
sub1.requestConnection(); requestConnectionSafely(helper, sub1);
sub2.requestConnection(); requestConnectionSafely(helper, sub2);
// The subchannel doesn't matter since this isn't called // The subchannel doesn't matter since this isn't called
verify(mockTransportFactory, times(2)) verify(mockTransportFactory, times(2))
.newClientTransport( .newClientTransport(
eq(socketAddress), eq(clientTransportOptions), isA(TransportLogger.class)); eq(socketAddress), eq(clientTransportOptions), isA(TransportLogger.class));
// shutdown() has a delay // shutdown() has a delay
sub1.shutdown(); shutdownSafely(helper, sub1);
timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS - 1, TimeUnit.SECONDS); timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS - 1, TimeUnit.SECONDS);
sub1.shutdown(); shutdownSafely(helper, sub1);
verify(transportInfo1.transport, never()).shutdown(any(Status.class)); verify(transportInfo1.transport, never()).shutdown(any(Status.class));
timer.forwardTime(1, TimeUnit.SECONDS); timer.forwardTime(1, TimeUnit.SECONDS);
verify(transportInfo1.transport).shutdown(same(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_STATUS)); verify(transportInfo1.transport).shutdown(same(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_STATUS));
@ -1247,7 +1257,7 @@ public class ManagedChannelImplTest {
verify(mockLoadBalancer).shutdown(); verify(mockLoadBalancer).shutdown();
verify(transportInfo2.transport, never()).shutdown(any(Status.class)); verify(transportInfo2.transport, never()).shutdown(any(Status.class));
sub2.shutdown(); shutdownSafely(helper, sub2);
verify(transportInfo2.transport).shutdown(same(ManagedChannelImpl.SHUTDOWN_STATUS)); verify(transportInfo2.transport).shutdown(same(ManagedChannelImpl.SHUTDOWN_STATUS));
// Cleanup // Cleanup
@ -1263,8 +1273,8 @@ public class ManagedChannelImplTest {
createChannel(); createChannel();
Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
sub1.requestConnection(); requestConnectionSafely(helper, sub1);
sub2.requestConnection(); requestConnectionSafely(helper, sub2);
assertThat(transports).hasSize(2); assertThat(transports).hasSize(2);
MockClientTransportInfo ti1 = transports.poll(); MockClientTransportInfo ti1 = transports.poll();
@ -1294,9 +1304,9 @@ public class ManagedChannelImplTest {
channel.shutdown(); channel.shutdown();
verify(mockLoadBalancer).shutdown(); verify(mockLoadBalancer).shutdown();
sub1.shutdown(); shutdownSafely(helper, sub1);
assertFalse(channel.isTerminated()); assertFalse(channel.isTerminated());
sub2.shutdown(); shutdownSafely(helper, sub2);
assertTrue(channel.isTerminated()); assertTrue(channel.isTerminated());
verify(mockTransportFactory, never()) verify(mockTransportFactory, never())
.newClientTransport( .newClientTransport(
@ -1499,7 +1509,7 @@ public class ManagedChannelImplTest {
CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(5, TimeUnit.SECONDS); CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(5, TimeUnit.SECONDS);
// Subchannel must be READY when creating the RPC. // Subchannel must be READY when creating the RPC.
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory) verify(mockTransportFactory)
.newClientTransport( .newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
@ -1524,7 +1534,7 @@ public class ManagedChannelImplTest {
Channel sChannel = subchannel.asChannel(); Channel sChannel = subchannel.asChannel();
Metadata headers = new Metadata(); Metadata headers = new Metadata();
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory) verify(mockTransportFactory)
.newClientTransport( .newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
@ -1553,7 +1563,7 @@ public class ManagedChannelImplTest {
Metadata headers = new Metadata(); Metadata headers = new Metadata();
// Subchannel must be READY when creating the RPC. // Subchannel must be READY when creating the RPC.
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory) verify(mockTransportFactory)
.newClientTransport( .newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
@ -1656,7 +1666,7 @@ public class ManagedChannelImplTest {
oobChannel.getSubchannel().requestConnection(); oobChannel.getSubchannel().requestConnection();
} else { } else {
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
} }
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
@ -1743,7 +1753,7 @@ public class ManagedChannelImplTest {
// Simulate name resolution results // Simulate name resolution results
EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(socketAddress); EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(socketAddress);
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory) verify(mockTransportFactory)
.newClientTransport( .newClientTransport(
same(socketAddress), eq(clientTransportOptions), any(ChannelLogger.class)); same(socketAddress), eq(clientTransportOptions), any(ChannelLogger.class));
@ -1766,7 +1776,7 @@ public class ManagedChannelImplTest {
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
executor.runDueTasks(); executor.runDueTasks();
ArgumentCaptor<RequestInfo> infoCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<RequestInfo> infoCaptor = ArgumentCaptor.forClass(null);
ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
@ -1816,7 +1826,7 @@ public class ManagedChannelImplTest {
ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class);
createChannel(); createChannel();
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
ClientTransport mockTransport = transportInfo.transport; ClientTransport mockTransport = transportInfo.transport;
@ -1826,7 +1836,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel, factory2)); PickResult.withSubchannel(subchannel, factory2));
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
CallOptions callOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1); CallOptions callOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1);
ClientCall<String, Integer> call = channel.newCall(method, callOptions); ClientCall<String, Integer> call = channel.newCall(method, callOptions);
@ -1854,7 +1864,7 @@ public class ManagedChannelImplTest {
call.start(mockCallListener, new Metadata()); call.start(mockCallListener, new Metadata());
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
ClientTransport mockTransport = transportInfo.transport; ClientTransport mockTransport = transportInfo.transport;
@ -1864,7 +1874,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel, factory2)); PickResult.withSubchannel(subchannel, factory2));
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
assertEquals(1, executor.runDueTasks()); assertEquals(1, executor.runDueTasks());
verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class));
@ -1884,7 +1894,7 @@ public class ManagedChannelImplTest {
createChannel(); createChannel();
assertEquals(IDLE, channel.getState(false)); assertEquals(IDLE, channel.getState(false));
helper.updateBalancingState(TRANSIENT_FAILURE, mockPicker); updateBalancingStateSafely(helper, TRANSIENT_FAILURE, mockPicker);
assertEquals(TRANSIENT_FAILURE, channel.getState(false)); assertEquals(TRANSIENT_FAILURE, channel.getState(false));
} }
@ -1904,7 +1914,7 @@ public class ManagedChannelImplTest {
verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture());
helper = helperCaptor.getValue(); helper = helperCaptor.getValue();
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
assertEquals(CONNECTING, channel.getState(false)); assertEquals(CONNECTING, channel.getState(false));
assertEquals(CONNECTING, channel.getState(true)); assertEquals(CONNECTING, channel.getState(true));
verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class));
@ -1917,7 +1927,7 @@ public class ManagedChannelImplTest {
createChannel(); createChannel();
verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class));
helper.updateBalancingState(IDLE, mockPicker); updateBalancingStateSafely(helper, IDLE, mockPicker);
assertEquals(IDLE, channel.getState(true)); assertEquals(IDLE, channel.getState(true));
verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class));
@ -1944,7 +1954,7 @@ public class ManagedChannelImplTest {
assertFalse(stateChanged.get()); assertFalse(stateChanged.get());
// state change from IDLE to CONNECTING // state change from IDLE to CONNECTING
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
// onStateChanged callback should run // onStateChanged callback should run
executor.runDueTasks(); executor.runDueTasks();
assertTrue(stateChanged.get()); assertTrue(stateChanged.get());
@ -1982,7 +1992,7 @@ public class ManagedChannelImplTest {
stateChanged.set(false); stateChanged.set(false);
channel.notifyWhenStateChanged(SHUTDOWN, onStateChanged); channel.notifyWhenStateChanged(SHUTDOWN, onStateChanged);
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
assertEquals(SHUTDOWN, channel.getState(false)); assertEquals(SHUTDOWN, channel.getState(false));
executor.runDueTasks(); executor.runDueTasks();
@ -1996,7 +2006,7 @@ public class ManagedChannelImplTest {
createChannel(); createChannel();
assertEquals(IDLE, channel.getState(false)); assertEquals(IDLE, channel.getState(false));
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
assertEquals(CONNECTING, channel.getState(false)); assertEquals(CONNECTING, channel.getState(false));
timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis));
@ -2040,7 +2050,7 @@ public class ManagedChannelImplTest {
if (initialState == IDLE) { if (initialState == IDLE) {
timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis));
} else { } else {
helper.updateBalancingState(initialState, mockPicker); updateBalancingStateSafely(helper, initialState, mockPicker);
} }
assertEquals(initialState, channel.getState(false)); assertEquals(initialState, channel.getState(false));
@ -2083,7 +2093,7 @@ public class ManagedChannelImplTest {
// A misbehaving balancer that calls updateBalancingState() after it's shut down will not be // A misbehaving balancer that calls updateBalancingState() after it's shut down will not be
// able to revive it. // able to revive it.
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
verifyPanicMode(panicReason); verifyPanicMode(panicReason);
// Cannot be revived by exitIdleMode() // Cannot be revived by exitIdleMode()
@ -2106,7 +2116,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withNoResult()); .thenReturn(PickResult.withNoResult());
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
// Start RPCs that will be buffered in delayedTransport // Start RPCs that will be buffered in delayedTransport
ClientCall<String, Integer> call = ClientCall<String, Integer> call =
@ -2180,10 +2190,10 @@ public class ManagedChannelImplTest {
// Updating on the old helper (whose balancer has been shutdown) does not change the channel // Updating on the old helper (whose balancer has been shutdown) does not change the channel
// state. // state.
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
assertEquals(IDLE, channel.getState(false)); assertEquals(IDLE, channel.getState(false));
helper2.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper2, CONNECTING, mockPicker);
assertEquals(CONNECTING, channel.getState(false)); assertEquals(CONNECTING, channel.getState(false));
} }
@ -2207,7 +2217,7 @@ public class ManagedChannelImplTest {
// Move channel into TRANSIENT_FAILURE, which will fail the pending call // Move channel into TRANSIENT_FAILURE, which will fail the pending call
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withError(pickError)); .thenReturn(PickResult.withError(pickError));
helper.updateBalancingState(TRANSIENT_FAILURE, mockPicker); updateBalancingStateSafely(helper, TRANSIENT_FAILURE, mockPicker);
assertEquals(TRANSIENT_FAILURE, channel.getState(false)); assertEquals(TRANSIENT_FAILURE, channel.getState(false));
executor.runDueTasks(); executor.runDueTasks();
verify(mockCallListener).onClose(same(pickError), any(Metadata.class)); verify(mockCallListener).onClose(same(pickError), any(Metadata.class));
@ -2229,7 +2239,7 @@ public class ManagedChannelImplTest {
// Establish a connection // Establish a connection
Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
ConnectionClientTransport mockTransport = transportInfo.transport; ConnectionClientTransport mockTransport = transportInfo.transport;
ManagedClientTransport.Listener transportListener = transportInfo.listener; ManagedClientTransport.Listener transportListener = transportInfo.listener;
@ -2239,7 +2249,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
helper2.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper2, READY, mockPicker);
assertEquals(READY, channel.getState(false)); assertEquals(READY, channel.getState(false));
executor.runDueTasks(); executor.runDueTasks();
@ -2251,7 +2261,7 @@ public class ManagedChannelImplTest {
@Test @Test
public void enterIdleEntersIdle() { public void enterIdleEntersIdle() {
createChannel(); createChannel();
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
assertEquals(READY, channel.getState(false)); assertEquals(READY, channel.getState(false));
channel.enterIdle(); channel.enterIdle();
@ -2297,7 +2307,7 @@ public class ManagedChannelImplTest {
// Establish a connection // Establish a connection
Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
ConnectionClientTransport mockTransport = transportInfo.transport; ConnectionClientTransport mockTransport = transportInfo.transport;
@ -2307,7 +2317,7 @@ public class ManagedChannelImplTest {
transportListener.transportReady(); transportListener.transportReady();
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
helper2.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper2, READY, mockPicker);
assertEquals(READY, channel.getState(false)); assertEquals(READY, channel.getState(false));
// Verify the original call was drained // Verify the original call was drained
@ -2327,7 +2337,7 @@ public class ManagedChannelImplTest {
// Make the transport available with subchannel2 // Make the transport available with subchannel2
Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel2.requestConnection(); requestConnectionSafely(helper, subchannel2);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
ConnectionClientTransport mockTransport = transportInfo.transport; ConnectionClientTransport mockTransport = transportInfo.transport;
@ -2338,7 +2348,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel1)); .thenReturn(PickResult.withSubchannel(subchannel1));
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
executor.runDueTasks(); executor.runDueTasks();
verify(mockTransport, never()) verify(mockTransport, never())
@ -2348,7 +2358,7 @@ public class ManagedChannelImplTest {
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel2)); .thenReturn(PickResult.withSubchannel(subchannel2));
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
executor.runDueTasks(); executor.runDueTasks();
verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class));
@ -2365,7 +2375,7 @@ public class ManagedChannelImplTest {
Runnable onStateChanged = mock(Runnable.class); Runnable onStateChanged = mock(Runnable.class);
channel.notifyWhenStateChanged(IDLE, onStateChanged); channel.notifyWhenStateChanged(IDLE, onStateChanged);
helper.updateBalancingState(SHUTDOWN, mockPicker); updateBalancingStateSafely(helper, SHUTDOWN, mockPicker);
assertEquals(IDLE, channel.getState(false)); assertEquals(IDLE, channel.getState(false));
executor.runDueTasks(); executor.runDueTasks();
@ -2381,7 +2391,7 @@ public class ManagedChannelImplTest {
FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.get(0); FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.get(0);
int initialRefreshCount = resolver.refreshCalled; int initialRefreshCount = resolver.refreshCalled;
helper.refreshNameResolution(); refreshNameResolutionSafely(helper);
assertEquals(initialRefreshCount + 1, resolver.refreshCalled); assertEquals(initialRefreshCount + 1, resolver.refreshCalled);
} }
@ -2648,7 +2658,7 @@ public class ManagedChannelImplTest {
channelBuilder.maxTraceEvents(10); channelBuilder.maxTraceEvents(10);
createChannel(); createChannel();
timer.forwardNanos(1234); timer.forwardNanos(1234);
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder()
.setDescription("Entering CONNECTING state") .setDescription("Entering CONNECTING state")
.setSeverity(ChannelTrace.Event.Severity.CT_INFO) .setSeverity(ChannelTrace.Event.Severity.CT_INFO)
@ -2720,14 +2730,14 @@ public class ManagedChannelImplTest {
helper = helperCaptor.getValue(); helper = helperCaptor.getValue();
assertEquals(IDLE, getStats(channel).state); assertEquals(IDLE, getStats(channel).state);
helper.updateBalancingState(CONNECTING, mockPicker); updateBalancingStateSafely(helper, CONNECTING, mockPicker);
assertEquals(CONNECTING, getStats(channel).state); assertEquals(CONNECTING, getStats(channel).state);
AbstractSubchannel subchannel = AbstractSubchannel subchannel =
(AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
assertEquals(IDLE, getStats(subchannel).state); assertEquals(IDLE, getStats(subchannel).state);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
assertEquals(CONNECTING, getStats(subchannel).state); assertEquals(CONNECTING, getStats(subchannel).state);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
@ -2737,7 +2747,7 @@ public class ManagedChannelImplTest {
assertEquals(READY, getStats(subchannel).state); assertEquals(READY, getStats(subchannel).state);
assertEquals(CONNECTING, getStats(channel).state); assertEquals(CONNECTING, getStats(channel).state);
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
assertEquals(READY, getStats(channel).state); assertEquals(READY, getStats(channel).state);
channel.shutdownNow(); channel.shutdownNow();
@ -2779,7 +2789,7 @@ public class ManagedChannelImplTest {
ClientStreamTracer.Factory factory = mock(ClientStreamTracer.Factory.class); ClientStreamTracer.Factory factory = mock(ClientStreamTracer.Factory.class);
AbstractSubchannel subchannel = AbstractSubchannel subchannel =
(AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
ClientTransport mockTransport = transportInfo.transport; ClientTransport mockTransport = transportInfo.transport;
@ -2791,7 +2801,7 @@ public class ManagedChannelImplTest {
// subchannel stat bumped when call gets assigned to it // subchannel stat bumped when call gets assigned to it
assertEquals(0, getStats(subchannel).callsStarted); assertEquals(0, getStats(subchannel).callsStarted);
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
assertEquals(1, executor.runDueTasks()); assertEquals(1, executor.runDueTasks());
verify(mockStream).start(streamListenerCaptor.capture()); verify(mockStream).start(streamListenerCaptor.capture());
assertEquals(1, getStats(subchannel).callsStarted); assertEquals(1, getStats(subchannel).callsStarted);
@ -3018,7 +3028,7 @@ public class ManagedChannelImplTest {
Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
ConnectionClientTransport mockTransport = transportInfo.transport; ConnectionClientTransport mockTransport = transportInfo.transport;
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
@ -3026,7 +3036,7 @@ public class ManagedChannelImplTest {
when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class)))
.thenReturn(mockStream).thenReturn(mockStream2); .thenReturn(mockStream).thenReturn(mockStream2);
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
ArgumentCaptor<ClientStreamListener> streamListenerCaptor = ArgumentCaptor<ClientStreamListener> streamListenerCaptor =
ArgumentCaptor.forClass(ClientStreamListener.class); ArgumentCaptor.forClass(ClientStreamListener.class);
@ -3067,7 +3077,7 @@ public class ManagedChannelImplTest {
streamListenerCaptor.getValue().closed(Status.INTERNAL, new Metadata()); streamListenerCaptor.getValue().closed(Status.INTERNAL, new Metadata());
verify(mockLoadBalancer).shutdown(); verify(mockLoadBalancer).shutdown();
// simulating the shutdown of load balancer triggers the shutdown of subchannel // simulating the shutdown of load balancer triggers the shutdown of subchannel
subchannel.shutdown(); shutdownSafely(helper, subchannel);
transportInfo.listener.transportTerminated(); // simulating transport terminated transportInfo.listener.transportTerminated(); // simulating transport terminated
assertTrue( assertTrue(
"channel.isTerminated() is expected to be true but was false", "channel.isTerminated() is expected to be true but was false",
@ -3114,10 +3124,10 @@ public class ManagedChannelImplTest {
.build()); .build());
// simulating request connection and then transport ready after resolved address // simulating request connection and then transport ready after resolved address
Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel)); .thenReturn(PickResult.withSubchannel(subchannel));
subchannel.requestConnection(); requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
ConnectionClientTransport mockTransport = transportInfo.transport; ConnectionClientTransport mockTransport = transportInfo.transport;
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
@ -3125,7 +3135,7 @@ public class ManagedChannelImplTest {
when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class)))
.thenReturn(mockStream).thenReturn(mockStream2); .thenReturn(mockStream).thenReturn(mockStream2);
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
helper.updateBalancingState(READY, mockPicker); updateBalancingStateSafely(helper, READY, mockPicker);
ArgumentCaptor<ClientStreamListener> streamListenerCaptor = ArgumentCaptor<ClientStreamListener> streamListenerCaptor =
ArgumentCaptor.forClass(ClientStreamListener.class); ArgumentCaptor.forClass(ClientStreamListener.class);
@ -3165,7 +3175,7 @@ public class ManagedChannelImplTest {
assertThat(timer.numPendingTasks()).isEqualTo(0); assertThat(timer.numPendingTasks()).isEqualTo(0);
verify(mockLoadBalancer).shutdown(); verify(mockLoadBalancer).shutdown();
// simulating the shutdown of load balancer triggers the shutdown of subchannel // simulating the shutdown of load balancer triggers the shutdown of subchannel
subchannel.shutdown(); shutdownSafely(helper, subchannel);
transportInfo.listener.transportTerminated(); // simulating transport terminated transportInfo.listener.transportTerminated(); // simulating transport terminated
assertTrue( assertTrue(
"channel.isTerminated() is expected to be true but was false", "channel.isTerminated() is expected to be true but was false",
@ -3899,7 +3909,7 @@ public class ManagedChannelImplTest {
return Iterables.getOnlyElement(timer.getPendingTasks(NAME_RESOLVER_REFRESH_TASK_FILTER), null); return Iterables.getOnlyElement(timer.getPendingTasks(NAME_RESOLVER_REFRESH_TASK_FILTER), null);
} }
// We need this because createSubchannel() should be called from the SynchronizationContext // Helper methods to call methods from SynchronizationContext
private static Subchannel createSubchannelSafely( private static Subchannel createSubchannelSafely(
final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) {
final AtomicReference<Subchannel> resultCapture = new AtomicReference<>(); final AtomicReference<Subchannel> resultCapture = new AtomicReference<>();
@ -3913,6 +3923,48 @@ public class ManagedChannelImplTest {
return resultCapture.get(); return resultCapture.get();
} }
private static void requestConnectionSafely(Helper helper, final Subchannel subchannel) {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
subchannel.requestConnection();
}
});
}
private static void updateBalancingStateSafely(
final Helper helper, final ConnectivityState state, final SubchannelPicker picker) {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
helper.updateBalancingState(state, picker);
}
});
}
private static void refreshNameResolutionSafely(final Helper helper) {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
helper.refreshNameResolution();
}
});
}
private static void shutdownSafely(
final Helper helper, final Subchannel subchannel) {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
subchannel.shutdown();
}
});
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static Map<String, Object> parseConfig(String json) throws Exception { private static Map<String, Object> parseConfig(String json) throws Exception {
return (Map<String, Object>) JsonParser.parse(json); return (Map<String, Object>) JsonParser.parse(json);