diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index a64ea29e50..debc2d0fff 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -340,7 +340,7 @@ public final class ManagedChannelImpl extends ManagedChannel { TransportSet ts; synchronized (lock) { if (shutdown) { - return null; + return NULL_VALUE_TRANSPORT_FUTURE; } ts = transports.get(addr); if (ts == null) { diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 9c2e880e60..08e3dd9fc6 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -42,6 +42,7 @@ import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -76,6 +77,7 @@ import org.mockito.stubbing.Answer; import java.net.SocketAddress; import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -141,7 +143,8 @@ public class ManagedChannelImplTest { @Test public void immediateDeadlineExceeded() { - ManagedChannel channel = createChannel(new FakeNameResolverFactory(server), NO_INTERCEPTOR); + ManagedChannel channel = createChannel( + new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); ClientCall call = channel.newCall(method, CallOptions.DEFAULT.withDeadlineNanoTime(System.nanoTime())); call.start(mockCallListener, new Metadata()); @@ -151,7 +154,8 @@ public class ManagedChannelImplTest { @Test public void shutdownWithNoTransportsEverCreated() { - ManagedChannel channel = createChannel(new FakeNameResolverFactory(server), NO_INTERCEPTOR); + ManagedChannel channel = createChannel( + new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); verifyNoMoreInteractions(mockTransportFactory); channel.shutdown(); assertTrue(channel.isShutdown()); @@ -161,7 +165,7 @@ public class ManagedChannelImplTest { @Test public void twoCallsAndGracefulShutdown() { ManagedChannel channel = createChannel( - new FakeNameResolverFactory(server), Collections.emptyList()); + new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); verifyNoMoreInteractions(mockTransportFactory); ClientCall call = channel.newCall(method, CallOptions.DEFAULT); verifyNoMoreInteractions(mockTransportFactory); @@ -241,14 +245,15 @@ public class ManagedChannelImplTest { } }; ManagedChannel channel = createChannel( - new FakeNameResolverFactory(server), Arrays.asList(interceptor)); + new FakeNameResolverFactory(server, true), Arrays.asList(interceptor)); assertNotNull(channel.newCall(method, CallOptions.DEFAULT)); assertEquals(1, atomic.get()); } @Test public void testNoDeadlockOnShutdown() { - ManagedChannel channel = createChannel(new FakeNameResolverFactory(server), NO_INTERCEPTOR); + ManagedChannel channel = createChannel( + new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); // Force creation of transport ClientCall call = channel.newCall(method, CallOptions.DEFAULT); Metadata headers = new Metadata(); @@ -309,6 +314,23 @@ public class ManagedChannelImplTest { assertSame(error, status); } + @Test + public void nameResolvedAfterChannelShutdown() { + FakeNameResolverFactory nameResolverFactory = new FakeNameResolverFactory(server, false); + ManagedChannel channel = createChannel(nameResolverFactory, NO_INTERCEPTOR); + ClientCall call = channel.newCall(method, CallOptions.DEFAULT); + Metadata headers = new Metadata(); + call.start(mockCallListener, headers); + channel.shutdown(); + assertTrue(channel.isShutdown()); + assertTrue(channel.isTerminated()); + // Name resolved after the channel is shut down, which is possible if the name resolution takes + // time and is not cancellable. The resolved address will still be passed to the LoadBalancer. + nameResolverFactory.allResolved(); + verify(mockTransportFactory, never()) + .newClientTransport(any(SocketAddress.class), any(String.class)); + } + private static class FakeBackoffPolicyProvider implements BackoffPolicy.Provider { @Override public BackoffPolicy get() { @@ -323,9 +345,12 @@ public class ManagedChannelImplTest { private class FakeNameResolverFactory extends NameResolver.Factory { final ResolvedServerInfo server; + final boolean resolvedAtStart; + final ArrayList resolvers = new ArrayList(); - FakeNameResolverFactory(ResolvedServerInfo server) { + FakeNameResolverFactory(ResolvedServerInfo server, boolean resolvedAtStart) { this.server = server; + this.resolvedAtStart = resolvedAtStart; } @Override @@ -333,17 +358,36 @@ public class ManagedChannelImplTest { assertEquals("fake", targetUri.getScheme()); assertEquals(serviceName, targetUri.getAuthority()); assertSame(NAME_RESOLVER_PARAMS, params); - return new NameResolver() { - @Override public String getServiceAuthority() { - return serviceName; - } + FakeNameResolver resolver = new FakeNameResolver(); + resolvers.add(resolver); + return resolver; + } - @Override public void start(final Listener listener) { - listener.onUpdate(Collections.singletonList(server), Attributes.EMPTY); - } + void allResolved() { + for (FakeNameResolver resolver : resolvers) { + resolver.resolved(); + } + } - @Override public void shutdown() {} - }; + private class FakeNameResolver extends NameResolver { + Listener listener; + + @Override public String getServiceAuthority() { + return serviceName; + } + + @Override public void start(final Listener listener) { + this.listener = listener; + if (resolvedAtStart) { + resolved(); + } + } + + void resolved() { + listener.onUpdate(Collections.singletonList(server), Attributes.EMPTY); + } + + @Override public void shutdown() {} } }