Update Load balancing URI parsing and add tests

This commit is contained in:
Carl Mastrangelo 2015-11-10 15:58:43 -08:00
parent 5529b6489b
commit b0c626c359
2 changed files with 162 additions and 46 deletions

View File

@ -33,6 +33,7 @@ package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
@ -145,38 +146,7 @@ public final class ManagedChannelImpl extends ManagedChannel {
this.executor = executor; this.executor = executor;
} }
this.backoffPolicyProvider = backoffPolicyProvider; this.backoffPolicyProvider = backoffPolicyProvider;
this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverParams);
// Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending
// "dns:///".
NameResolver nameResolver = null;
URI targetUri;
StringBuilder uriSyntaxErrors = new StringBuilder();
try {
targetUri = new URI(target);
nameResolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams);
// For "localhost:8080" this would likely return null, because "localhost" is parsed as the
// scheme. Will fall into the next branch and try "dns:///localhost:8080".
} catch (URISyntaxException e) {
// "foo.googleapis.com:8080" will trigger this exception, because "foo.googleapis.com" is an
// invalid scheme. Just fall through and will try "dns:///foo.googleapis.com:8080"
uriSyntaxErrors.append(e.getMessage());
}
if (nameResolver == null) {
try {
targetUri = new URI("dns:///" + target);
nameResolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams);
} catch (URISyntaxException e) {
if (uriSyntaxErrors.length() > 0) {
uriSyntaxErrors.append("; ");
}
uriSyntaxErrors.append(e.getMessage());
}
}
Preconditions.checkArgument(nameResolver != null,
"cannot find a NameResolver for %s%s", target,
uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors.toString() + ")" : "");
this.nameResolver = nameResolver;
this.loadBalancer = loadBalancerFactory.newLoadBalancer(nameResolver.getServiceAuthority(), tm); this.loadBalancer = loadBalancerFactory.newLoadBalancer(nameResolver.getServiceAuthority(), tm);
this.transportFactory = transportFactory; this.transportFactory = transportFactory;
this.userAgent = userAgent; this.userAgent = userAgent;
@ -197,6 +167,51 @@ public final class ManagedChannelImpl extends ManagedChannel {
}); });
} }
@VisibleForTesting
static NameResolver getNameResolver(String target, NameResolver.Factory nameResolverFactory,
Attributes nameResolverParams) {
// Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending
// "dns:///".
URI targetUri = null;
StringBuilder uriSyntaxErrors = new StringBuilder();
try {
targetUri = new URI(target);
// For "localhost:8080" this would likely cause newNameResolver to return null, because
// "localhost" is parsed as the scheme. Will fall into the next branch and try
// "dns:///localhost:8080".
} catch (URISyntaxException e) {
// Can happen with ip addresses like "[::1]:1234" or 127.0.0.1:1234. Also can happen with
// bogus urls like "dns:///[::1]:1234", which are not properly uriencoded.
uriSyntaxErrors.append(e.getMessage());
}
if (targetUri != null) {
NameResolver resolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams);
if (resolver != null) {
return resolver;
}
// "foo.googleapis.com:8080" cause resolver to be null, because "foo.googleapis.com" is an
// unmapped scheme. Just fall through and will try "dns:///foo.googleapis.com:8080"
}
// If we reached here, the targetUri couldn't be used, so try again.
try {
targetUri = new URI("dns", null, "/" + target, null);
} catch (URISyntaxException e) {
// Should not be possible.
throw new IllegalArgumentException(e);
}
if (targetUri != null) {
NameResolver resolver = nameResolverFactory.newNameResolver(targetUri, nameResolverParams);
if (resolver != null) {
return resolver;
}
}
throw new IllegalArgumentException(String.format(
"cannot find a NameResolver for %s%s",
target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors.toString() + ")" : ""));
}
/** /**
* Sets the default compression method for this Channel. By default, new calls will use the * Sets the default compression method for this Channel. By default, new calls will use the
* provided compressor. Each individual Call can override this by specifying it in CallOptions. * provided compressor. Each individual Call can override this by specifying it in CallOptions.

View File

@ -59,6 +59,7 @@ import io.grpc.ManagedChannel;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import io.grpc.NameResolver.Factory;
import io.grpc.ResolvedServerInfo; import io.grpc.ResolvedServerInfo;
import io.grpc.SimpleLoadBalancerFactory; import io.grpc.SimpleLoadBalancerFactory;
import io.grpc.Status; import io.grpc.Status;
@ -66,7 +67,9 @@ import io.grpc.StringMarshaller;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -77,6 +80,7 @@ import org.mockito.stubbing.Answer;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -100,16 +104,16 @@ public class ManagedChannelImplTest {
private final String serviceName = "fake.example.com"; private final String serviceName = "fake.example.com";
private final String authority = serviceName; private final String authority = serviceName;
private final String target = "fake://" + serviceName; private final String target = "fake://" + serviceName;
private URI expectedUri;
private final SocketAddress socketAddress = new SocketAddress() {}; private final SocketAddress socketAddress = new SocketAddress() {};
private final ResolvedServerInfo server = new ResolvedServerInfo(socketAddress, Attributes.EMPTY); private final ResolvedServerInfo server = new ResolvedServerInfo(socketAddress, Attributes.EMPTY);
@Rule public final ExpectedException thrown = ExpectedException.none();
@Mock @Mock
private ClientTransport mockTransport; private ClientTransport mockTransport;
@Mock @Mock
private ClientTransportFactory mockTransportFactory; private ClientTransportFactory mockTransportFactory;
private ManagedChannel channel;
@Mock @Mock
private ClientCall.Listener<Integer> mockCallListener; private ClientCall.Listener<Integer> mockCallListener;
@Mock @Mock
@ -132,6 +136,7 @@ public class ManagedChannelImplTest {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
expectedUri = new URI(target);
when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class))) when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class)))
.thenReturn(mockTransport); .thenReturn(mockTransport);
} }
@ -144,7 +149,7 @@ public class ManagedChannelImplTest {
@Test @Test
public void immediateDeadlineExceeded() { public void immediateDeadlineExceeded() {
ManagedChannel channel = createChannel( ManagedChannel channel = createChannel(
new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR);
ClientCall<String, Integer> call = ClientCall<String, Integer> call =
channel.newCall(method, CallOptions.DEFAULT.withDeadlineNanoTime(System.nanoTime())); channel.newCall(method, CallOptions.DEFAULT.withDeadlineNanoTime(System.nanoTime()));
call.start(mockCallListener, new Metadata()); call.start(mockCallListener, new Metadata());
@ -155,7 +160,7 @@ public class ManagedChannelImplTest {
@Test @Test
public void shutdownWithNoTransportsEverCreated() { public void shutdownWithNoTransportsEverCreated() {
ManagedChannel channel = createChannel( ManagedChannel channel = createChannel(
new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR);
verifyNoMoreInteractions(mockTransportFactory); verifyNoMoreInteractions(mockTransportFactory);
channel.shutdown(); channel.shutdown();
assertTrue(channel.isShutdown()); assertTrue(channel.isShutdown());
@ -165,7 +170,7 @@ public class ManagedChannelImplTest {
@Test @Test
public void twoCallsAndGracefulShutdown() { public void twoCallsAndGracefulShutdown() {
ManagedChannel channel = createChannel( ManagedChannel channel = createChannel(
new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR);
verifyNoMoreInteractions(mockTransportFactory); verifyNoMoreInteractions(mockTransportFactory);
ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT);
verifyNoMoreInteractions(mockTransportFactory); verifyNoMoreInteractions(mockTransportFactory);
@ -232,6 +237,98 @@ public class ManagedChannelImplTest {
verifyNoMoreInteractions(mockStream); verifyNoMoreInteractions(mockStream);
} }
@Test
public void getNameResolver_invalidUriWithoutScheme() {
Factory nameResolverFactory = new FakeNameResolverFactory(server, expectedUri, true);
thrown.expect(IllegalArgumentException.class);
ManagedChannelImpl.getNameResolver("[invalid", nameResolverFactory, Attributes.EMPTY);
}
@Test
public void getNameResolver_invalidUriWithScheme() {
Factory nameResolverFactory = new FakeNameResolverFactory(server, expectedUri, true);
thrown.expect(IllegalArgumentException.class);
ManagedChannelImpl.getNameResolver("scheme://[invalid", nameResolverFactory, Attributes.EMPTY);
}
@Test
public void getNameResolver_validHost() {
Factory nameResolverFactory = new FakeNameResolverFactory(server, expectedUri, true);
NameResolver res = ManagedChannelImpl.getNameResolver(
target, nameResolverFactory, NAME_RESOLVER_PARAMS);
assertEquals(serviceName, res.getServiceAuthority());
}
@Test
public void getNameResolver_validHostWithoutSchema() throws URISyntaxException {
expectedUri = new URI("dns:///foo.googleapis.com:8080");
Factory nameResolverFactory = new NameResolver.Factory() {
@Override
public NameResolver newNameResolver(URI targetUri, Attributes params) {
if (targetUri.equals(expectedUri)) {
NameResolver resolver = mock(NameResolver.class);
when(resolver.getServiceAuthority()).thenReturn("foo.googleapis.com:8080");
return resolver;
}
return null;
}
};
NameResolver res = ManagedChannelImpl.getNameResolver(
"foo.googleapis.com:8080", nameResolverFactory, NAME_RESOLVER_PARAMS);
assertEquals("foo.googleapis.com:8080", res.getServiceAuthority());
}
@Test
public void getNameResolver_validIpHostWithoutSchema() {
Factory nameResolverFactory = new NameResolver.Factory() {
@Override
public NameResolver newNameResolver(URI targetUri, Attributes params) {
NameResolver resolver = mock(NameResolver.class);
when(resolver.getServiceAuthority()).thenReturn("127.0.0.1:8080");
return resolver;
}
};
NameResolver res = ManagedChannelImpl.getNameResolver(
"127.0.0.1:8080", nameResolverFactory, NAME_RESOLVER_PARAMS);
assertEquals("127.0.0.1:8080", res.getServiceAuthority());
}
@Test
public void getNameResolver_validTargetNoResovler() {
Factory nameResolverFactory = new NameResolver.Factory() {
@Override
public NameResolver newNameResolver(URI targetUri, Attributes params) {
return null;
}
};
thrown.expect(IllegalArgumentException.class);
ManagedChannelImpl.getNameResolver(target, nameResolverFactory, NAME_RESOLVER_PARAMS);
}
@Test
public void getNameResolver_validTargetDnsResovler() {
Factory nameResolverFactory = new NameResolver.Factory() {
@Override
public NameResolver newNameResolver(URI targetUri, Attributes params) {
if (targetUri.getScheme().equals("dns")) {
return mock(NameResolver.class);
}
return null;
}
};
ManagedChannelImpl.getNameResolver("[::1]:1234", nameResolverFactory, NAME_RESOLVER_PARAMS);
}
@Test @Test
public void interceptor() throws Exception { public void interceptor() throws Exception {
final AtomicLong atomic = new AtomicLong(); final AtomicLong atomic = new AtomicLong();
@ -245,7 +342,7 @@ public class ManagedChannelImplTest {
} }
}; };
ManagedChannel channel = createChannel( ManagedChannel channel = createChannel(
new FakeNameResolverFactory(server, true), Arrays.asList(interceptor)); new FakeNameResolverFactory(server, expectedUri, true), Arrays.asList(interceptor));
assertNotNull(channel.newCall(method, CallOptions.DEFAULT)); assertNotNull(channel.newCall(method, CallOptions.DEFAULT));
assertEquals(1, atomic.get()); assertEquals(1, atomic.get());
} }
@ -253,7 +350,7 @@ public class ManagedChannelImplTest {
@Test @Test
public void testNoDeadlockOnShutdown() { public void testNoDeadlockOnShutdown() {
ManagedChannel channel = createChannel( ManagedChannel channel = createChannel(
new FakeNameResolverFactory(server, true), NO_INTERCEPTOR); new FakeNameResolverFactory(server, expectedUri, true), NO_INTERCEPTOR);
// Force creation of transport // Force creation of transport
ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
@ -317,7 +414,8 @@ public class ManagedChannelImplTest {
@Test @Test
public void nameResolvedAfterChannelShutdown() { public void nameResolvedAfterChannelShutdown() {
FakeNameResolverFactory nameResolverFactory = new FakeNameResolverFactory(server, false); FakeNameResolverFactory nameResolverFactory =
new FakeNameResolverFactory(server, expectedUri, false);
ManagedChannel channel = createChannel(nameResolverFactory, NO_INTERCEPTOR); ManagedChannel channel = createChannel(nameResolverFactory, NO_INTERCEPTOR);
ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
@ -344,20 +442,23 @@ public class ManagedChannelImplTest {
} }
} }
private class FakeNameResolverFactory extends NameResolver.Factory { private static class FakeNameResolverFactory extends NameResolver.Factory {
final ResolvedServerInfo server; final ResolvedServerInfo server;
final boolean resolvedAtStart; final boolean resolvedAtStart;
final ArrayList<FakeNameResolver> resolvers = new ArrayList<FakeNameResolver>(); final ArrayList<FakeNameResolver> resolvers = new ArrayList<FakeNameResolver>();
final URI expectedUri;
FakeNameResolverFactory(ResolvedServerInfo server, boolean resolvedAtStart) { FakeNameResolverFactory(ResolvedServerInfo server, URI expectedUri, boolean resolvedAtStart) {
this.server = server; this.server = server;
this.resolvedAtStart = resolvedAtStart; this.resolvedAtStart = resolvedAtStart;
this.expectedUri = expectedUri;
} }
@Override @Override
public NameResolver newNameResolver(final URI targetUri, Attributes params) { public NameResolver newNameResolver(final URI targetUri, Attributes params) {
assertEquals("fake", targetUri.getScheme()); if (!expectedUri.equals(targetUri)) {
assertEquals(serviceName, targetUri.getAuthority()); return null;
}
assertSame(NAME_RESOLVER_PARAMS, params); assertSame(NAME_RESOLVER_PARAMS, params);
FakeNameResolver resolver = new FakeNameResolver(); FakeNameResolver resolver = new FakeNameResolver();
resolvers.add(resolver); resolvers.add(resolver);
@ -374,7 +475,7 @@ public class ManagedChannelImplTest {
Listener listener; Listener listener;
@Override public String getServiceAuthority() { @Override public String getServiceAuthority() {
return serviceName; return expectedUri.getAuthority();
} }
@Override public void start(final Listener listener) { @Override public void start(final Listener listener) {