core: make NameResolver not thread-safe (#5364)

Resolves #2649

As a prerequisite, added `getSynchronizationContext()` to `NameResolver.Helper`.

`DnsNameResolver` has gone through a small refactor around the `Resolve` runnable, which makes it a little simpler.
This commit is contained in:
Kun Zhang 2019-02-20 11:45:38 -08:00 committed by GitHub
parent 05b6156d43
commit b867f8e4fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 102 deletions

View File

@ -37,10 +37,14 @@ import javax.annotation.concurrent.ThreadSafe;
* {@link Listener} is responsible for eventually (after an appropriate backoff period) invoking
* {@link #refresh()}.
*
* <p>Implementations <strong>don't need to be thread-safe</strong>. All methods are guaranteed to
* be called sequentially. Additionally, all methods that have side-effects, i.e., {@link #start},
* {@link #shutdown} and {@link #refresh} are called from the same {@link SynchronizationContext} as
* returned by {@link Helper#getSynchronizationContext}.
*
* @since 1.0.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770")
@ThreadSafe
public abstract class NameResolver {
/**
* Returns the authority used to authenticate connections to servers. It <strong>must</strong> be
@ -209,18 +213,34 @@ public abstract class NameResolver {
/**
* A utility object passed to {@link Factory#newNameResolver(URI, NameResolver.Helper)}.
*
* @since 1.19.0
*/
public abstract static class Helper {
/**
* The port number used in case the target or the underlying naming system doesn't provide a
* port number.
*
* @since 1.19.0
*/
public abstract int getDefaultPort();
/**
* If the NameResolver wants to support proxy, it should inquire this {@link ProxyDetector}.
* See documentation on {@link ProxyDetector} about how proxies work in gRPC.
*
* @since 1.19.0
*/
public abstract ProxyDetector getProxyDetector();
/**
* Returns the {@link SynchronizationContext} where {@link #start}, {@link #shutdown} and {@link
* #refresh} are run from.
*
* @since 1.20.0
*/
public SynchronizationContext getSynchronizationContext() {
throw new UnsupportedOperationException("Not implemented");
}
}
}

View File

@ -30,6 +30,7 @@ import io.grpc.NameResolver;
import io.grpc.ProxiedSocketAddress;
import io.grpc.ProxyDetector;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.SharedResourceHolder.Resource;
import java.io.IOException;
import java.lang.reflect.Constructor;
@ -52,7 +53,6 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
/**
* A DNS-based {@link NameResolver}.
@ -138,19 +138,23 @@ final class DnsNameResolver extends NameResolver {
private final String host;
private final int port;
private final Resource<Executor> executorResource;
@GuardedBy("this")
private boolean shutdown;
@GuardedBy("this")
private Executor executor;
@GuardedBy("this")
private boolean resolving;
@GuardedBy("this")
private Listener listener;
private final long cacheTtlNanos;
private final SynchronizationContext syncContext;
private final Runnable resolveRunnable;
// Following fields must be accessed from syncContext
private final Stopwatch stopwatch;
private ResolutionResults cachedResolutionResults;
private boolean shutdown;
private Executor executor;
private boolean resolving;
// The field must be accessed from syncContext, although the methods on a Listener can be called
// from any thread.
private Listener listener;
DnsNameResolver(@Nullable String nsAuthority, String name, Helper helper,
Resource<Executor> executorResource, Stopwatch stopwatch, boolean isAndroid) {
Preconditions.checkNotNull(helper, "helper");
// TODO: if a DNS server is provided as nsAuthority, use it.
// https://www.captechconsulting.com/blogs/accessing-the-dusty-corners-of-dns-with-java
this.executorResource = executorResource;
@ -167,16 +171,19 @@ final class DnsNameResolver extends NameResolver {
port = nameUri.getPort();
}
this.proxyDetector = Preconditions.checkNotNull(helper.getProxyDetector(), "proxyDetector");
this.resolveRunnable = new Resolve(this, stopwatch, getNetworkAddressCacheTtlNanos(isAndroid));
this.cacheTtlNanos = getNetworkAddressCacheTtlNanos(isAndroid);
this.stopwatch = Preconditions.checkNotNull(stopwatch, "stopwatch");
this.syncContext =
Preconditions.checkNotNull(helper.getSynchronizationContext(), "syncContext");
}
@Override
public final String getServiceAuthority() {
public String getServiceAuthority() {
return authority;
}
@Override
public final synchronized void start(Listener listener) {
public void start(Listener listener) {
Preconditions.checkState(this.listener == null, "already started");
executor = SharedResourceHolder.get(executorResource);
this.listener = Preconditions.checkNotNull(listener, "listener");
@ -184,64 +191,45 @@ final class DnsNameResolver extends NameResolver {
}
@Override
public final synchronized void refresh() {
public void refresh() {
Preconditions.checkState(listener != null, "not started");
resolve();
}
@VisibleForTesting
static final class Resolve implements Runnable {
private final class Resolve implements Runnable {
private final Listener savedListener;
private final DnsNameResolver resolver;
private final Stopwatch stopwatch;
private final long cacheTtlNanos;
private ResolutionResults cachedResolutionResults = null;
Resolve(DnsNameResolver resolver, Stopwatch stopwatch, long cacheTtlNanos) {
this.resolver = resolver;
this.stopwatch = Preconditions.checkNotNull(stopwatch, "stopwatch");
this.cacheTtlNanos = cacheTtlNanos;
Resolve(Listener savedListener) {
this.savedListener = Preconditions.checkNotNull(savedListener, "savedListener");
}
@Override
public void run() {
if (logger.isLoggable(Level.FINER)) {
logger.finer("Attempting DNS resolution of " + resolver.host);
}
Listener savedListener;
synchronized (resolver) {
if (resolver.shutdown || !cacheRefreshRequired()) {
return;
}
savedListener = resolver.listener;
resolver.resolving = true;
logger.finer("Attempting DNS resolution of " + host);
}
try {
resolveInternal(savedListener);
resolveInternal();
} finally {
synchronized (resolver) {
resolver.resolving = false;
}
syncContext.execute(new Runnable() {
@Override
public void run() {
resolving = false;
}
});
}
}
private boolean cacheRefreshRequired() {
return cachedResolutionResults == null
|| cacheTtlNanos == 0
|| (cacheTtlNanos > 0 && stopwatch.elapsed(TimeUnit.NANOSECONDS) > cacheTtlNanos);
}
@VisibleForTesting
void resolveInternal(Listener savedListener) {
void resolveInternal() {
InetSocketAddress destination =
InetSocketAddress.createUnresolved(resolver.host, resolver.port);
InetSocketAddress.createUnresolved(host, port);
ProxiedSocketAddress proxiedAddr;
try {
proxiedAddr = resolver.proxyDetector.proxyFor(destination);
proxiedAddr = proxyDetector.proxyFor(destination);
} catch (IOException e) {
savedListener.onError(
Status.UNAVAILABLE.withDescription("Unable to resolve host " + resolver.host)
.withCause(e));
Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e));
return;
}
if (proxiedAddr != null) {
@ -256,37 +244,42 @@ final class DnsNameResolver extends NameResolver {
ResolutionResults resolutionResults;
try {
ResourceResolver resourceResolver = null;
if (shouldUseJndi(enableJndi, enableJndiLocalhost, resolver.host)) {
resourceResolver = resolver.getResourceResolver();
if (shouldUseJndi(enableJndi, enableJndiLocalhost, host)) {
resourceResolver = getResourceResolver();
}
resolutionResults = resolveAll(
resolver.addressResolver,
final ResolutionResults results = resolveAll(
addressResolver,
resourceResolver,
enableSrv,
enableTxt,
resolver.host);
cachedResolutionResults = resolutionResults;
if (cacheTtlNanos > 0) {
stopwatch.reset().start();
}
host);
resolutionResults = results;
syncContext.execute(new Runnable() {
@Override
public void run() {
cachedResolutionResults = results;
if (cacheTtlNanos > 0) {
stopwatch.reset().start();
}
}
});
if (logger.isLoggable(Level.FINER)) {
logger.finer("Found DNS results " + resolutionResults + " for " + resolver.host);
logger.finer("Found DNS results " + resolutionResults + " for " + host);
}
} catch (Exception e) {
savedListener.onError(
Status.UNAVAILABLE.withDescription("Unable to resolve host " + resolver.host)
.withCause(e));
Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e));
return;
}
// Each address forms an EAG
List<EquivalentAddressGroup> servers = new ArrayList<>();
for (InetAddress inetAddr : resolutionResults.addresses) {
servers.add(new EquivalentAddressGroup(new InetSocketAddress(inetAddr, resolver.port)));
servers.add(new EquivalentAddressGroup(new InetSocketAddress(inetAddr, port)));
}
servers.addAll(resolutionResults.balancerAddresses);
if (servers.isEmpty()) {
savedListener.onError(Status.UNAVAILABLE.withDescription(
"No DNS backend or balancer addresses found for " + resolver.host));
"No DNS backend or balancer addresses found for " + host));
return;
}
@ -298,7 +291,7 @@ final class DnsNameResolver extends NameResolver {
parseTxtResults(resolutionResults.txtRecords)) {
try {
serviceConfig =
maybeChooseServiceConfig(possibleConfig, resolver.random, getLocalHostname());
maybeChooseServiceConfig(possibleConfig, random, getLocalHostname());
} catch (RuntimeException e) {
logger.log(Level.WARNING, "Bad service config choice " + possibleConfig, e);
}
@ -313,22 +306,28 @@ final class DnsNameResolver extends NameResolver {
attrs.set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig);
}
} else {
logger.log(Level.FINE, "No TXT records found for {0}", new Object[]{resolver.host});
logger.log(Level.FINE, "No TXT records found for {0}", new Object[]{host});
}
savedListener.onAddresses(servers, attrs.build());
}
}
@GuardedBy("this")
private void resolve() {
if (resolving || shutdown) {
if (resolving || shutdown || !cacheRefreshRequired()) {
return;
}
executor.execute(resolveRunnable);
resolving = true;
executor.execute(new Resolve(listener));
}
private boolean cacheRefreshRequired() {
return cachedResolutionResults == null
|| cacheTtlNanos == 0
|| (cacheTtlNanos > 0 && stopwatch.elapsed(TimeUnit.NANOSECONDS) > cacheTtlNanos);
}
@Override
public final synchronized void shutdown() {
public void shutdown() {
if (shutdown) {
return;
}

View File

@ -310,6 +310,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
// Must be called from syncContext
private void shutdownNameResolverAndLoadBalancer(boolean channelIsActive) {
syncContext.throwIfNotInThisSynchronizationContext();
if (channelIsActive) {
checkState(nameResolverStarted, "nameResolver is not started");
checkState(lbHelper != null, "lbHelper is null");
@ -338,6 +339,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
*/
@VisibleForTesting
void exitIdleMode() {
syncContext.throwIfNotInThisSynchronizationContext();
if (shutdown.get() || panicMode) {
return;
}
@ -557,6 +559,11 @@ final class ManagedChannelImpl extends ManagedChannel implements
public ProxyDetector getProxyDetector() {
return proxyDetector;
}
@Override
public SynchronizationContext getSynchronizationContext() {
return syncContext;
}
};
this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverHelper);
this.timeProvider = checkNotNull(timeProvider, "timeProvider");

View File

@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import io.grpc.NameResolver;
import io.grpc.ProxyDetector;
import io.grpc.SynchronizationContext;
import java.net.URI;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -30,8 +31,14 @@ import org.junit.runners.JUnit4;
/** Unit tests for {@link DnsNameResolverProvider}. */
@RunWith(JUnit4.class)
public class DnsNameResolverProviderTest {
private static final NameResolver.Helper HELPER = new NameResolver.Helper() {
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
});
private final NameResolver.Helper helper = new NameResolver.Helper() {
@Override
public int getDefaultPort() {
throw new UnsupportedOperationException("Should not be called");
@ -41,6 +48,11 @@ public class DnsNameResolverProviderTest {
public ProxyDetector getProxyDetector() {
return GrpcUtil.getDefaultProxyDetector();
}
@Override
public SynchronizationContext getSynchronizationContext() {
return syncContext;
}
};
private DnsNameResolverProvider provider = new DnsNameResolverProvider();
@ -53,8 +65,8 @@ public class DnsNameResolverProviderTest {
@Test
public void newNameResolver() {
assertSame(DnsNameResolver.class,
provider.newNameResolver(URI.create("dns:///localhost:443"), HELPER).getClass());
provider.newNameResolver(URI.create("dns:///localhost:443"), helper).getClass());
assertNull(
provider.newNameResolver(URI.create("notdns:///localhost:443"), HELPER));
provider.newNameResolver(URI.create("notdns:///localhost:443"), helper));
}
}

View File

@ -42,6 +42,7 @@ import io.grpc.NameResolver;
import io.grpc.ProxyDetector;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.SynchronizationContext;
import io.grpc.internal.DnsNameResolver.AddressResolver;
import io.grpc.internal.DnsNameResolver.ResolutionResults;
import io.grpc.internal.DnsNameResolver.ResourceResolver;
@ -95,7 +96,14 @@ public class DnsNameResolverTest {
private final Map<String, Object> serviceConfig = new LinkedHashMap<>();
private static final int DEFAULT_PORT = 887;
private static final NameResolver.Helper HELPER = new NameResolver.Helper() {
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
});
private final NameResolver.Helper helper = new NameResolver.Helper() {
@Override
public int getDefaultPort() {
return DEFAULT_PORT;
@ -105,6 +113,11 @@ public class DnsNameResolverTest {
public ProxyDetector getProxyDetector() {
return GrpcUtil.getDefaultProxyDetector();
}
@Override
public SynchronizationContext getSynchronizationContext() {
return syncContext;
}
};
private final DnsNameResolverProvider provider = new DnsNameResolverProvider();
@ -132,27 +145,27 @@ public class DnsNameResolverTest {
@Mock
private RecordFetcher recordFetcher;
private DnsNameResolver newResolver(String name, int port) {
return newResolver(name, port, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted());
private DnsNameResolver newResolver(String name, int defaultPort) {
return newResolver(
name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted());
}
private DnsNameResolver newResolver(String name, int port, boolean isAndroid) {
return
newResolver(
name, port, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(), isAndroid);
private DnsNameResolver newResolver(String name, int defaultPort, boolean isAndroid) {
return newResolver(
name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(), isAndroid);
}
private DnsNameResolver newResolver(
String name,
int port,
int defaultPort,
ProxyDetector proxyDetector,
Stopwatch stopwatch) {
return newResolver(name, port, proxyDetector, stopwatch, false);
return newResolver(name, defaultPort, proxyDetector, stopwatch, false);
}
private DnsNameResolver newResolver(
String name,
final int port,
final int defaultPort,
final ProxyDetector proxyDetector,
Stopwatch stopwatch,
boolean isAndroid) {
@ -162,13 +175,18 @@ public class DnsNameResolverTest {
new NameResolver.Helper() {
@Override
public int getDefaultPort() {
return port;
return defaultPort;
}
@Override
public ProxyDetector getProxyDetector() {
return proxyDetector;
}
@Override
public SynchronizationContext getSynchronizationContext() {
return syncContext;
}
},
fakeExecutorResource,
stopwatch,
@ -292,17 +310,16 @@ public class DnsNameResolverTest {
@Test
public void resolveAll_failsOnEmptyResult() throws Exception {
String hostname = "dns:///addr.fake:1234";
DnsNameResolver nrf =
new DnsNameResolverProvider().newNameResolver(new URI(hostname), HELPER);
nrf.setAddressResolver(new AddressResolver() {
DnsNameResolver nr = newResolver("dns:///addr.fake:1234", 443);
nr.setAddressResolver(new AddressResolver() {
@Override
public List<InetAddress> resolveAddress(String host) throws Exception {
return Collections.emptyList();
}
});
new DnsNameResolver.Resolve(nrf, Stopwatch.createUnstarted(), 0).resolveInternal(mockListener);
nr.start(mockListener);
assertThat(fakeExecutor.runDueTasks()).isEqualTo(1);
ArgumentCaptor<Status> ac = ArgumentCaptor.forClass(Status.class);
verify(mockListener).onError(ac.capture());
@ -334,10 +351,9 @@ public class DnsNameResolverTest {
fakeTicker.advance(1, TimeUnit.DAYS);
resolver.refresh();
assertEquals(1, fakeExecutor.runDueTasks());
verifyNoMoreInteractions(mockListener);
assertAnswerMatches(answer1, 81, resultCaptor.getValue());
assertEquals(0, fakeExecutor.runDueTasks());
assertEquals(0, fakeClock.numPendingTasks());
verifyNoMoreInteractions(mockListener);
resolver.shutdown();
@ -369,10 +385,9 @@ public class DnsNameResolverTest {
// this refresh should return cached result
fakeTicker.advance(ttl - 1, TimeUnit.SECONDS);
resolver.refresh();
assertEquals(1, fakeExecutor.runDueTasks());
verifyNoMoreInteractions(mockListener);
assertAnswerMatches(answer, 81, resultCaptor.getValue());
assertEquals(0, fakeExecutor.runDueTasks());
assertEquals(0, fakeClock.numPendingTasks());
verifyNoMoreInteractions(mockListener);
resolver.shutdown();
@ -444,10 +459,9 @@ public class DnsNameResolverTest {
fakeTicker.advance(DnsNameResolver.DEFAULT_NETWORK_CACHE_TTL_SECONDS, TimeUnit.SECONDS);
resolver.refresh();
assertEquals(1, fakeExecutor.runDueTasks());
verifyNoMoreInteractions(mockListener);
assertAnswerMatches(answer1, 81, resultCaptor.getValue());
assertEquals(0, fakeExecutor.runDueTasks());
assertEquals(0, fakeClock.numPendingTasks());
verifyNoMoreInteractions(mockListener);
fakeTicker.advance(1, TimeUnit.SECONDS);
resolver.refresh();
@ -1006,7 +1020,7 @@ public class DnsNameResolverTest {
private void testInvalidUri(URI uri) {
try {
provider.newNameResolver(uri, HELPER);
provider.newNameResolver(uri, helper);
fail("Should have failed");
} catch (IllegalArgumentException e) {
// expected
@ -1014,7 +1028,7 @@ public class DnsNameResolverTest {
}
private void testValidUri(URI uri, String exportedAuthority, int expectedPort) {
DnsNameResolver resolver = provider.newNameResolver(uri, HELPER);
DnsNameResolver resolver = provider.newNameResolver(uri, helper);
assertNotNull(resolver);
assertEquals(expectedPort, resolver.getPort());
assertEquals(exportedAuthority, resolver.getServiceAuthority());

View File

@ -262,7 +262,12 @@ public class ManagedChannelImplTest {
int numExpectedTasks = 0;
// Force-exit the initial idle-mode
channel.exitIdleMode();
channel.syncContext.execute(new Runnable() {
@Override
public void run() {
channel.exitIdleMode();
}
});
if (channelBuilder.idleTimeoutMillis != ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE) {
numExpectedTasks += 1;
}