diff --git a/core/src/main/java/io/grpc/LoadBalancer.java b/core/src/main/java/io/grpc/LoadBalancer.java
index 00a7d964c1..989077eeff 100644
--- a/core/src/main/java/io/grpc/LoadBalancer.java
+++ b/core/src/main/java/io/grpc/LoadBalancer.java
@@ -477,6 +477,10 @@ public abstract class LoadBalancer {
* Subchannel, and can be accessed later through {@link Subchannel#getAttributes
* Subchannel.getAttributes()}.
*
+ *
It is recommended you call this method from the Synchronization Context, otherwise your
+ * logic around the creation may race with {@link #handleSubchannelState}. See
+ * #5015 for more discussions.
+ *
*
The LoadBalancer is responsible for closing unused Subchannels, and closing all
* Subchannels within {@link #shutdown}.
*
diff --git a/core/src/main/java/io/grpc/SynchronizationContext.java b/core/src/main/java/io/grpc/SynchronizationContext.java
index 55ee8873c6..9cbebe3934 100644
--- a/core/src/main/java/io/grpc/SynchronizationContext.java
+++ b/core/src/main/java/io/grpc/SynchronizationContext.java
@@ -17,6 +17,7 @@
package io.grpc;
import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
import java.lang.Thread.UncaughtExceptionHandler;
import java.util.ArrayDeque;
@@ -57,7 +58,7 @@ public final class SynchronizationContext implements Executor {
@GuardedBy("lock")
private final Queue queue = new ArrayDeque();
@GuardedBy("lock")
- private boolean draining;
+ private Thread drainingThread;
/**
* Creates a SynchronizationContext.
@@ -84,15 +85,15 @@ public final class SynchronizationContext implements Executor {
Runnable runnable;
synchronized (lock) {
if (!drainLeaseAcquired) {
- if (draining) {
+ if (drainingThread != null) {
return;
}
- draining = true;
+ drainingThread = Thread.currentThread();
drainLeaseAcquired = true;
}
runnable = queue.poll();
if (runnable == null) {
- draining = false;
+ drainingThread = null;
break;
}
}
@@ -129,6 +130,18 @@ public final class SynchronizationContext implements Executor {
drain();
}
+ /**
+ * Throw {@link IllegalStateException} if this method is not called from this synchronization
+ * context.
+ */
+ public void throwIfNotInThisSynchronizationContext() {
+ synchronized (lock) {
+ checkState(
+ Thread.currentThread() == drainingThread,
+ "Not called from the SynchronizationContext");
+ }
+ }
+
/**
* Schedules a task to be added and run via {@link #execute} after a delay.
*
diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
index 679794d10f..9863a57c1d 100644
--- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
+++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
@@ -1026,6 +1026,14 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override
public AbstractSubchannel createSubchannel(
List addressGroups, Attributes attrs) {
+ try {
+ syncContext.throwIfNotInThisSynchronizationContext();
+ } catch (IllegalStateException e) {
+ logger.log(Level.WARNING,
+ "We sugguest you call createSubchannel() from SynchronizationContext."
+ + " Otherwise, it may race with handleSubchannelState()."
+ + " See https://github.com/grpc/grpc-java/issues/5015", e);
+ }
checkNotNull(addressGroups, "addressGroups");
checkNotNull(attrs, "attrs");
// TODO(ejona): can we be even stricter? Like loadBalancer == null?
diff --git a/core/src/test/java/io/grpc/SynchronizationContextTest.java b/core/src/test/java/io/grpc/SynchronizationContextTest.java
index 477e4e57db..66dfdbd4ee 100644
--- a/core/src/test/java/io/grpc/SynchronizationContextTest.java
+++ b/core/src/test/java/io/grpc/SynchronizationContextTest.java
@@ -19,6 +19,7 @@ package io.grpc;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.never;
@@ -152,6 +153,61 @@ public class SynchronizationContextTest {
assertSame(sideThread, task2Thread.get());
}
+ @Test
+ public void throwIfNotInThisSynchronizationContext() throws Exception {
+ final AtomicBoolean taskSuccess = new AtomicBoolean(false);
+ final CountDownLatch task1Running = new CountDownLatch(1);
+ final CountDownLatch task1Proceed = new CountDownLatch(1);
+
+ doAnswer(new Answer() {
+ @Override
+ public Void answer(InvocationOnMock invocation) {
+ task1Running.countDown();
+ syncContext.throwIfNotInThisSynchronizationContext();
+ try {
+ assertTrue(task1Proceed.await(5, TimeUnit.SECONDS));
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ taskSuccess.set(true);
+ return null;
+ }
+ }).when(task1).run();
+
+ Thread sideThread = new Thread() {
+ @Override
+ public void run() {
+ syncContext.execute(task1);
+ }
+ };
+ sideThread.start();
+
+ assertThat(task1Running.await(5, TimeUnit.SECONDS)).isTrue();
+
+ // syncContext is draining, but the current thread is not in the context
+ try {
+ syncContext.throwIfNotInThisSynchronizationContext();
+ fail("Should throw");
+ } catch (IllegalStateException e) {
+ assertThat(e.getMessage()).isEqualTo("Not called from the SynchronizationContext");
+ }
+
+ // Let task1 finish
+ task1Proceed.countDown();
+ sideThread.join();
+
+ // throwIfNotInThisSynchronizationContext() didn't throw in task1
+ assertThat(taskSuccess.get()).isTrue();
+
+ // syncContext is not draining, but the current thread is not in the context
+ try {
+ syncContext.throwIfNotInThisSynchronizationContext();
+ fail("Should throw");
+ } catch (IllegalStateException e) {
+ assertThat(e.getMessage()).isEqualTo("Not called from the SynchronizationContext");
+ }
+ }
+
@Test
public void taskThrows() {
InOrder inOrder = inOrder(task1, task2, task3);
diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java
index d0ced5faae..782da4446a 100644
--- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java
+++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java
@@ -62,6 +62,7 @@ import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -262,7 +263,7 @@ public class ManagedChannelImplIdlenessTest {
assertTrue(channel.inUseStateAggregator.isInUse());
// Assume LoadBalancer has received an address, then create a subchannel.
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
MockClientTransportInfo t0 = newTransports.poll();
t0.listener.transportReady();
@@ -301,7 +302,7 @@ public class ManagedChannelImplIdlenessTest {
ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null);
verify(mockLoadBalancerFactory).newLoadBalancer(helperCaptor.capture());
Helper helper = helperCaptor.getValue();
- Subchannel subchannel = helper.createSubchannel(servers.get(0), Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY);
subchannel.requestConnection();
MockClientTransportInfo t0 = newTransports.poll();
@@ -321,7 +322,7 @@ public class ManagedChannelImplIdlenessTest {
ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null);
verify(mockLoadBalancerFactory).newLoadBalancer(helperCaptor.capture());
Helper helper = helperCaptor.getValue();
- Subchannel subchannel = helper.createSubchannel(servers.get(0), Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, servers.get(0), Attributes.EMPTY);
subchannel.requestConnection();
MockClientTransportInfo t0 = newTransports.poll();
@@ -449,4 +450,18 @@ public class ManagedChannelImplIdlenessTest {
return "FakeSocketAddress-" + name;
}
}
+
+ // We need this because createSubchannel() should be called from the SynchronizationContext
+ private static Subchannel createSubchannelSafely(
+ final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) {
+ final AtomicReference resultCapture = new AtomicReference();
+ helper.getSynchronizationContext().execute(
+ new Runnable() {
+ @Override
+ public void run() {
+ resultCapture.set(helper.createSubchannel(addressGroup, attrs));
+ }
+ });
+ return resultCapture.get();
+ }
}
diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
index 2c9025e240..53ad75419a 100644
--- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
@@ -112,6 +112,10 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.logging.Handler;
+import java.util.logging.Level;
+import java.util.logging.LogRecord;
+import java.util.logging.Logger;
import javax.annotation.Nullable;
import org.junit.After;
import org.junit.Assert;
@@ -278,6 +282,38 @@ public class ManagedChannelImplTest {
}
}
+ @Test
+ public void createSubchannelOutsideSynchronizationContextShouldLogWarning() {
+ createChannel();
+ final AtomicReference logRef = new AtomicReference();
+ Handler handler = new Handler() {
+ @Override
+ public void publish(LogRecord record) {
+ logRef.set(record);
+ }
+
+ @Override
+ public void flush() {
+ }
+
+ @Override
+ public void close() throws SecurityException {
+ }
+ };
+ Logger logger = Logger.getLogger(ManagedChannelImpl.class.getName());
+ try {
+ logger.addHandler(handler);
+ helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ LogRecord record = logRef.get();
+ assertThat(record.getLevel()).isEqualTo(Level.WARNING);
+ assertThat(record.getMessage()).contains(
+ "We sugguest you call createSubchannel() from SynchronizationContext");
+ assertThat(record.getThrown()).isInstanceOf(IllegalStateException.class);
+ } finally {
+ logger.removeHandler(handler);
+ }
+ }
+
@Test
@SuppressWarnings("unchecked")
public void idleModeDisabled() {
@@ -338,7 +374,7 @@ public class ManagedChannelImplTest {
assertNotNull(channelz.getRootChannel(channel.getLogId().getId()));
AbstractSubchannel subchannel =
- (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
// subchannels are not root channels
assertNull(channelz.getRootChannel(subchannel.getInternalSubchannel().getLogId().getId()));
assertTrue(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId()));
@@ -430,7 +466,7 @@ public class ManagedChannelImplTest {
// Configure the picker so that first RPC goes to delayed transport, and second RPC goes to
// real transport.
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
verify(mockTransportFactory)
.newClientTransport(any(SocketAddress.class), any(ClientTransportOptions.class));
@@ -563,8 +599,8 @@ public class ManagedChannelImplTest {
verify(mockLoadBalancer).handleResolvedAddressGroups(
eq(Arrays.asList(addressGroup)), eq(Attributes.EMPTY));
- Subchannel subchannel1 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
- Subchannel subchannel2 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
+ Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel1.requestConnection();
subchannel2.requestConnection();
verify(mockTransportFactory, times(2))
@@ -629,7 +665,7 @@ public class ManagedChannelImplTest {
call.start(mockCallListener, headers);
// Make the transport available
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
verify(mockTransportFactory, never())
.newClientTransport(any(SocketAddress.class), any(ClientTransportOptions.class));
subchannel.requestConnection();
@@ -845,7 +881,7 @@ public class ManagedChannelImplTest {
EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(resolvedAddrs);
inOrder.verify(mockLoadBalancer).handleResolvedAddressGroups(
eq(Arrays.asList(addressGroup)), eq(Attributes.EMPTY));
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel));
subchannel.requestConnection();
@@ -990,7 +1026,7 @@ public class ManagedChannelImplTest {
EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(resolvedAddrs);
inOrder.verify(mockLoadBalancer).handleResolvedAddressGroups(
eq(Arrays.asList(addressGroup)), eq(Attributes.EMPTY));
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel));
subchannel.requestConnection();
@@ -1046,8 +1082,8 @@ public class ManagedChannelImplTest {
// createSubchannel() always return a new Subchannel
Attributes attrs1 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr1").build();
Attributes attrs2 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr2").build();
- Subchannel sub1 = helper.createSubchannel(addressGroup, attrs1);
- Subchannel sub2 = helper.createSubchannel(addressGroup, attrs2);
+ Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1);
+ Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2);
assertNotSame(sub1, sub2);
assertNotSame(attrs1, attrs2);
assertSame(attrs1, sub1.getAttributes());
@@ -1102,8 +1138,8 @@ public class ManagedChannelImplTest {
@Test
public void subchannelsWhenChannelShutdownNow() {
createChannel();
- Subchannel sub1 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
- Subchannel sub2 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
+ Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
sub1.requestConnection();
sub2.requestConnection();
@@ -1130,8 +1166,8 @@ public class ManagedChannelImplTest {
@Test
public void subchannelsNoConnectionShutdown() {
createChannel();
- Subchannel sub1 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
- Subchannel sub2 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
+ Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
channel.shutdown();
verify(mockLoadBalancer).shutdown();
@@ -1146,8 +1182,8 @@ public class ManagedChannelImplTest {
@Test
public void subchannelsNoConnectionShutdownNow() {
createChannel();
- helper.createSubchannel(addressGroup, Attributes.EMPTY);
- helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
+ createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
channel.shutdownNow();
verify(mockLoadBalancer).shutdown();
@@ -1324,7 +1360,7 @@ public class ManagedChannelImplTest {
@Test
public void subchannelChannel_normalUsage() {
createChannel();
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
verify(balancerRpcExecutorPool, never()).getObject();
Channel sChannel = subchannel.asChannel();
@@ -1354,7 +1390,7 @@ public class ManagedChannelImplTest {
@Test
public void subchannelChannel_failWhenNotReady() {
createChannel();
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
Channel sChannel = subchannel.asChannel();
Metadata headers = new Metadata();
@@ -1381,7 +1417,7 @@ public class ManagedChannelImplTest {
@Test
public void subchannelChannel_failWaitForReady() {
createChannel();
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
Channel sChannel = subchannel.asChannel();
Metadata headers = new Metadata();
@@ -1458,7 +1494,7 @@ public class ManagedChannelImplTest {
OobChannel oobChannel = (OobChannel) helper.createOobChannel(addressGroup, "oobAuthority");
oobChannel.getSubchannel().requestConnection();
} else {
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
}
@@ -1527,7 +1563,7 @@ public class ManagedChannelImplTest {
// Simulate name resolution results
EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(socketAddress);
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
verify(mockTransportFactory)
.newClientTransport(same(socketAddress), eq(clientTransportOptions));
@@ -1603,7 +1639,7 @@ public class ManagedChannelImplTest {
ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class);
ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class);
createChannel();
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady();
@@ -1641,7 +1677,7 @@ public class ManagedChannelImplTest {
ClientCall call = channel.newCall(method, callOptions);
call.start(mockCallListener, new Metadata());
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady();
@@ -2004,7 +2040,7 @@ public class ManagedChannelImplTest {
Helper helper2 = helperCaptor.getValue();
// Establish a connection
- Subchannel subchannel = helper2.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
MockClientTransportInfo transportInfo = transports.poll();
ConnectionClientTransport mockTransport = transportInfo.transport;
@@ -2072,7 +2108,7 @@ public class ManagedChannelImplTest {
Helper helper2 = helperCaptor.getValue();
// Establish a connection
- Subchannel subchannel = helper2.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
ClientStream mockStream = mock(ClientStream.class);
MockClientTransportInfo transportInfo = transports.poll();
@@ -2101,8 +2137,8 @@ public class ManagedChannelImplTest {
call.start(mockCallListener, new Metadata());
// Make the transport available with subchannel2
- Subchannel subchannel1 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
- Subchannel subchannel2 = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
+ Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel2.requestConnection();
MockClientTransportInfo transportInfo = transports.poll();
@@ -2228,7 +2264,7 @@ public class ManagedChannelImplTest {
createChannel();
assertEquals(TARGET, getStats(channel).target);
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
assertEquals(Collections.singletonList(addressGroup).toString(),
getStats((AbstractSubchannel) subchannel).target);
}
@@ -2251,7 +2287,7 @@ public class ManagedChannelImplTest {
createChannel();
timer.forwardNanos(1234);
AbstractSubchannel subchannel =
- (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder()
.setDescription("Child channel created")
.setSeverity(ChannelTrace.Event.Severity.CT_INFO)
@@ -2403,7 +2439,7 @@ public class ManagedChannelImplTest {
channelBuilder.maxTraceEvents(10);
createChannel();
AbstractSubchannel subchannel =
- (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
timer.forwardNanos(1234);
subchannel.obtainActiveTransport();
assertThat(getStats(subchannel).channelTrace.events).contains(new ChannelTrace.Event.Builder()
@@ -2466,7 +2502,7 @@ public class ManagedChannelImplTest {
assertEquals(CONNECTING, getStats(channel).state);
AbstractSubchannel subchannel =
- (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
assertEquals(IDLE, getStats(subchannel).state);
subchannel.requestConnection();
@@ -2520,7 +2556,7 @@ public class ManagedChannelImplTest {
ClientStream mockStream = mock(ClientStream.class);
ClientStreamTracer.Factory factory = mock(ClientStreamTracer.Factory.class);
AbstractSubchannel subchannel =
- (AbstractSubchannel) helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
subchannel.requestConnection();
MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady();
@@ -2754,7 +2790,7 @@ public class ManagedChannelImplTest {
.handleResolvedAddressGroups(nameResolverFactory.servers, attributesWithRetryPolicy);
// simulating request connection and then transport ready after resolved address
- Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY);
+ Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY);
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel));
subchannel.requestConnection();
@@ -3076,4 +3112,18 @@ public class ManagedChannelImplTest {
private FakeClock.ScheduledTask getNameResolverRefresh() {
return Iterables.getOnlyElement(timer.getPendingTasks(NAME_RESOLVER_REFRESH_TASK_FILTER), null);
}
+
+ // We need this because createSubchannel() should be called from the SynchronizationContext
+ private static Subchannel createSubchannelSafely(
+ final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) {
+ final AtomicReference resultCapture = new AtomicReference();
+ helper.getSynchronizationContext().execute(
+ new Runnable() {
+ @Override
+ public void run() {
+ resultCapture.set(helper.createSubchannel(addressGroup, attrs));
+ }
+ });
+ return resultCapture.get();
+ }
}