diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index debc2d0fff..ab70e4ed25 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -33,6 +33,7 @@ package io.grpc.internal; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -145,38 +146,7 @@ public final class ManagedChannelImpl extends ManagedChannel { this.executor = executor; } this.backoffPolicyProvider = backoffPolicyProvider; - - // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending - // "dns:///". - NameResolver nameResolver = null; - URI targetUri; - StringBuilder uriSyntaxErrors = new StringBuilder(); - try { - targetUri = new URI(target); - nameResolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams); - // For "localhost:8080" this would likely return null, because "localhost" is parsed as the - // scheme. Will fall into the next branch and try "dns:///localhost:8080". - } catch (URISyntaxException e) { - // "foo.googleapis.com:8080" will trigger this exception, because "foo.googleapis.com" is an - // invalid scheme. Just fall through and will try "dns:///foo.googleapis.com:8080" - uriSyntaxErrors.append(e.getMessage()); - } - if (nameResolver == null) { - try { - targetUri = new URI("dns:///" + target); - nameResolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams); - } catch (URISyntaxException e) { - if (uriSyntaxErrors.length() > 0) { - uriSyntaxErrors.append("; "); - } - uriSyntaxErrors.append(e.getMessage()); - } - } - Preconditions.checkArgument(nameResolver != null, - "cannot find a NameResolver for %s%s", target, - uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors.toString() + ")" : ""); - this.nameResolver = nameResolver; - + this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverParams); this.loadBalancer = loadBalancerFactory.newLoadBalancer(nameResolver.getServiceAuthority(), tm); this.transportFactory = transportFactory; this.userAgent = userAgent; @@ -197,6 +167,51 @@ public final class ManagedChannelImpl extends ManagedChannel { }); } + @VisibleForTesting + static NameResolver getNameResolver(String target, NameResolver.Factory nameResolverFactory, + Attributes nameResolverParams) { + // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending + // "dns:///". + URI targetUri = null; + StringBuilder uriSyntaxErrors = new StringBuilder(); + try { + targetUri = new URI(target); + // For "localhost:8080" this would likely cause newNameResolver to return null, because + // "localhost" is parsed as the scheme. Will fall into the next branch and try + // "dns:///localhost:8080". + } catch (URISyntaxException e) { + // Can happen with ip addresses like "[::1]:1234" or 127.0.0.1:1234. Also can happen with + // bogus urls like "dns:///[::1]:1234", which are not properly uriencoded. + uriSyntaxErrors.append(e.getMessage()); + } + if (targetUri != null) { + NameResolver resolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams); + if (resolver != null) { + return resolver; + } + // "foo.googleapis.com:8080" cause resolver to be null, because "foo.googleapis.com" is an + // unmapped scheme. Just fall through and will try "dns:///foo.googleapis.com:8080" + } + + // If we reached here, the targetUri couldn't be used, so try again. + try { + targetUri = new URI("dns", null, "/" + target, null); + } catch (URISyntaxException e) { + // Should not be possible. + throw new IllegalArgumentException(e); + } + if (targetUri != null) { + NameResolver resolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams); + if (resolver != null) { + return resolver; + } + } + throw new IllegalArgumentException(String.format( + "cannot find a NameResolver for %s%s", + target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors.toString() + ")" : "")); + } + + /** * Sets the default compression method for this Channel. By default, new calls will use the * provided compressor. Each individual Call can override this by specifying it in CallOptions. diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index a09bc1b8ba..a4989c8f04 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -59,6 +59,7 @@ import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.NameResolver; +import io.grpc.NameResolver.Factory; import io.grpc.ResolvedServerInfo; import io.grpc.SimpleLoadBalancerFactory; import io.grpc.Status; @@ -66,7 +67,9 @@ import io.grpc.StringMarshaller; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -77,6 +80,7 @@ import org.mockito.stubbing.Answer; import java.net.SocketAddress; import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -100,16 +104,16 @@ public class ManagedChannelImplTest { private final String serviceName = "fake.example.com"; private final String authority = serviceName; private final String target = "fake://" + serviceName; + private URI expectedUri; private final SocketAddress socketAddress = new SocketAddress() {}; private final ResolvedServerInfo server = new ResolvedServerInfo(socketAddress, Attributes.EMPTY); + @Rule public final ExpectedException thrown = ExpectedException.none(); + @Mock private ClientTransport mockTransport; @Mock private ClientTransportFactory mockTransportFactory; - - private ManagedChannel channel; - @Mock private ClientCall.Listener mockCallListener; @Mock @@ -132,6 +136,7 @@ public class ManagedChannelImplTest { @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); + expectedUri = new URI(target); when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class))) .thenReturn(mockTransport); } @@ -144,7 +149,7 @@ public class ManagedChannelImplTest { @Test public void immediateDeadlineExceeded() { ManagedChannel channel = createChannel( - new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); + new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR); ClientCall call = channel.newCall(method, CallOptions.DEFAULT.withDeadlineNanoTime(System.nanoTime())); call.start(mockCallListener, new Metadata()); @@ -155,7 +160,7 @@ public class ManagedChannelImplTest { @Test public void shutdownWithNoTransportsEverCreated() { ManagedChannel channel = createChannel( - new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); + new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR); verifyNoMoreInteractions(mockTransportFactory); channel.shutdown(); assertTrue(channel.isShutdown()); @@ -165,7 +170,7 @@ public class ManagedChannelImplTest { @Test public void twoCallsAndGracefulShutdown() { ManagedChannel channel = createChannel( - new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); + new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR); verifyNoMoreInteractions(mockTransportFactory); ClientCall call = channel.newCall(method, CallOptions.DEFAULT); verifyNoMoreInteractions(mockTransportFactory); @@ -232,6 +237,98 @@ public class ManagedChannelImplTest { verifyNoMoreInteractions(mockStream); } + @Test + public void getNameResolver_invalidUriWithoutScheme() { + Factory nameResolverFactory = new FakeNameResolverFactory(server, expectedUri, true); + thrown.expect(IllegalArgumentException.class); + + ManagedChannelImpl.getNameResolver("[invalid", nameResolverFactory, Attributes.EMPTY); + } + + @Test + public void getNameResolver_invalidUriWithScheme() { + Factory nameResolverFactory = new FakeNameResolverFactory(server, expectedUri, true); + thrown.expect(IllegalArgumentException.class); + + ManagedChannelImpl.getNameResolver("scheme://[invalid", nameResolverFactory, Attributes.EMPTY); + } + + @Test + public void getNameResolver_validHost() { + Factory nameResolverFactory = new FakeNameResolverFactory(server, expectedUri, true); + + NameResolver res = ManagedChannelImpl.getNameResolver( + target, nameResolverFactory, NAME_RESOLVER_PARAMS); + + assertEquals(serviceName, res.getServiceAuthority()); + } + + @Test + public void getNameResolver_validHostWithoutSchema() throws URISyntaxException { + expectedUri = new URI("dns:///foo.googleapis.com:8080"); + Factory nameResolverFactory = new NameResolver.Factory() { + @Override + public NameResolver newNameResolver(URI targetUri, Attributes params) { + if (targetUri.equals(expectedUri)) { + NameResolver resolver = mock(NameResolver.class); + when(resolver.getServiceAuthority()).thenReturn("foo.googleapis.com:8080"); + return resolver; + } + return null; + } + }; + + NameResolver res = ManagedChannelImpl.getNameResolver( + "foo.googleapis.com:8080", nameResolverFactory, NAME_RESOLVER_PARAMS); + + assertEquals("foo.googleapis.com:8080", res.getServiceAuthority()); + } + + @Test + public void getNameResolver_validIpHostWithoutSchema() { + Factory nameResolverFactory = new NameResolver.Factory() { + @Override + public NameResolver newNameResolver(URI targetUri, Attributes params) { + NameResolver resolver = mock(NameResolver.class); + when(resolver.getServiceAuthority()).thenReturn("127.0.0.1:8080"); + return resolver; + } + }; + + NameResolver res = ManagedChannelImpl.getNameResolver( + "127.0.0.1:8080", nameResolverFactory, NAME_RESOLVER_PARAMS); + + assertEquals("127.0.0.1:8080", res.getServiceAuthority()); + } + + @Test + public void getNameResolver_validTargetNoResovler() { + Factory nameResolverFactory = new NameResolver.Factory() { + @Override + public NameResolver newNameResolver(URI targetUri, Attributes params) { + return null; + } + }; + thrown.expect(IllegalArgumentException.class); + + ManagedChannelImpl.getNameResolver(target, nameResolverFactory, NAME_RESOLVER_PARAMS); + } + + @Test + public void getNameResolver_validTargetDnsResovler() { + Factory nameResolverFactory = new NameResolver.Factory() { + @Override + public NameResolver newNameResolver(URI targetUri, Attributes params) { + if (targetUri.getScheme().equals("dns")) { + return mock(NameResolver.class); + } + return null; + } + }; + + ManagedChannelImpl.getNameResolver("[::1]:1234", nameResolverFactory, NAME_RESOLVER_PARAMS); + } + @Test public void interceptor() throws Exception { final AtomicLong atomic = new AtomicLong(); @@ -245,7 +342,7 @@ public class ManagedChannelImplTest { } }; ManagedChannel channel = createChannel( - new FakeNameResolverFactory(server, true), Arrays.asList(interceptor)); + new FakeNameResolverFactory(server, expectedUri, true), Arrays.asList(interceptor)); assertNotNull(channel.newCall(method, CallOptions.DEFAULT)); assertEquals(1, atomic.get()); } @@ -253,7 +350,7 @@ public class ManagedChannelImplTest { @Test public void testNoDeadlockOnShutdown() { ManagedChannel channel = createChannel( - new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); + new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR); // Force creation of transport ClientCall call = channel.newCall(method, CallOptions.DEFAULT); Metadata headers = new Metadata(); @@ -317,7 +414,8 @@ public class ManagedChannelImplTest { @Test public void nameResolvedAfterChannelShutdown() { - FakeNameResolverFactory nameResolverFactory = new FakeNameResolverFactory(server, false); + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory(server, expectedUri, false); ManagedChannel channel = createChannel(nameResolverFactory, NO_INTERCEPTOR); ClientCall call = channel.newCall(method, CallOptions.DEFAULT); Metadata headers = new Metadata(); @@ -344,20 +442,23 @@ public class ManagedChannelImplTest { } } - private class FakeNameResolverFactory extends NameResolver.Factory { + private static class FakeNameResolverFactory extends NameResolver.Factory { final ResolvedServerInfo server; final boolean resolvedAtStart; final ArrayList resolvers = new ArrayList(); + final URI expectedUri; - FakeNameResolverFactory(ResolvedServerInfo server, boolean resolvedAtStart) { + FakeNameResolverFactory(ResolvedServerInfo server, URI expectedUri, boolean resolvedAtStart) { this.server = server; this.resolvedAtStart = resolvedAtStart; + this.expectedUri = expectedUri; } @Override public NameResolver newNameResolver(final URI targetUri, Attributes params) { - assertEquals("fake", targetUri.getScheme()); - assertEquals(serviceName, targetUri.getAuthority()); + if (!expectedUri.equals(targetUri)) { + return null; + } assertSame(NAME_RESOLVER_PARAMS, params); FakeNameResolver resolver = new FakeNameResolver(); resolvers.add(resolver); @@ -374,7 +475,7 @@ public class ManagedChannelImplTest { Listener listener; @Override public String getServiceAuthority() { - return serviceName; + return expectedUri.getAuthority(); } @Override public void start(final Listener listener) {