diff --git a/core/src/main/java/io/grpc/LoadBalancer.java b/core/src/main/java/io/grpc/LoadBalancer.java index 00a7d964c1..989077eeff 100644 --- a/core/src/main/java/io/grpc/LoadBalancer.java +++ b/core/src/main/java/io/grpc/LoadBalancer.java @@ -477,6 +477,10 @@ public abstract class LoadBalancer { * Subchannel, and can be accessed later through {@link Subchannel#getAttributes * Subchannel.getAttributes()}. * + *

It is recommended you call this method from the Synchronization Context, otherwise your + * logic around the creation may race with {@link #handleSubchannelState}. See + * #5015 for more discussions. + * *

The LoadBalancer is responsible for closing unused Subchannels, and closing all * Subchannels within {@link #shutdown}. * diff --git a/core/src/main/java/io/grpc/SynchronizationContext.java b/core/src/main/java/io/grpc/SynchronizationContext.java index 55ee8873c6..9cbebe3934 100644 --- a/core/src/main/java/io/grpc/SynchronizationContext.java +++ b/core/src/main/java/io/grpc/SynchronizationContext.java @@ -17,6 +17,7 @@ package io.grpc; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import java.lang.Thread.UncaughtExceptionHandler; import java.util.ArrayDeque; @@ -57,7 +58,7 @@ public final class SynchronizationContext implements Executor { @GuardedBy("lock") private final Queue queue = new ArrayDeque(); @GuardedBy("lock") - private boolean draining; + private Thread drainingThread; /** * Creates a SynchronizationContext. @@ -84,15 +85,15 @@ public final class SynchronizationContext implements Executor { Runnable runnable; synchronized (lock) { if (!drainLeaseAcquired) { - if (draining) { + if (drainingThread != null) { return; } - draining = true; + drainingThread = Thread.currentThread(); drainLeaseAcquired = true; } runnable = queue.poll(); if (runnable == null) { - draining = false; + drainingThread = null; break; } } @@ -129,6 +130,18 @@ public final class SynchronizationContext implements Executor { drain(); } + /** + * Throw {@link IllegalStateException} if this method is not called from this synchronization + * context. + */ + public void throwIfNotInThisSynchronizationContext() { + synchronized (lock) { + checkState( + Thread.currentThread() == drainingThread, + "Not called from the SynchronizationContext"); + } + } + /** * Schedules a task to be added and run via {@link #execute} after a delay. * diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 679794d10f..9863a57c1d 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -1026,6 +1026,14 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public AbstractSubchannel createSubchannel( List addressGroups, Attributes attrs) { + try { + syncContext.throwIfNotInThisSynchronizationContext(); + } catch (IllegalStateException e) { + logger.log(Level.WARNING, + "We sugguest you call createSubchannel() from SynchronizationContext." + + " Otherwise, it may race with handleSubchannelState()." + + " See https://github.com/grpc/grpc-java/issues/5015", e); + } checkNotNull(addressGroups, "addressGroups"); checkNotNull(attrs, "attrs"); // TODO(ejona): can we be even stricter? Like loadBalancer == null? diff --git a/core/src/test/java/io/grpc/SynchronizationContextTest.java b/core/src/test/java/io/grpc/SynchronizationContextTest.java index 477e4e57db..66dfdbd4ee 100644 --- a/core/src/test/java/io/grpc/SynchronizationContextTest.java +++ b/core/src/test/java/io/grpc/SynchronizationContextTest.java @@ -19,6 +19,7 @@ package io.grpc; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.never; @@ -152,6 +153,61 @@ public class SynchronizationContextTest { assertSame(sideThread, task2Thread.get()); } + @Test + public void throwIfNotInThisSynchronizationContext() throws Exception { + final AtomicBoolean taskSuccess = new AtomicBoolean(false); + final CountDownLatch task1Running = new CountDownLatch(1); + final CountDownLatch task1Proceed = new CountDownLatch(1); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + task1Running.countDown(); + syncContext.throwIfNotInThisSynchronizationContext(); + try { + assertTrue(task1Proceed.await(5, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + taskSuccess.set(true); + return null; + } + }).when(task1).run(); + + Thread sideThread = new Thread() { + @Override + public void run() { + syncContext.execute(task1); + } + }; + sideThread.start(); + + assertThat(task1Running.await(5, TimeUnit.SECONDS)).isTrue(); + + // syncContext is draining, but the current thread is not in the context + try { + syncContext.throwIfNotInThisSynchronizationContext(); + fail("Should throw"); + } catch (IllegalStateException e) { + assertThat(e.getMessage()).isEqualTo("Not called from the SynchronizationContext"); + } + + // Let task1 finish + task1Proceed.countDown(); + sideThread.join(); + + // throwIfNotInThisSynchronizationContext() didn't throw in task1 + assertThat(taskSuccess.get()).isTrue(); + + // syncContext is not draining, but the current thread is not in the context + try { + syncContext.throwIfNotInThisSynchronizationContext(); + fail("Should throw"); + } catch (IllegalStateException e) { + assertThat(e.getMessage()).isEqualTo("Not called from the SynchronizationContext"); + } + } + @Test public void taskThrows() { InOrder inOrder = inOrder(task1, task2, task3); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index d0ced5faae..782da4446a 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -62,6 +62,7 @@ import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -262,7 +263,7 @@ public class ManagedChannelImplIdlenessTest { assertTrue(channel.inUseStateAggregator.isInUse()); // Assume LoadBalancer has received an address, then create a subchannel. - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); MockClientTransportInfo t0 = newTransports.poll(); t0.listener.transportReady(); @@ -301,7 +302,7 @@ public class ManagedChannelImplIdlenessTest { ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); verify(mockLoadBalancerFactory).newLoadBalancer(helperCaptor.capture()); Helper helper = helperCaptor.getValue(); - Subchannel subchannel = helper.createSubchannel(servers.get(0), Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY); subchannel.requestConnection(); MockClientTransportInfo t0 = newTransports.poll(); @@ -321,7 +322,7 @@ public class ManagedChannelImplIdlenessTest { ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); verify(mockLoadBalancerFactory).newLoadBalancer(helperCaptor.capture()); Helper helper = helperCaptor.getValue(); - Subchannel subchannel = helper.createSubchannel(servers.get(0), Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY); subchannel.requestConnection(); MockClientTransportInfo t0 = newTransports.poll(); @@ -449,4 +450,18 @@ public class ManagedChannelImplIdlenessTest { return "FakeSocketAddress-" + name; } } + + // We need this because createSubchannel() should be called from the SynchronizationContext + private static Subchannel createSubchannelSafely( + final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { + final AtomicReference resultCapture = new AtomicReference(); + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + resultCapture.set(helper.createSubchannel(addressGroup, attrs)); + } + }); + return resultCapture.get(); + } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 2c9025e240..53ad75419a 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -112,6 +112,10 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; import javax.annotation.Nullable; import org.junit.After; import org.junit.Assert; @@ -278,6 +282,38 @@ public class ManagedChannelImplTest { } } + @Test + public void createSubchannelOutsideSynchronizationContextShouldLogWarning() { + createChannel(); + final AtomicReference logRef = new AtomicReference(); + Handler handler = new Handler() { + @Override + public void publish(LogRecord record) { + logRef.set(record); + } + + @Override + public void flush() { + } + + @Override + public void close() throws SecurityException { + } + }; + Logger logger = Logger.getLogger(ManagedChannelImpl.class.getName()); + try { + logger.addHandler(handler); + helper.createSubchannel(addressGroup, Attributes.EMPTY); + LogRecord record = logRef.get(); + assertThat(record.getLevel()).isEqualTo(Level.WARNING); + assertThat(record.getMessage()).contains( + "We sugguest you call createSubchannel() from SynchronizationContext"); + assertThat(record.getThrown()).isInstanceOf(IllegalStateException.class); + } finally { + logger.removeHandler(handler); + } + } + @Test @SuppressWarnings("unchecked") public void idleModeDisabled() { @@ -338,7 +374,7 @@ public class ManagedChannelImplTest { assertNotNull(channelz.getRootChannel(channel.getLogId().getId())); AbstractSubchannel subchannel = - (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); // subchannels are not root channels assertNull(channelz.getRootChannel(subchannel.getInternalSubchannel().getLogId().getId())); assertTrue(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId())); @@ -430,7 +466,7 @@ public class ManagedChannelImplTest { // Configure the picker so that first RPC goes to delayed transport, and second RPC goes to // real transport. - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); verify(mockTransportFactory) .newClientTransport(any(SocketAddress.class), any(ClientTransportOptions.class)); @@ -563,8 +599,8 @@ public class ManagedChannelImplTest { verify(mockLoadBalancer).handleResolvedAddressGroups( eq(Arrays.asList(addressGroup)), eq(Attributes.EMPTY)); - Subchannel subchannel1 = helper.createSubchannel(addressGroup, Attributes.EMPTY); - Subchannel subchannel2 = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel1.requestConnection(); subchannel2.requestConnection(); verify(mockTransportFactory, times(2)) @@ -629,7 +665,7 @@ public class ManagedChannelImplTest { call.start(mockCallListener, headers); // Make the transport available - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); verify(mockTransportFactory, never()) .newClientTransport(any(SocketAddress.class), any(ClientTransportOptions.class)); subchannel.requestConnection(); @@ -845,7 +881,7 @@ public class ManagedChannelImplTest { EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(resolvedAddrs); inOrder.verify(mockLoadBalancer).handleResolvedAddressGroups( eq(Arrays.asList(addressGroup)), eq(Attributes.EMPTY)); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); subchannel.requestConnection(); @@ -990,7 +1026,7 @@ public class ManagedChannelImplTest { EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(resolvedAddrs); inOrder.verify(mockLoadBalancer).handleResolvedAddressGroups( eq(Arrays.asList(addressGroup)), eq(Attributes.EMPTY)); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); subchannel.requestConnection(); @@ -1046,8 +1082,8 @@ public class ManagedChannelImplTest { // createSubchannel() always return a new Subchannel Attributes attrs1 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr1").build(); Attributes attrs2 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr2").build(); - Subchannel sub1 = helper.createSubchannel(addressGroup, attrs1); - Subchannel sub2 = helper.createSubchannel(addressGroup, attrs2); + Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1); + Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2); assertNotSame(sub1, sub2); assertNotSame(attrs1, attrs2); assertSame(attrs1, sub1.getAttributes()); @@ -1102,8 +1138,8 @@ public class ManagedChannelImplTest { @Test public void subchannelsWhenChannelShutdownNow() { createChannel(); - Subchannel sub1 = helper.createSubchannel(addressGroup, Attributes.EMPTY); - Subchannel sub2 = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); sub1.requestConnection(); sub2.requestConnection(); @@ -1130,8 +1166,8 @@ public class ManagedChannelImplTest { @Test public void subchannelsNoConnectionShutdown() { createChannel(); - Subchannel sub1 = helper.createSubchannel(addressGroup, Attributes.EMPTY); - Subchannel sub2 = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); channel.shutdown(); verify(mockLoadBalancer).shutdown(); @@ -1146,8 +1182,8 @@ public class ManagedChannelImplTest { @Test public void subchannelsNoConnectionShutdownNow() { createChannel(); - helper.createSubchannel(addressGroup, Attributes.EMPTY); - helper.createSubchannel(addressGroup, Attributes.EMPTY); + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); channel.shutdownNow(); verify(mockLoadBalancer).shutdown(); @@ -1324,7 +1360,7 @@ public class ManagedChannelImplTest { @Test public void subchannelChannel_normalUsage() { createChannel(); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); verify(balancerRpcExecutorPool, never()).getObject(); Channel sChannel = subchannel.asChannel(); @@ -1354,7 +1390,7 @@ public class ManagedChannelImplTest { @Test public void subchannelChannel_failWhenNotReady() { createChannel(); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Channel sChannel = subchannel.asChannel(); Metadata headers = new Metadata(); @@ -1381,7 +1417,7 @@ public class ManagedChannelImplTest { @Test public void subchannelChannel_failWaitForReady() { createChannel(); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); Channel sChannel = subchannel.asChannel(); Metadata headers = new Metadata(); @@ -1458,7 +1494,7 @@ public class ManagedChannelImplTest { OobChannel oobChannel = (OobChannel) helper.createOobChannel(addressGroup, "oobAuthority"); oobChannel.getSubchannel().requestConnection(); } else { - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); } @@ -1527,7 +1563,7 @@ public class ManagedChannelImplTest { // Simulate name resolution results EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(socketAddress); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); verify(mockTransportFactory) .newClientTransport(same(socketAddress), eq(clientTransportOptions)); @@ -1603,7 +1639,7 @@ public class ManagedChannelImplTest { ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); createChannel(); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); @@ -1641,7 +1677,7 @@ public class ManagedChannelImplTest { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, new Metadata()); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); @@ -2004,7 +2040,7 @@ public class ManagedChannelImplTest { Helper helper2 = helperCaptor.getValue(); // Establish a connection - Subchannel subchannel = helper2.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; @@ -2072,7 +2108,7 @@ public class ManagedChannelImplTest { Helper helper2 = helperCaptor.getValue(); // Establish a connection - Subchannel subchannel = helper2.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); ClientStream mockStream = mock(ClientStream.class); MockClientTransportInfo transportInfo = transports.poll(); @@ -2101,8 +2137,8 @@ public class ManagedChannelImplTest { call.start(mockCallListener, new Metadata()); // Make the transport available with subchannel2 - Subchannel subchannel1 = helper.createSubchannel(addressGroup, Attributes.EMPTY); - Subchannel subchannel2 = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel2.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); @@ -2228,7 +2264,7 @@ public class ManagedChannelImplTest { createChannel(); assertEquals(TARGET, getStats(channel).target); - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); assertEquals(Collections.singletonList(addressGroup).toString(), getStats((AbstractSubchannel) subchannel).target); } @@ -2251,7 +2287,7 @@ public class ManagedChannelImplTest { createChannel(); timer.forwardNanos(1234); AbstractSubchannel subchannel = - (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() .setDescription("Child channel created") .setSeverity(ChannelTrace.Event.Severity.CT_INFO) @@ -2403,7 +2439,7 @@ public class ManagedChannelImplTest { channelBuilder.maxTraceEvents(10); createChannel(); AbstractSubchannel subchannel = - (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); timer.forwardNanos(1234); subchannel.obtainActiveTransport(); assertThat(getStats(subchannel).channelTrace.events).contains(new ChannelTrace.Event.Builder() @@ -2466,7 +2502,7 @@ public class ManagedChannelImplTest { assertEquals(CONNECTING, getStats(channel).state); AbstractSubchannel subchannel = - (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); assertEquals(IDLE, getStats(subchannel).state); subchannel.requestConnection(); @@ -2520,7 +2556,7 @@ public class ManagedChannelImplTest { ClientStream mockStream = mock(ClientStream.class); ClientStreamTracer.Factory factory = mock(ClientStreamTracer.Factory.class); AbstractSubchannel subchannel = - (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); @@ -2754,7 +2790,7 @@ public class ManagedChannelImplTest { .handleResolvedAddressGroups(nameResolverFactory.servers, attributesWithRetryPolicy); // simulating request connection and then transport ready after resolved address - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); subchannel.requestConnection(); @@ -3076,4 +3112,18 @@ public class ManagedChannelImplTest { private FakeClock.ScheduledTask getNameResolverRefresh() { return Iterables.getOnlyElement(timer.getPendingTasks(NAME_RESOLVER_REFRESH_TASK_FILTER), null); } + + // We need this because createSubchannel() should be called from the SynchronizationContext + private static Subchannel createSubchannelSafely( + final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { + final AtomicReference resultCapture = new AtomicReference(); + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + resultCapture.set(helper.createSubchannel(addressGroup, attrs)); + } + }); + return resultCapture.get(); + } }