Plumb target to load balancer

gRFC A78 has WRR and pick-first include a `grpc.target` label, defined
in A66:

> `grpc.target` : Canonicalized target URI used when creating gRPC
> Channel, e.g. "dns:///pubsub.googleapis.com:443",
> "xds:///helloworld-gke:8000". Canonicalized target URI is the form
> with the scheme included if the user didn't mention the scheme
> (`scheme://[authority]/path`). For channels such as inprocess channels
> where a target URI is not available, implementations can synthesize a
> target URI.
This commit is contained in:
Eric Anderson 2024-04-30 08:57:55 -07:00
parent 27d57585cd
commit 4561bb5b80
5 changed files with 145 additions and 101 deletions

View File

@ -1185,6 +1185,13 @@ public abstract class LoadBalancer {
*/ */
public abstract String getAuthority(); public abstract String getAuthority();
/**
* Returns the target string of the channel, guaranteed to include its scheme.
*/
public String getChannelTarget() {
throw new UnsupportedOperationException();
}
/** /**
* Returns the ChannelCredentials used to construct the channel, without bearer tokens. * Returns the ChannelCredentials used to construct the channel, without bearer tokens.
* *

View File

@ -166,6 +166,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Nullable @Nullable
private final String authorityOverride; private final String authorityOverride;
private final NameResolverRegistry nameResolverRegistry; private final NameResolverRegistry nameResolverRegistry;
private final URI targetUri;
private final NameResolverProvider nameResolverProvider;
private final NameResolver.Args nameResolverArgs; private final NameResolver.Args nameResolverArgs;
private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; private final AutoConfiguredLoadBalancerFactory loadBalancerFactory;
private final ClientTransportFactory originalTransportFactory; private final ClientTransportFactory originalTransportFactory;
@ -383,8 +385,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
nameResolverStarted = false; nameResolverStarted = false;
if (channelIsActive) { if (channelIsActive) {
nameResolver = getNameResolver( nameResolver = getNameResolver(
target, authorityOverride, nameResolverRegistry, nameResolverArgs, targetUri, authorityOverride, nameResolverProvider, nameResolverArgs);
transportFactory.getSupportedSocketAddressTypes());
} else { } else {
nameResolver = null; nameResolver = null;
} }
@ -621,6 +622,10 @@ final class ManagedChannelImpl extends ManagedChannel implements
this.retryEnabled = builder.retryEnabled; this.retryEnabled = builder.retryEnabled;
this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy); this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy);
this.nameResolverRegistry = builder.nameResolverRegistry; this.nameResolverRegistry = builder.nameResolverRegistry;
ResolvedNameResolver resolvedResolver = getNameResolverProvider(
target, nameResolverRegistry, transportFactory.getSupportedSocketAddressTypes());
this.targetUri = resolvedResolver.targetUri;
this.nameResolverProvider = resolvedResolver.provider;
ScParser serviceConfigParser = ScParser serviceConfigParser =
new ScParser( new ScParser(
retryEnabled, retryEnabled,
@ -640,8 +645,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
.setOverrideAuthority(this.authorityOverride) .setOverrideAuthority(this.authorityOverride)
.build(); .build();
this.nameResolver = getNameResolver( this.nameResolver = getNameResolver(
target, authorityOverride, nameResolverRegistry, nameResolverArgs, targetUri, authorityOverride, nameResolverProvider, nameResolverArgs);
transportFactory.getSupportedSocketAddressTypes());
this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool"); this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool");
this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool); this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool);
this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext); this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext);
@ -713,8 +717,20 @@ final class ManagedChannelImpl extends ManagedChannel implements
} }
} }
private static NameResolver getNameResolver( @VisibleForTesting
String target, NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, static class ResolvedNameResolver {
public final URI targetUri;
public final NameResolverProvider provider;
public ResolvedNameResolver(URI targetUri, NameResolverProvider provider) {
this.targetUri = checkNotNull(targetUri, "targetUri");
this.provider = checkNotNull(provider, "provider");
}
}
@VisibleForTesting
static ResolvedNameResolver getNameResolverProvider(
String target, NameResolverRegistry nameResolverRegistry,
Collection<Class<? extends SocketAddress>> channelTransportSocketAddressTypes) { Collection<Class<? extends SocketAddress>> channelTransportSocketAddressTypes) {
// Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending
// "dns:///". // "dns:///".
@ -761,23 +777,17 @@ final class ManagedChannelImpl extends ManagedChannel implements
} }
} }
NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs); return new ResolvedNameResolver(targetUri, provider);
if (resolver != null) {
return resolver;
}
throw new IllegalArgumentException(String.format(
"cannot create a NameResolver for %s%s",
target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : ""));
} }
@VisibleForTesting @VisibleForTesting
static NameResolver getNameResolver( static NameResolver getNameResolver(
String target, @Nullable final String overrideAuthority, URI targetUri, @Nullable final String overrideAuthority,
NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, NameResolverProvider provider, NameResolver.Args nameResolverArgs) {
Collection<Class<? extends SocketAddress>> channelTransportSocketAddressTypes) { NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs);
NameResolver resolver = getNameResolver(target, nameResolverRegistry, nameResolverArgs, if (resolver == null) {
channelTransportSocketAddressTypes); throw new IllegalArgumentException("cannot create a NameResolver for " + targetUri);
}
// We wrap the name resolver in a RetryingNameResolver to give it the ability to retry failures. // We wrap the name resolver in a RetryingNameResolver to give it the ability to retry failures.
// TODO: After a transition period, all NameResolver implementations that need retry should use // TODO: After a transition period, all NameResolver implementations that need retry should use
@ -1703,6 +1713,11 @@ final class ManagedChannelImpl extends ManagedChannel implements
return ManagedChannelImpl.this.authority(); return ManagedChannelImpl.this.authority();
} }
@Override
public String getChannelTarget() {
return targetUri.toString();
}
@Override @Override
public SynchronizationContext getSynchronizationContext() { public SynchronizationContext getSynchronizationContext() {
return syncContext; return syncContext;

View File

@ -17,21 +17,12 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import io.grpc.ChannelLogger;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import io.grpc.NameResolver.Args;
import io.grpc.NameResolver.ServiceConfigParser;
import io.grpc.NameResolverProvider; import io.grpc.NameResolverProvider;
import io.grpc.NameResolverRegistry; import io.grpc.NameResolverRegistry;
import io.grpc.ProxyDetector;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.inprocess.InProcessSocketAddress;
import java.lang.Thread.UncaughtExceptionHandler;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
import java.util.Collections; import java.util.Collections;
@ -39,18 +30,9 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
/** Unit tests for ManagedChannelImpl#getNameResolver(). */ /** Unit tests for ManagedChannelImpl#getNameResolverProvider(). */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class ManagedChannelImplGetNameResolverTest { public class ManagedChannelImplGetNameResolverTest {
private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder()
.setDefaultPort(447)
.setProxyDetector(mock(ProxyDetector.class))
.setSynchronizationContext(new SynchronizationContext(mock(UncaughtExceptionHandler.class)))
.setServiceConfigParser(mock(ServiceConfigParser.class))
.setChannelLogger(mock(ChannelLogger.class))
.setScheduledExecutorService(new FakeClock().getScheduledExecutorService())
.build();
@Test @Test
public void invalidUriTarget() { public void invalidUriTarget() {
testInvalidTarget("defaultscheme:///[invalid]"); testInvalidTarget("defaultscheme:///[invalid]");
@ -68,18 +50,6 @@ public class ManagedChannelImplGetNameResolverTest {
new URI("defaultscheme", "", "/foo.googleapis.com:8080", null)); new URI("defaultscheme", "", "/foo.googleapis.com:8080", null));
} }
@Test
public void validAuthorityTarget_overrideAuthority() throws Exception {
String target = "foo.googleapis.com:8080";
String overrideAuthority = "override.authority";
URI expectedUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null);
NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme());
NameResolver nameResolver = ManagedChannelImpl.getNameResolver(
target, overrideAuthority, nameResolverRegistry, NAMERESOLVER_ARGS,
Collections.singleton(InetSocketAddress.class));
assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority);
}
@Test @Test
public void validUriTarget() throws Exception { public void validUriTarget() throws Exception {
testValidTarget("scheme:///foo.googleapis.com:8080", "scheme:///foo.googleapis.com:8080", testValidTarget("scheme:///foo.googleapis.com:8080", "scheme:///foo.googleapis.com:8080",
@ -121,47 +91,12 @@ public class ManagedChannelImplGetNameResolverTest {
new URI("defaultscheme", "", "//target", null)); new URI("defaultscheme", "", "//target", null));
} }
@Test
public void validTargetNoResolver() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
NameResolverProvider nameResolverProvider = new NameResolverProvider() {
@Override
protected boolean isAvailable() {
return true;
}
@Override
protected int priority() {
return 5;
}
@Override
public NameResolver newNameResolver(URI targetUri, Args args) {
return null;
}
@Override
public String getDefaultScheme() {
return "defaultscheme";
}
};
nameResolverRegistry.register(nameResolverProvider);
try {
ManagedChannelImpl.getNameResolver(
"foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS,
Collections.singleton(InetSocketAddress.class));
fail("Should fail");
} catch (IllegalArgumentException e) {
// expected
}
}
@Test @Test
public void validTargetNoProvider() { public void validTargetNoProvider() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
try { try {
ManagedChannelImpl.getNameResolver( ManagedChannelImpl.getNameResolverProvider(
"foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, "foo.googleapis.com:8080", nameResolverRegistry,
Collections.singleton(InetSocketAddress.class)); Collections.singleton(InetSocketAddress.class));
fail("Should fail"); fail("Should fail");
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
@ -173,8 +108,8 @@ public class ManagedChannelImplGetNameResolverTest {
public void validTargetProviderAddrTypesNotSupported() { public void validTargetProviderAddrTypesNotSupported() {
NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme");
try { try {
ManagedChannelImpl.getNameResolver( ManagedChannelImpl.getNameResolverProvider(
"testscheme:///foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, "testscheme:///foo.googleapis.com:8080", nameResolverRegistry,
Collections.singleton(InProcessSocketAddress.class)); Collections.singleton(InProcessSocketAddress.class));
fail("Should fail"); fail("Should fail");
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
@ -184,26 +119,23 @@ public class ManagedChannelImplGetNameResolverTest {
} }
} }
private void testValidTarget(String target, String expectedUriString, URI expectedUri) { private void testValidTarget(String target, String expectedUriString, URI expectedUri) {
NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme());
FakeNameResolver nameResolver ManagedChannelImpl.ResolvedNameResolver resolved = ManagedChannelImpl.getNameResolverProvider(
= (FakeNameResolver) ((RetryingNameResolver) ManagedChannelImpl.getNameResolver( target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class));
target, null, nameResolverRegistry, NAMERESOLVER_ARGS, assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class);
Collections.singleton(InetSocketAddress.class))).getRetriedNameResolver(); assertThat(resolved.targetUri).isEqualTo(expectedUri);
assertNotNull(nameResolver); assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString);
assertEquals(expectedUri, nameResolver.uri);
assertEquals(expectedUriString, nameResolver.uri.toString());
} }
private void testInvalidTarget(String target) { private void testInvalidTarget(String target) {
NameResolverRegistry nameResolverRegistry = getTestRegistry("dns"); NameResolverRegistry nameResolverRegistry = getTestRegistry("dns");
try { try {
FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver( ManagedChannelImpl.ResolvedNameResolver resolved = ManagedChannelImpl.getNameResolverProvider(
target, null, nameResolverRegistry, NAMERESOLVER_ARGS, target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class));
Collections.singleton(InetSocketAddress.class)); FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider;
fail("Should have failed, but got resolver with " + nameResolver.uri); fail("Should have failed, but got resolver provider " + nameResolverProvider);
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
// expected // expected
} }

View File

@ -104,6 +104,7 @@ import io.grpc.MethodDescriptor.MethodType;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ConfigOrError;
import io.grpc.NameResolver.ResolutionResult; import io.grpc.NameResolver.ResolutionResult;
import io.grpc.NameResolverProvider;
import io.grpc.NameResolverRegistry; import io.grpc.NameResolverRegistry;
import io.grpc.ProxiedSocketAddress; import io.grpc.ProxiedSocketAddress;
import io.grpc.ProxyDetector; import io.grpc.ProxyDetector;
@ -112,6 +113,7 @@ import io.grpc.ServerMethodDefinition;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.Status.Code; import io.grpc.Status.Code;
import io.grpc.StringMarshaller; import io.grpc.StringMarshaller;
import io.grpc.SynchronizationContext;
import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.InternalSubchannel.TransportLogger; import io.grpc.internal.InternalSubchannel.TransportLogger;
@ -188,6 +190,15 @@ public class ManagedChannelImplTest {
.setUserAgent(USER_AGENT); .setUserAgent(USER_AGENT);
private static final String TARGET = "fake://" + SERVICE_NAME; private static final String TARGET = "fake://" + SERVICE_NAME;
private static final String MOCK_POLICY_NAME = "mock_lb"; private static final String MOCK_POLICY_NAME = "mock_lb";
private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder()
.setDefaultPort(447)
.setProxyDetector(mock(ProxyDetector.class))
.setSynchronizationContext(
new SynchronizationContext(mock(Thread.UncaughtExceptionHandler.class)))
.setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class))
.setScheduledExecutorService(new FakeClock().getScheduledExecutorService())
.build();
private URI expectedUri; private URI expectedUri;
private final SocketAddress socketAddress = private final SocketAddress socketAddress =
new SocketAddress() { new SocketAddress() {
@ -4306,6 +4317,80 @@ public class ManagedChannelImplTest {
assertEquals(1, terminationCallbackCalled.get()); assertEquals(1, terminationCallbackCalled.get());
} }
@Test
public void validAuthorityTarget_overrideAuthority() throws Exception {
String overrideAuthority = "override.authority";
String serviceAuthority = "fakeauthority";
NameResolverProvider nameResolverProvider = new NameResolverProvider() {
@Override protected boolean isAvailable() {
return true;
}
@Override protected int priority() {
return 5;
}
@Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return new NameResolver() {
@Override public String getServiceAuthority() {
return serviceAuthority;
}
@Override public void start(final Listener2 listener) {}
@Override public void shutdown() {}
};
}
@Override public String getDefaultScheme() {
return "defaultscheme";
}
};
URI targetUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null);
NameResolver nameResolver = ManagedChannelImpl.getNameResolver(
targetUri, null, nameResolverProvider, NAMERESOLVER_ARGS);
assertThat(nameResolver.getServiceAuthority()).isEqualTo(serviceAuthority);
nameResolver = ManagedChannelImpl.getNameResolver(
targetUri, overrideAuthority, nameResolverProvider, NAMERESOLVER_ARGS);
assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority);
}
@Test
public void validTargetNoResolver_throws() {
NameResolverProvider nameResolverProvider = new NameResolverProvider() {
@Override
protected boolean isAvailable() {
return true;
}
@Override
protected int priority() {
return 5;
}
@Override
public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return null;
}
@Override
public String getDefaultScheme() {
return "defaultscheme";
}
};
try {
ManagedChannelImpl.getNameResolver(
URI.create("defaultscheme:///foo.gogoleapis.com:8080"),
null, nameResolverProvider, NAMERESOLVER_ARGS);
fail("Should fail");
} catch (IllegalArgumentException e) {
// expected
}
}
private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider { private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider {
@Override @Override
public BackoffPolicy get() { public BackoffPolicy get() {

View File

@ -106,6 +106,11 @@ public abstract class ForwardingLoadBalancerHelper extends LoadBalancer.Helper {
return delegate().getAuthority(); return delegate().getAuthority();
} }
@Override
public String getChannelTarget() {
return delegate().getChannelTarget();
}
@Override @Override
public ChannelCredentials getChannelCredentials() { public ChannelCredentials getChannelCredentials() {
return delegate().getChannelCredentials(); return delegate().getChannelCredentials();