From b867f8e4fcd5e46f6bad1af776340aea0f753cdf Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Wed, 20 Feb 2019 11:45:38 -0800 Subject: [PATCH] core: make NameResolver not thread-safe (#5364) Resolves #2649 As a prerequisite, added `getSynchronizationContext()` to `NameResolver.Helper`. `DnsNameResolver` has gone through a small refactor around the `Resolve` runnable, which makes it a little simpler. --- core/src/main/java/io/grpc/NameResolver.java | 22 ++- .../io/grpc/internal/DnsNameResolver.java | 137 +++++++++--------- .../io/grpc/internal/ManagedChannelImpl.java | 7 + .../internal/DnsNameResolverProviderTest.java | 20 ++- .../io/grpc/internal/DnsNameResolverTest.java | 68 +++++---- .../grpc/internal/ManagedChannelImplTest.java | 7 +- 6 files changed, 159 insertions(+), 102 deletions(-) diff --git a/core/src/main/java/io/grpc/NameResolver.java b/core/src/main/java/io/grpc/NameResolver.java index c188eca0d9..ca3bff90dd 100644 --- a/core/src/main/java/io/grpc/NameResolver.java +++ b/core/src/main/java/io/grpc/NameResolver.java @@ -37,10 +37,14 @@ import javax.annotation.concurrent.ThreadSafe; * {@link Listener} is responsible for eventually (after an appropriate backoff period) invoking * {@link #refresh()}. * + *

Implementations don't need to be thread-safe. All methods are guaranteed to + * be called sequentially. Additionally, all methods that have side-effects, i.e., {@link #start}, + * {@link #shutdown} and {@link #refresh} are called from the same {@link SynchronizationContext} as + * returned by {@link Helper#getSynchronizationContext}. + * * @since 1.0.0 */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") -@ThreadSafe public abstract class NameResolver { /** * Returns the authority used to authenticate connections to servers. It must be @@ -209,18 +213,34 @@ public abstract class NameResolver { /** * A utility object passed to {@link Factory#newNameResolver(URI, NameResolver.Helper)}. + * + * @since 1.19.0 */ public abstract static class Helper { /** * The port number used in case the target or the underlying naming system doesn't provide a * port number. + * + * @since 1.19.0 */ public abstract int getDefaultPort(); /** * If the NameResolver wants to support proxy, it should inquire this {@link ProxyDetector}. * See documentation on {@link ProxyDetector} about how proxies work in gRPC. + * + * @since 1.19.0 */ public abstract ProxyDetector getProxyDetector(); + + /** + * Returns the {@link SynchronizationContext} where {@link #start}, {@link #shutdown} and {@link + * #refresh} are run from. + * + * @since 1.20.0 + */ + public SynchronizationContext getSynchronizationContext() { + throw new UnsupportedOperationException("Not implemented"); + } } } diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolver.java b/core/src/main/java/io/grpc/internal/DnsNameResolver.java index fcbb3bb1b1..0ef208193c 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolver.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolver.java @@ -30,6 +30,7 @@ import io.grpc.NameResolver; import io.grpc.ProxiedSocketAddress; import io.grpc.ProxyDetector; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.SharedResourceHolder.Resource; import java.io.IOException; import java.lang.reflect.Constructor; @@ -52,7 +53,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A DNS-based {@link NameResolver}. @@ -138,19 +138,23 @@ final class DnsNameResolver extends NameResolver { private final String host; private final int port; private final Resource executorResource; - @GuardedBy("this") - private boolean shutdown; - @GuardedBy("this") - private Executor executor; - @GuardedBy("this") - private boolean resolving; - @GuardedBy("this") - private Listener listener; + private final long cacheTtlNanos; + private final SynchronizationContext syncContext; - private final Runnable resolveRunnable; + // Following fields must be accessed from syncContext + private final Stopwatch stopwatch; + private ResolutionResults cachedResolutionResults; + private boolean shutdown; + private Executor executor; + private boolean resolving; + + // The field must be accessed from syncContext, although the methods on a Listener can be called + // from any thread. + private Listener listener; DnsNameResolver(@Nullable String nsAuthority, String name, Helper helper, Resource executorResource, Stopwatch stopwatch, boolean isAndroid) { + Preconditions.checkNotNull(helper, "helper"); // TODO: if a DNS server is provided as nsAuthority, use it. // https://www.captechconsulting.com/blogs/accessing-the-dusty-corners-of-dns-with-java this.executorResource = executorResource; @@ -167,16 +171,19 @@ final class DnsNameResolver extends NameResolver { port = nameUri.getPort(); } this.proxyDetector = Preconditions.checkNotNull(helper.getProxyDetector(), "proxyDetector"); - this.resolveRunnable = new Resolve(this, stopwatch, getNetworkAddressCacheTtlNanos(isAndroid)); + this.cacheTtlNanos = getNetworkAddressCacheTtlNanos(isAndroid); + this.stopwatch = Preconditions.checkNotNull(stopwatch, "stopwatch"); + this.syncContext = + Preconditions.checkNotNull(helper.getSynchronizationContext(), "syncContext"); } @Override - public final String getServiceAuthority() { + public String getServiceAuthority() { return authority; } @Override - public final synchronized void start(Listener listener) { + public void start(Listener listener) { Preconditions.checkState(this.listener == null, "already started"); executor = SharedResourceHolder.get(executorResource); this.listener = Preconditions.checkNotNull(listener, "listener"); @@ -184,64 +191,45 @@ final class DnsNameResolver extends NameResolver { } @Override - public final synchronized void refresh() { + public void refresh() { Preconditions.checkState(listener != null, "not started"); resolve(); } - @VisibleForTesting - static final class Resolve implements Runnable { + private final class Resolve implements Runnable { + private final Listener savedListener; - private final DnsNameResolver resolver; - private final Stopwatch stopwatch; - private final long cacheTtlNanos; - private ResolutionResults cachedResolutionResults = null; - - Resolve(DnsNameResolver resolver, Stopwatch stopwatch, long cacheTtlNanos) { - this.resolver = resolver; - this.stopwatch = Preconditions.checkNotNull(stopwatch, "stopwatch"); - this.cacheTtlNanos = cacheTtlNanos; + Resolve(Listener savedListener) { + this.savedListener = Preconditions.checkNotNull(savedListener, "savedListener"); } @Override public void run() { if (logger.isLoggable(Level.FINER)) { - logger.finer("Attempting DNS resolution of " + resolver.host); - } - Listener savedListener; - synchronized (resolver) { - if (resolver.shutdown || !cacheRefreshRequired()) { - return; - } - savedListener = resolver.listener; - resolver.resolving = true; + logger.finer("Attempting DNS resolution of " + host); } try { - resolveInternal(savedListener); + resolveInternal(); } finally { - synchronized (resolver) { - resolver.resolving = false; - } + syncContext.execute(new Runnable() { + @Override + public void run() { + resolving = false; + } + }); } } - private boolean cacheRefreshRequired() { - return cachedResolutionResults == null - || cacheTtlNanos == 0 - || (cacheTtlNanos > 0 && stopwatch.elapsed(TimeUnit.NANOSECONDS) > cacheTtlNanos); - } - @VisibleForTesting - void resolveInternal(Listener savedListener) { + void resolveInternal() { InetSocketAddress destination = - InetSocketAddress.createUnresolved(resolver.host, resolver.port); + InetSocketAddress.createUnresolved(host, port); ProxiedSocketAddress proxiedAddr; try { - proxiedAddr = resolver.proxyDetector.proxyFor(destination); + proxiedAddr = proxyDetector.proxyFor(destination); } catch (IOException e) { savedListener.onError( - Status.UNAVAILABLE.withDescription("Unable to resolve host " + resolver.host) - .withCause(e)); + Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e)); return; } if (proxiedAddr != null) { @@ -256,37 +244,42 @@ final class DnsNameResolver extends NameResolver { ResolutionResults resolutionResults; try { ResourceResolver resourceResolver = null; - if (shouldUseJndi(enableJndi, enableJndiLocalhost, resolver.host)) { - resourceResolver = resolver.getResourceResolver(); + if (shouldUseJndi(enableJndi, enableJndiLocalhost, host)) { + resourceResolver = getResourceResolver(); } - resolutionResults = resolveAll( - resolver.addressResolver, + final ResolutionResults results = resolveAll( + addressResolver, resourceResolver, enableSrv, enableTxt, - resolver.host); - cachedResolutionResults = resolutionResults; - if (cacheTtlNanos > 0) { - stopwatch.reset().start(); - } + host); + resolutionResults = results; + syncContext.execute(new Runnable() { + @Override + public void run() { + cachedResolutionResults = results; + if (cacheTtlNanos > 0) { + stopwatch.reset().start(); + } + } + }); if (logger.isLoggable(Level.FINER)) { - logger.finer("Found DNS results " + resolutionResults + " for " + resolver.host); + logger.finer("Found DNS results " + resolutionResults + " for " + host); } } catch (Exception e) { savedListener.onError( - Status.UNAVAILABLE.withDescription("Unable to resolve host " + resolver.host) - .withCause(e)); + Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e)); return; } // Each address forms an EAG List servers = new ArrayList<>(); for (InetAddress inetAddr : resolutionResults.addresses) { - servers.add(new EquivalentAddressGroup(new InetSocketAddress(inetAddr, resolver.port))); + servers.add(new EquivalentAddressGroup(new InetSocketAddress(inetAddr, port))); } servers.addAll(resolutionResults.balancerAddresses); if (servers.isEmpty()) { savedListener.onError(Status.UNAVAILABLE.withDescription( - "No DNS backend or balancer addresses found for " + resolver.host)); + "No DNS backend or balancer addresses found for " + host)); return; } @@ -298,7 +291,7 @@ final class DnsNameResolver extends NameResolver { parseTxtResults(resolutionResults.txtRecords)) { try { serviceConfig = - maybeChooseServiceConfig(possibleConfig, resolver.random, getLocalHostname()); + maybeChooseServiceConfig(possibleConfig, random, getLocalHostname()); } catch (RuntimeException e) { logger.log(Level.WARNING, "Bad service config choice " + possibleConfig, e); } @@ -313,22 +306,28 @@ final class DnsNameResolver extends NameResolver { attrs.set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig); } } else { - logger.log(Level.FINE, "No TXT records found for {0}", new Object[]{resolver.host}); + logger.log(Level.FINE, "No TXT records found for {0}", new Object[]{host}); } savedListener.onAddresses(servers, attrs.build()); } } - @GuardedBy("this") private void resolve() { - if (resolving || shutdown) { + if (resolving || shutdown || !cacheRefreshRequired()) { return; } - executor.execute(resolveRunnable); + resolving = true; + executor.execute(new Resolve(listener)); + } + + private boolean cacheRefreshRequired() { + return cachedResolutionResults == null + || cacheTtlNanos == 0 + || (cacheTtlNanos > 0 && stopwatch.elapsed(TimeUnit.NANOSECONDS) > cacheTtlNanos); } @Override - public final synchronized void shutdown() { + public void shutdown() { if (shutdown) { return; } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index a6bb50463c..397db1b224 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -310,6 +310,7 @@ final class ManagedChannelImpl extends ManagedChannel implements // Must be called from syncContext private void shutdownNameResolverAndLoadBalancer(boolean channelIsActive) { + syncContext.throwIfNotInThisSynchronizationContext(); if (channelIsActive) { checkState(nameResolverStarted, "nameResolver is not started"); checkState(lbHelper != null, "lbHelper is null"); @@ -338,6 +339,7 @@ final class ManagedChannelImpl extends ManagedChannel implements */ @VisibleForTesting void exitIdleMode() { + syncContext.throwIfNotInThisSynchronizationContext(); if (shutdown.get() || panicMode) { return; } @@ -557,6 +559,11 @@ final class ManagedChannelImpl extends ManagedChannel implements public ProxyDetector getProxyDetector() { return proxyDetector; } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } }; this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverHelper); this.timeProvider = checkNotNull(timeProvider, "timeProvider"); diff --git a/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java b/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java index 4922d6350d..33959a283d 100644 --- a/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java +++ b/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import io.grpc.NameResolver; import io.grpc.ProxyDetector; +import io.grpc.SynchronizationContext; import java.net.URI; import org.junit.Test; import org.junit.runner.RunWith; @@ -30,8 +31,14 @@ import org.junit.runners.JUnit4; /** Unit tests for {@link DnsNameResolverProvider}. */ @RunWith(JUnit4.class) public class DnsNameResolverProviderTest { - - private static final NameResolver.Helper HELPER = new NameResolver.Helper() { + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private final NameResolver.Helper helper = new NameResolver.Helper() { @Override public int getDefaultPort() { throw new UnsupportedOperationException("Should not be called"); @@ -41,6 +48,11 @@ public class DnsNameResolverProviderTest { public ProxyDetector getProxyDetector() { return GrpcUtil.getDefaultProxyDetector(); } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } }; private DnsNameResolverProvider provider = new DnsNameResolverProvider(); @@ -53,8 +65,8 @@ public class DnsNameResolverProviderTest { @Test public void newNameResolver() { assertSame(DnsNameResolver.class, - provider.newNameResolver(URI.create("dns:///localhost:443"), HELPER).getClass()); + provider.newNameResolver(URI.create("dns:///localhost:443"), helper).getClass()); assertNull( - provider.newNameResolver(URI.create("notdns:///localhost:443"), HELPER)); + provider.newNameResolver(URI.create("notdns:///localhost:443"), helper)); } } diff --git a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java index 367c855834..52f6ae4a28 100644 --- a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java @@ -42,6 +42,7 @@ import io.grpc.NameResolver; import io.grpc.ProxyDetector; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.SynchronizationContext; import io.grpc.internal.DnsNameResolver.AddressResolver; import io.grpc.internal.DnsNameResolver.ResolutionResults; import io.grpc.internal.DnsNameResolver.ResourceResolver; @@ -95,7 +96,14 @@ public class DnsNameResolverTest { private final Map serviceConfig = new LinkedHashMap<>(); private static final int DEFAULT_PORT = 887; - private static final NameResolver.Helper HELPER = new NameResolver.Helper() { + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private final NameResolver.Helper helper = new NameResolver.Helper() { @Override public int getDefaultPort() { return DEFAULT_PORT; @@ -105,6 +113,11 @@ public class DnsNameResolverTest { public ProxyDetector getProxyDetector() { return GrpcUtil.getDefaultProxyDetector(); } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } }; private final DnsNameResolverProvider provider = new DnsNameResolverProvider(); @@ -132,27 +145,27 @@ public class DnsNameResolverTest { @Mock private RecordFetcher recordFetcher; - private DnsNameResolver newResolver(String name, int port) { - return newResolver(name, port, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted()); + private DnsNameResolver newResolver(String name, int defaultPort) { + return newResolver( + name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted()); } - private DnsNameResolver newResolver(String name, int port, boolean isAndroid) { - return - newResolver( - name, port, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(), isAndroid); + private DnsNameResolver newResolver(String name, int defaultPort, boolean isAndroid) { + return newResolver( + name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(), isAndroid); } private DnsNameResolver newResolver( String name, - int port, + int defaultPort, ProxyDetector proxyDetector, Stopwatch stopwatch) { - return newResolver(name, port, proxyDetector, stopwatch, false); + return newResolver(name, defaultPort, proxyDetector, stopwatch, false); } private DnsNameResolver newResolver( String name, - final int port, + final int defaultPort, final ProxyDetector proxyDetector, Stopwatch stopwatch, boolean isAndroid) { @@ -162,13 +175,18 @@ public class DnsNameResolverTest { new NameResolver.Helper() { @Override public int getDefaultPort() { - return port; + return defaultPort; } @Override public ProxyDetector getProxyDetector() { return proxyDetector; } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } }, fakeExecutorResource, stopwatch, @@ -292,17 +310,16 @@ public class DnsNameResolverTest { @Test public void resolveAll_failsOnEmptyResult() throws Exception { - String hostname = "dns:///addr.fake:1234"; - DnsNameResolver nrf = - new DnsNameResolverProvider().newNameResolver(new URI(hostname), HELPER); - nrf.setAddressResolver(new AddressResolver() { + DnsNameResolver nr = newResolver("dns:///addr.fake:1234", 443); + nr.setAddressResolver(new AddressResolver() { @Override public List resolveAddress(String host) throws Exception { return Collections.emptyList(); } }); - new DnsNameResolver.Resolve(nrf, Stopwatch.createUnstarted(), 0).resolveInternal(mockListener); + nr.start(mockListener); + assertThat(fakeExecutor.runDueTasks()).isEqualTo(1); ArgumentCaptor ac = ArgumentCaptor.forClass(Status.class); verify(mockListener).onError(ac.capture()); @@ -334,10 +351,9 @@ public class DnsNameResolverTest { fakeTicker.advance(1, TimeUnit.DAYS); resolver.refresh(); - assertEquals(1, fakeExecutor.runDueTasks()); - verifyNoMoreInteractions(mockListener); - assertAnswerMatches(answer1, 81, resultCaptor.getValue()); + assertEquals(0, fakeExecutor.runDueTasks()); assertEquals(0, fakeClock.numPendingTasks()); + verifyNoMoreInteractions(mockListener); resolver.shutdown(); @@ -369,10 +385,9 @@ public class DnsNameResolverTest { // this refresh should return cached result fakeTicker.advance(ttl - 1, TimeUnit.SECONDS); resolver.refresh(); - assertEquals(1, fakeExecutor.runDueTasks()); - verifyNoMoreInteractions(mockListener); - assertAnswerMatches(answer, 81, resultCaptor.getValue()); + assertEquals(0, fakeExecutor.runDueTasks()); assertEquals(0, fakeClock.numPendingTasks()); + verifyNoMoreInteractions(mockListener); resolver.shutdown(); @@ -444,10 +459,9 @@ public class DnsNameResolverTest { fakeTicker.advance(DnsNameResolver.DEFAULT_NETWORK_CACHE_TTL_SECONDS, TimeUnit.SECONDS); resolver.refresh(); - assertEquals(1, fakeExecutor.runDueTasks()); - verifyNoMoreInteractions(mockListener); - assertAnswerMatches(answer1, 81, resultCaptor.getValue()); + assertEquals(0, fakeExecutor.runDueTasks()); assertEquals(0, fakeClock.numPendingTasks()); + verifyNoMoreInteractions(mockListener); fakeTicker.advance(1, TimeUnit.SECONDS); resolver.refresh(); @@ -1006,7 +1020,7 @@ public class DnsNameResolverTest { private void testInvalidUri(URI uri) { try { - provider.newNameResolver(uri, HELPER); + provider.newNameResolver(uri, helper); fail("Should have failed"); } catch (IllegalArgumentException e) { // expected @@ -1014,7 +1028,7 @@ public class DnsNameResolverTest { } private void testValidUri(URI uri, String exportedAuthority, int expectedPort) { - DnsNameResolver resolver = provider.newNameResolver(uri, HELPER); + DnsNameResolver resolver = provider.newNameResolver(uri, helper); assertNotNull(resolver); assertEquals(expectedPort, resolver.getPort()); assertEquals(exportedAuthority, resolver.getServiceAuthority()); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 7a4b6d8762..a23526222c 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -262,7 +262,12 @@ public class ManagedChannelImplTest { int numExpectedTasks = 0; // Force-exit the initial idle-mode - channel.exitIdleMode(); + channel.syncContext.execute(new Runnable() { + @Override + public void run() { + channel.exitIdleMode(); + } + }); if (channelBuilder.idleTimeoutMillis != ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE) { numExpectedTasks += 1; }