diff --git a/core/src/main/java/io/grpc/DnsNameResolver.java b/core/src/main/java/io/grpc/DnsNameResolver.java index ce787ebbed..ae5fa04578 100644 --- a/core/src/main/java/io/grpc/DnsNameResolver.java +++ b/core/src/main/java/io/grpc/DnsNameResolver.java @@ -31,16 +31,21 @@ package io.grpc; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder; +import io.grpc.internal.SharedResourceHolder.Resource; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; +import java.net.UnknownHostException; import java.util.ArrayList; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -50,21 +55,33 @@ import javax.annotation.concurrent.GuardedBy; * * @see DnsNameResolverFactory */ -final class DnsNameResolver extends NameResolver { +class DnsNameResolver extends NameResolver { private final String authority; private final String host; private final int port; + private final Resource timerServiceResource; + private final Resource executorResource; + @GuardedBy("this") + private boolean shutdown; + @GuardedBy("this") + private ScheduledExecutorService timerService; @GuardedBy("this") private ExecutorService executor; @GuardedBy("this") + private ScheduledFuture resolutionTask; + @GuardedBy("this") private boolean resolving; @GuardedBy("this") private Listener listener; - DnsNameResolver(@Nullable String nsAuthority, String name, Attributes params) { + DnsNameResolver(@Nullable String nsAuthority, String name, Attributes params, + Resource timerServiceResource, + Resource executorResource) { // TODO: if a DNS server is provided as nsAuthority, use it. // https://www.captechconsulting.com/blogs/accessing-the-dusty-corners-of-dns-with-java + this.timerServiceResource = timerServiceResource; + this.executorResource = executorResource; // Must prepend a "//" to the name when constructing a URI, otherwise it will be treated as an // opaque URI, thus the authority and host of the resulted URI would be null. URI nameUri = URI.create("//" + name); @@ -85,39 +102,55 @@ final class DnsNameResolver extends NameResolver { } @Override - public String getServiceAuthority() { + public final String getServiceAuthority() { return authority; } @Override - public synchronized void start(Listener listener) { - Preconditions.checkState(executor == null, "already started"); - executor = SharedResourceHolder.get(GrpcUtil.SHARED_CHANNEL_EXECUTOR); - this.listener = listener; + public final synchronized void start(Listener listener) { + Preconditions.checkState(this.listener == null, "already started"); + timerService = SharedResourceHolder.get(timerServiceResource); + executor = SharedResourceHolder.get(executorResource); + this.listener = Preconditions.checkNotNull(listener, "listener"); resolve(); } @Override - public synchronized void refresh() { - Preconditions.checkState(executor != null, "not started"); + public final synchronized void refresh() { + Preconditions.checkState(listener != null, "not started"); resolve(); } - @GuardedBy("this") - private void resolve() { - if (resolving) { - return; - } - resolving = true; - final Listener savedListener = Preconditions.checkNotNull(listener); - executor.execute(new Runnable() { + private final Runnable resolutionRunnable = new Runnable() { @Override public void run() { InetAddress[] inetAddrs; + Listener savedListener; + synchronized (DnsNameResolver.this) { + // If this task is started by refresh(), there might already be a scheduled task. + if (resolutionTask != null) { + resolutionTask.cancel(false); + resolutionTask = null; + } + if (shutdown) { + return; + } + savedListener = listener; + resolving = true; + } try { try { - inetAddrs = InetAddress.getAllByName(host); - } catch (Exception e) { + inetAddrs = getAllByName(host); + } catch (UnknownHostException e) { + synchronized (DnsNameResolver.this) { + if (shutdown) { + return; + } + // Because timerService is the single-threaded GrpcUtil.TIMER_SERVICE in production, + // we need to delegate the blocking work to the executor + resolutionTask = timerService.schedule(resolutionRunnableOnExecutor, + 1, TimeUnit.MINUTES); + } savedListener.onError(Status.UNAVAILABLE.withCause(e)); return; } @@ -135,17 +168,51 @@ final class DnsNameResolver extends NameResolver { } } } - }); + }; + + private final Runnable resolutionRunnableOnExecutor = new Runnable() { + @Override + public void run() { + synchronized (DnsNameResolver.this) { + if (!shutdown) { + executor.execute(resolutionRunnable); + } + } + } + }; + + // To be mocked out in tests + @VisibleForTesting + InetAddress[] getAllByName(String host) throws UnknownHostException { + return InetAddress.getAllByName(host); + } + + @GuardedBy("this") + private void resolve() { + if (resolving || shutdown) { + return; + } + executor.execute(resolutionRunnable); } @Override - public synchronized void shutdown() { + public final synchronized void shutdown() { + if (shutdown) { + return; + } + shutdown = true; + if (resolutionTask != null) { + resolutionTask.cancel(false); + } + if (timerService != null) { + timerService = SharedResourceHolder.release(timerServiceResource, timerService); + } if (executor != null) { - executor = SharedResourceHolder.release(GrpcUtil.SHARED_CHANNEL_EXECUTOR, executor); + executor = SharedResourceHolder.release(executorResource, executor); } } - int getPort() { + final int getPort() { return port; } } diff --git a/core/src/main/java/io/grpc/DnsNameResolverFactory.java b/core/src/main/java/io/grpc/DnsNameResolverFactory.java index 39c2238fc7..6ee236d091 100644 --- a/core/src/main/java/io/grpc/DnsNameResolverFactory.java +++ b/core/src/main/java/io/grpc/DnsNameResolverFactory.java @@ -33,6 +33,8 @@ package io.grpc; import com.google.common.base.Preconditions; +import io.grpc.internal.GrpcUtil; + import java.net.URI; /** @@ -63,7 +65,8 @@ public final class DnsNameResolverFactory extends NameResolver.Factory { Preconditions.checkArgument(targetPath.startsWith("/"), "the path component (%s) of the target (%s) must start with '/'", targetPath, targetUri); String name = targetPath.substring(1); - return new DnsNameResolver(targetUri.getAuthority(), name, params); + return new DnsNameResolver(targetUri.getAuthority(), name, params, GrpcUtil.TIMER_SERVICE, + GrpcUtil.SHARED_CHANNEL_EXECUTOR); } else { return null; } diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index f604a87e4d..0ed5d4e111 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -385,7 +385,7 @@ public final class GrpcUtil { }; /** - * Shared executor for managing channel timers. + * Shared single-threaded executor for managing channel timers. */ public static final Resource TIMER_SERVICE = new Resource() { diff --git a/core/src/test/java/io/grpc/DnsNameResolverTest.java b/core/src/test/java/io/grpc/DnsNameResolverTest.java index e034a6d011..92fad2b145 100644 --- a/core/src/test/java/io/grpc/DnsNameResolverTest.java +++ b/core/src/test/java/io/grpc/DnsNameResolverTest.java @@ -33,24 +33,82 @@ package io.grpc; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import io.grpc.internal.FakeClock; +import io.grpc.internal.SharedResourceHolder.Resource; + +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import java.net.InetAddress; +import java.net.InetSocketAddress; import java.net.URI; +import java.net.UnknownHostException; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; /** Unit tests for {@link DnsNameResolver}. */ @RunWith(JUnit4.class) public class DnsNameResolverTest { - - private DnsNameResolverFactory factory = DnsNameResolverFactory.getInstance(); - private static final int DEFAULT_PORT = 887; private static final Attributes NAME_RESOLVER_PARAMS = Attributes.newBuilder().set(NameResolver.Factory.PARAMS_DEFAULT_PORT, DEFAULT_PORT).build(); + private final DnsNameResolverFactory factory = DnsNameResolverFactory.getInstance(); + private final FakeClock fakeClock = new FakeClock(); + private final Resource fakeTimerService = + new Resource() { + @Override + public ScheduledExecutorService create() { + return fakeClock.scheduledExecutorService; + } + + @Override + public void close(ScheduledExecutorService instance) { + assertSame(fakeClock, instance); + } + }; + + private final Resource fakeExecutor = + new Resource() { + @Override + public ExecutorService create() { + return fakeClock.scheduledExecutorService; + } + + @Override + public void close(ExecutorService instance) { + assertSame(fakeClock, instance); + } + }; + + @Mock + private NameResolver.Listener mockListener; + @Captor + private ArgumentCaptor> resultCaptor; + @Captor + private ArgumentCaptor statusCaptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + @Test public void invalidDnsName() throws Exception { testInvalidUri(new URI("dns", null, "/[invalid]", null)); @@ -73,6 +131,116 @@ public class DnsNameResolverTest { "foo.googleapis.com:456", 456); } + @Test + public void resolve() throws Exception { + InetAddress[] answer1 = createAddressList(2); + InetAddress[] answer2 = createAddressList(1); + String name = "foo.googleapis.com"; + MockResolver resolver = new MockResolver(name, 81, answer1, answer2); + resolver.start(mockListener); + verify(mockListener).onUpdate(resultCaptor.capture(), any(Attributes.class)); + assertEquals(name, resolver.invocations.poll()); + assertAnswerMatches(answer1, 81, resultCaptor.getValue()); + assertEquals(0, fakeClock.numPendingTasks()); + + resolver.refresh(); + verify(mockListener, times(2)).onUpdate(resultCaptor.capture(), any(Attributes.class)); + assertEquals(name, resolver.invocations.poll()); + assertAnswerMatches(answer2, 81, resultCaptor.getValue()); + assertEquals(0, fakeClock.numPendingTasks()); + + resolver.shutdown(); + } + + @Test + public void retry() throws Exception { + String name = "foo.googleapis.com"; + UnknownHostException error = new UnknownHostException(name); + InetAddress[] answer = createAddressList(2); + MockResolver resolver = new MockResolver(name, 81, error, error, answer); + resolver.start(mockListener); + verify(mockListener).onError(statusCaptor.capture()); + assertEquals(name, resolver.invocations.poll()); + Status status = statusCaptor.getValue(); + assertEquals(Status.Code.UNAVAILABLE, status.getCode()); + assertSame(error, status.getCause()); + + // First retry scheduled + assertEquals(1, fakeClock.numPendingTasks()); + fakeClock.forwardMillis(TimeUnit.MINUTES.toMillis(1) - 1); + assertEquals(1, fakeClock.numPendingTasks()); + + // First retry + fakeClock.forwardMillis(1); + verify(mockListener, times(2)).onError(statusCaptor.capture()); + assertEquals(name, resolver.invocations.poll()); + status = statusCaptor.getValue(); + assertEquals(Status.Code.UNAVAILABLE, status.getCode()); + assertSame(error, status.getCause()); + + // Second retry scheduled + assertEquals(1, fakeClock.numPendingTasks()); + fakeClock.forwardMillis(TimeUnit.MINUTES.toMillis(1) - 1); + assertEquals(1, fakeClock.numPendingTasks()); + + // Second retry + fakeClock.forwardMillis(1); + assertEquals(0, fakeClock.numPendingTasks()); + verify(mockListener).onUpdate(resultCaptor.capture(), any(Attributes.class)); + assertEquals(name, resolver.invocations.poll()); + assertAnswerMatches(answer, 81, resultCaptor.getValue()); + + verifyNoMoreInteractions(mockListener); + } + + @Test + public void refreshCancelsScheduledRetry() throws Exception { + String name = "foo.googleapis.com"; + UnknownHostException error = new UnknownHostException(name); + InetAddress[] answer = createAddressList(2); + MockResolver resolver = new MockResolver(name, 81, error, answer); + resolver.start(mockListener); + verify(mockListener).onError(statusCaptor.capture()); + assertEquals(name, resolver.invocations.poll()); + Status status = statusCaptor.getValue(); + assertEquals(Status.Code.UNAVAILABLE, status.getCode()); + assertSame(error, status.getCause()); + + // First retry scheduled + assertEquals(1, fakeClock.numPendingTasks()); + + resolver.refresh(); + // Refresh cancelled the retry + assertEquals(0, fakeClock.numPendingTasks()); + verify(mockListener).onUpdate(resultCaptor.capture(), any(Attributes.class)); + assertEquals(name, resolver.invocations.poll()); + assertAnswerMatches(answer, 81, resultCaptor.getValue()); + + verifyNoMoreInteractions(mockListener); + } + + @Test + public void shutdownCancelsScheduledRetry() throws Exception { + String name = "foo.googleapis.com"; + UnknownHostException error = new UnknownHostException(name); + MockResolver resolver = new MockResolver(name, 81, error); + resolver.start(mockListener); + verify(mockListener).onError(statusCaptor.capture()); + assertEquals(name, resolver.invocations.poll()); + Status status = statusCaptor.getValue(); + assertEquals(Status.Code.UNAVAILABLE, status.getCode()); + assertSame(error, status.getCause()); + + // Retry scheduled + assertEquals(1, fakeClock.numPendingTasks()); + + // Shutdown cancelled the retry + resolver.shutdown(); + assertEquals(0, fakeClock.numPendingTasks()); + + verifyNoMoreInteractions(mockListener); + } + private void testInvalidUri(URI uri) { try { factory.newNameResolver(uri, NAME_RESOLVER_PARAMS); @@ -88,4 +256,48 @@ public class DnsNameResolverTest { assertEquals(expectedPort, resolver.getPort()); assertEquals(exportedAuthority, resolver.getServiceAuthority()); } + + private byte lastByte = 0; + + private InetAddress[] createAddressList(int n) throws UnknownHostException { + InetAddress[] list = new InetAddress[n]; + for (int i = 0; i < n; i++) { + list[i] = InetAddress.getByAddress(new byte[] {127, 0, 0, ++lastByte}); + } + return list; + } + + private static void assertAnswerMatches(InetAddress[] addrs, int port, + List result) { + assertEquals(addrs.length, result.size()); + for (int i = 0; i < addrs.length; i++) { + InetSocketAddress socketAddr = (InetSocketAddress) result.get(i).getAddress(); + assertEquals("Addr " + i, port, socketAddr.getPort()); + assertEquals("Addr " + i, addrs[i], socketAddr.getAddress()); + } + } + + private class MockResolver extends DnsNameResolver { + final LinkedList answers = new LinkedList(); + final LinkedList invocations = new LinkedList(); + + MockResolver(String name, int defaultPort, Object ... answers) { + super(null, name, Attributes.newBuilder().set( + NameResolver.Factory.PARAMS_DEFAULT_PORT, defaultPort).build(), fakeTimerService, + fakeExecutor); + for (Object answer : answers) { + this.answers.add(answer); + } + } + + @Override + InetAddress[] getAllByName(String host) throws UnknownHostException { + invocations.add(host); + Object answer = answers.poll(); + if (answer instanceof UnknownHostException) { + throw (UnknownHostException) answer; + } + return (InetAddress[]) answer; + } + } } diff --git a/core/src/test/java/io/grpc/internal/FakeClock.java b/core/src/test/java/io/grpc/internal/FakeClock.java index 9f5f8cc40c..d0db8f29e6 100644 --- a/core/src/test/java/io/grpc/internal/FakeClock.java +++ b/core/src/test/java/io/grpc/internal/FakeClock.java @@ -47,9 +47,9 @@ import java.util.concurrent.TimeUnit; /** * A manipulated clock that exports a {@link Ticker} and a {@link ScheduledExecutorService}. */ -final class FakeClock { +public final class FakeClock { - final ScheduledExecutorService scheduledExecutorService = new ScheduledExecutorImpl(); + public final ScheduledExecutorService scheduledExecutorService = new ScheduledExecutorImpl(); final Ticker ticker = new Ticker() { @Override public long read() { return TimeUnit.MILLISECONDS.toNanos(currentTimeNanos); @@ -183,12 +183,16 @@ final class FakeClock { } } - void forwardTime(long value, TimeUnit unit) { + public void forwardTime(long value, TimeUnit unit) { currentTimeNanos += unit.toNanos(value); runDueTasks(); } - void forwardMillis(long millis) { + public void forwardMillis(long millis) { forwardTime(millis, TimeUnit.MILLISECONDS); } + + public int numPendingTasks() { + return tasks.size(); + } }