From 4561bb5b804ca242e0dd4047d8d2744c3eb9920f Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 30 Apr 2024 08:57:55 -0700 Subject: [PATCH] Plumb target to load balancer gRFC A78 has WRR and pick-first include a `grpc.target` label, defined in A66: > `grpc.target` : Canonicalized target URI used when creating gRPC > Channel, e.g. "dns:///pubsub.googleapis.com:443", > "xds:///helloworld-gke:8000". Canonicalized target URI is the form > with the scheme included if the user didn't mention the scheme > (`scheme://[authority]/path`). For channels such as inprocess channels > where a target URI is not available, implementations can synthesize a > target URI. --- api/src/main/java/io/grpc/LoadBalancer.java | 7 ++ .../io/grpc/internal/ManagedChannelImpl.java | 53 ++++++---- ...ManagedChannelImplGetNameResolverTest.java | 96 +++---------------- .../grpc/internal/ManagedChannelImplTest.java | 85 ++++++++++++++++ .../util/ForwardingLoadBalancerHelper.java | 5 + 5 files changed, 145 insertions(+), 101 deletions(-) diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index 00fcd82b0a..f1bab25d87 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -1185,6 +1185,13 @@ public abstract class LoadBalancer { */ public abstract String getAuthority(); + /** + * Returns the target string of the channel, guaranteed to include its scheme. + */ + public String getChannelTarget() { + throw new UnsupportedOperationException(); + } + /** * Returns the ChannelCredentials used to construct the channel, without bearer tokens. * diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 241fb216c2..2ba9e34999 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -166,6 +166,8 @@ final class ManagedChannelImpl extends ManagedChannel implements @Nullable private final String authorityOverride; private final NameResolverRegistry nameResolverRegistry; + private final URI targetUri; + private final NameResolverProvider nameResolverProvider; private final NameResolver.Args nameResolverArgs; private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; private final ClientTransportFactory originalTransportFactory; @@ -383,8 +385,7 @@ final class ManagedChannelImpl extends ManagedChannel implements nameResolverStarted = false; if (channelIsActive) { nameResolver = getNameResolver( - target, authorityOverride, nameResolverRegistry, nameResolverArgs, - transportFactory.getSupportedSocketAddressTypes()); + targetUri, authorityOverride, nameResolverProvider, nameResolverArgs); } else { nameResolver = null; } @@ -621,6 +622,10 @@ final class ManagedChannelImpl extends ManagedChannel implements this.retryEnabled = builder.retryEnabled; this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy); this.nameResolverRegistry = builder.nameResolverRegistry; + ResolvedNameResolver resolvedResolver = getNameResolverProvider( + target, nameResolverRegistry, transportFactory.getSupportedSocketAddressTypes()); + this.targetUri = resolvedResolver.targetUri; + this.nameResolverProvider = resolvedResolver.provider; ScParser serviceConfigParser = new ScParser( retryEnabled, @@ -640,8 +645,7 @@ final class ManagedChannelImpl extends ManagedChannel implements .setOverrideAuthority(this.authorityOverride) .build(); this.nameResolver = getNameResolver( - target, authorityOverride, nameResolverRegistry, nameResolverArgs, - transportFactory.getSupportedSocketAddressTypes()); + targetUri, authorityOverride, nameResolverProvider, nameResolverArgs); this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool"); this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool); this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext); @@ -713,8 +717,20 @@ final class ManagedChannelImpl extends ManagedChannel implements } } - private static NameResolver getNameResolver( - String target, NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, + @VisibleForTesting + static class ResolvedNameResolver { + public final URI targetUri; + public final NameResolverProvider provider; + + public ResolvedNameResolver(URI targetUri, NameResolverProvider provider) { + this.targetUri = checkNotNull(targetUri, "targetUri"); + this.provider = checkNotNull(provider, "provider"); + } + } + + @VisibleForTesting + static ResolvedNameResolver getNameResolverProvider( + String target, NameResolverRegistry nameResolverRegistry, Collection> channelTransportSocketAddressTypes) { // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending // "dns:///". @@ -761,23 +777,17 @@ final class ManagedChannelImpl extends ManagedChannel implements } } - NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs); - if (resolver != null) { - return resolver; - } - - throw new IllegalArgumentException(String.format( - "cannot create a NameResolver for %s%s", - target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); + return new ResolvedNameResolver(targetUri, provider); } @VisibleForTesting static NameResolver getNameResolver( - String target, @Nullable final String overrideAuthority, - NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, - Collection> channelTransportSocketAddressTypes) { - NameResolver resolver = getNameResolver(target, nameResolverRegistry, nameResolverArgs, - channelTransportSocketAddressTypes); + URI targetUri, @Nullable final String overrideAuthority, + NameResolverProvider provider, NameResolver.Args nameResolverArgs) { + NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs); + if (resolver == null) { + throw new IllegalArgumentException("cannot create a NameResolver for " + targetUri); + } // We wrap the name resolver in a RetryingNameResolver to give it the ability to retry failures. // TODO: After a transition period, all NameResolver implementations that need retry should use @@ -1703,6 +1713,11 @@ final class ManagedChannelImpl extends ManagedChannel implements return ManagedChannelImpl.this.authority(); } + @Override + public String getChannelTarget() { + return targetUri.toString(); + } + @Override public SynchronizationContext getSynchronizationContext() { return syncContext; diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java index 452e071912..d930045a13 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java @@ -17,21 +17,12 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import io.grpc.ChannelLogger; import io.grpc.NameResolver; -import io.grpc.NameResolver.Args; -import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; -import io.grpc.ProxyDetector; -import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessSocketAddress; -import java.lang.Thread.UncaughtExceptionHandler; import java.net.InetSocketAddress; import java.net.URI; import java.util.Collections; @@ -39,18 +30,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for ManagedChannelImpl#getNameResolver(). */ +/** Unit tests for ManagedChannelImpl#getNameResolverProvider(). */ @RunWith(JUnit4.class) public class ManagedChannelImplGetNameResolverTest { - private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder() - .setDefaultPort(447) - .setProxyDetector(mock(ProxyDetector.class)) - .setSynchronizationContext(new SynchronizationContext(mock(UncaughtExceptionHandler.class))) - .setServiceConfigParser(mock(ServiceConfigParser.class)) - .setChannelLogger(mock(ChannelLogger.class)) - .setScheduledExecutorService(new FakeClock().getScheduledExecutorService()) - .build(); - @Test public void invalidUriTarget() { testInvalidTarget("defaultscheme:///[invalid]"); @@ -68,18 +50,6 @@ public class ManagedChannelImplGetNameResolverTest { new URI("defaultscheme", "", "/foo.googleapis.com:8080", null)); } - @Test - public void validAuthorityTarget_overrideAuthority() throws Exception { - String target = "foo.googleapis.com:8080"; - String overrideAuthority = "override.authority"; - URI expectedUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null); - NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); - NameResolver nameResolver = ManagedChannelImpl.getNameResolver( - target, overrideAuthority, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class)); - assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority); - } - @Test public void validUriTarget() throws Exception { testValidTarget("scheme:///foo.googleapis.com:8080", "scheme:///foo.googleapis.com:8080", @@ -121,47 +91,12 @@ public class ManagedChannelImplGetNameResolverTest { new URI("defaultscheme", "", "//target", null)); } - @Test - public void validTargetNoResolver() { - NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); - NameResolverProvider nameResolverProvider = new NameResolverProvider() { - @Override - protected boolean isAvailable() { - return true; - } - - @Override - protected int priority() { - return 5; - } - - @Override - public NameResolver newNameResolver(URI targetUri, Args args) { - return null; - } - - @Override - public String getDefaultScheme() { - return "defaultscheme"; - } - }; - nameResolverRegistry.register(nameResolverProvider); - try { - ManagedChannelImpl.getNameResolver( - "foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class)); - fail("Should fail"); - } catch (IllegalArgumentException e) { - // expected - } - } - @Test public void validTargetNoProvider() { NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); try { - ManagedChannelImpl.getNameResolver( - "foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, + ManagedChannelImpl.getNameResolverProvider( + "foo.googleapis.com:8080", nameResolverRegistry, Collections.singleton(InetSocketAddress.class)); fail("Should fail"); } catch (IllegalArgumentException e) { @@ -173,8 +108,8 @@ public class ManagedChannelImplGetNameResolverTest { public void validTargetProviderAddrTypesNotSupported() { NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); try { - ManagedChannelImpl.getNameResolver( - "testscheme:///foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, + ManagedChannelImpl.getNameResolverProvider( + "testscheme:///foo.googleapis.com:8080", nameResolverRegistry, Collections.singleton(InProcessSocketAddress.class)); fail("Should fail"); } catch (IllegalArgumentException e) { @@ -184,26 +119,23 @@ public class ManagedChannelImplGetNameResolverTest { } } - private void testValidTarget(String target, String expectedUriString, URI expectedUri) { NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); - FakeNameResolver nameResolver - = (FakeNameResolver) ((RetryingNameResolver) ManagedChannelImpl.getNameResolver( - target, null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class))).getRetriedNameResolver(); - assertNotNull(nameResolver); - assertEquals(expectedUri, nameResolver.uri); - assertEquals(expectedUriString, nameResolver.uri.toString()); + ManagedChannelImpl.ResolvedNameResolver resolved = ManagedChannelImpl.getNameResolverProvider( + target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class)); + assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class); + assertThat(resolved.targetUri).isEqualTo(expectedUri); + assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString); } private void testInvalidTarget(String target) { NameResolverRegistry nameResolverRegistry = getTestRegistry("dns"); try { - FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver( - target, null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class)); - fail("Should have failed, but got resolver with " + nameResolver.uri); + ManagedChannelImpl.ResolvedNameResolver resolved = ManagedChannelImpl.getNameResolverProvider( + target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class)); + FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider; + fail("Should have failed, but got resolver provider " + nameResolverProvider); } catch (IllegalArgumentException e) { // expected } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index dc2162bfce..9f7c043d72 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -104,6 +104,7 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.ProxiedSocketAddress; import io.grpc.ProxyDetector; @@ -112,6 +113,7 @@ import io.grpc.ServerMethodDefinition; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StringMarshaller; +import io.grpc.SynchronizationContext; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.InternalSubchannel.TransportLogger; @@ -188,6 +190,15 @@ public class ManagedChannelImplTest { .setUserAgent(USER_AGENT); private static final String TARGET = "fake://" + SERVICE_NAME; private static final String MOCK_POLICY_NAME = "mock_lb"; + private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder() + .setDefaultPort(447) + .setProxyDetector(mock(ProxyDetector.class)) + .setSynchronizationContext( + new SynchronizationContext(mock(Thread.UncaughtExceptionHandler.class))) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setScheduledExecutorService(new FakeClock().getScheduledExecutorService()) + .build(); + private URI expectedUri; private final SocketAddress socketAddress = new SocketAddress() { @@ -4306,6 +4317,80 @@ public class ManagedChannelImplTest { assertEquals(1, terminationCallbackCalled.get()); } + @Test + public void validAuthorityTarget_overrideAuthority() throws Exception { + String overrideAuthority = "override.authority"; + String serviceAuthority = "fakeauthority"; + NameResolverProvider nameResolverProvider = new NameResolverProvider() { + @Override protected boolean isAvailable() { + return true; + } + + @Override protected int priority() { + return 5; + } + + @Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new NameResolver() { + @Override public String getServiceAuthority() { + return serviceAuthority; + } + + @Override public void start(final Listener2 listener) {} + + @Override public void shutdown() {} + }; + } + + @Override public String getDefaultScheme() { + return "defaultscheme"; + } + }; + + URI targetUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null); + NameResolver nameResolver = ManagedChannelImpl.getNameResolver( + targetUri, null, nameResolverProvider, NAMERESOLVER_ARGS); + assertThat(nameResolver.getServiceAuthority()).isEqualTo(serviceAuthority); + + nameResolver = ManagedChannelImpl.getNameResolver( + targetUri, overrideAuthority, nameResolverProvider, NAMERESOLVER_ARGS); + assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority); + } + + @Test + public void validTargetNoResolver_throws() { + NameResolverProvider nameResolverProvider = new NameResolverProvider() { + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return null; + } + + @Override + public String getDefaultScheme() { + return "defaultscheme"; + } + }; + try { + ManagedChannelImpl.getNameResolver( + URI.create("defaultscheme:///foo.gogoleapis.com:8080"), + null, nameResolverProvider, NAMERESOLVER_ARGS); + fail("Should fail"); + } catch (IllegalArgumentException e) { + // expected + } + } + + private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider { @Override public BackoffPolicy get() { diff --git a/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java b/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java index b3b1d36159..338903fc5f 100644 --- a/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java +++ b/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java @@ -106,6 +106,11 @@ public abstract class ForwardingLoadBalancerHelper extends LoadBalancer.Helper { return delegate().getAuthority(); } + @Override + public String getChannelTarget() { + return delegate().getChannelTarget(); + } + @Override public ChannelCredentials getChannelCredentials() { return delegate().getChannelCredentials();