From 538db03d56630df042c44c2e81363be924a2be8f Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Fri, 22 Apr 2022 09:10:55 -0700 Subject: [PATCH] api: add support for SocketAddress types in ManagedChannelProvider (#9076) * api: add support for SocketAddress types in ManagedChannelProvider also add support for SocketAddress types in NameResolverProvider Use scheme in target URI to select a NameRseolverProvider and get that provider's supported SocketAddress types. implement selection in ManagedChannelRegistry of appropriate ManagedChannelProvider based on NameResolver's SocketAddress types --- .../java/io/grpc/ManagedChannelProvider.java | 7 + .../java/io/grpc/ManagedChannelRegistry.java | 36 +++ .../java/io/grpc/NameResolverProvider.java | 14 + .../io/grpc/ManagedChannelRegistryTest.java | 261 ++++++++++++++++++ .../internal/DnsNameResolverProvider.java | 9 + .../observability/LoggingChannelProvider.java | 9 + ...GoogleCloudToProdNameResolverProvider.java | 9 + .../SecretGrpclbNameResolverProvider.java | 9 + .../io/grpc/netty/NettyChannelProvider.java | 9 + .../io/grpc/okhttp/OkHttpChannelProvider.java | 9 + .../io/grpc/xds/XdsNameResolverProvider.java | 9 + 11 files changed, 381 insertions(+) diff --git a/api/src/main/java/io/grpc/ManagedChannelProvider.java b/api/src/main/java/io/grpc/ManagedChannelProvider.java index f57340d9ba..42941dfc80 100644 --- a/api/src/main/java/io/grpc/ManagedChannelProvider.java +++ b/api/src/main/java/io/grpc/ManagedChannelProvider.java @@ -17,6 +17,8 @@ package io.grpc; import com.google.common.base.Preconditions; +import java.net.SocketAddress; +import java.util.Collection; /** * Provider of managed channels for transport agnostic consumption. @@ -79,6 +81,11 @@ public abstract class ManagedChannelProvider { return NewChannelBuilderResult.error("ChannelCredentials are unsupported"); } + /** + * Returns the {@link SocketAddress} types this ManagedChannelProvider supports. + */ + protected abstract Collection> getSupportedSocketAddressTypes(); + public static final class NewChannelBuilderResult { private final ManagedChannelBuilder channelBuilder; private final String error; diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index 8eb1cce14a..677856ed8d 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -18,7 +18,12 @@ package io.grpc; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashSet; @@ -144,6 +149,28 @@ public final class ManagedChannelRegistry { } ManagedChannelBuilder newChannelBuilder(String target, ChannelCredentials creds) { + return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds); + } + + @VisibleForTesting + ManagedChannelBuilder newChannelBuilder(NameResolverRegistry nameResolverRegistry, + String target, ChannelCredentials creds) { + NameResolverProvider nameResolverProvider = null; + try { + URI uri = new URI(target); + nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme()); + } catch (URISyntaxException ignore) { + // bad URI found, just ignore and continue + } + if (nameResolverProvider == null) { + nameResolverProvider = nameResolverRegistry.providers().get( + nameResolverRegistry.asFactory().getDefaultScheme()); + } + Collection> nameResolverSocketAddressTypes + = (nameResolverProvider != null) + ? nameResolverProvider.getProducedSocketAddressTypes() : + Collections.emptySet(); + List providers = providers(); if (providers.isEmpty()) { throw new ProviderNotFoundException("No functional channel service provider found. " @@ -152,6 +179,15 @@ public final class ManagedChannelRegistry { } StringBuilder error = new StringBuilder(); for (ManagedChannelProvider provider : providers()) { + Collection> channelProviderSocketAddressTypes + = provider.getSupportedSocketAddressTypes(); + if (!channelProviderSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { + error.append("; "); + error.append(provider.getClass().getName()); + error.append(": does not support 1 or more of "); + error.append(Arrays.toString(nameResolverSocketAddressTypes.toArray())); + continue; + } ManagedChannelProvider.NewChannelBuilderResult result = provider.newChannelBuilder(target, creds); if (result.getChannelBuilder() != null) { diff --git a/api/src/main/java/io/grpc/NameResolverProvider.java b/api/src/main/java/io/grpc/NameResolverProvider.java index 2c337cd505..e7cddfc36d 100644 --- a/api/src/main/java/io/grpc/NameResolverProvider.java +++ b/api/src/main/java/io/grpc/NameResolverProvider.java @@ -17,6 +17,10 @@ package io.grpc; import io.grpc.NameResolver.Factory; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** * Provider of name resolvers for name agnostic consumption. @@ -62,4 +66,14 @@ public abstract class NameResolverProvider extends NameResolver.Factory { protected String getScheme() { return getDefaultScheme(); } + + /** + * Returns the {@link SocketAddress} types this provider's name-resolver is capable of producing. + * This enables selection of the appropriate {@link ManagedChannelProvider} for a channel. + * + * @return the {@link SocketAddress} types this provider's name-resolver is capable of producing. + */ + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java index 6f25f62057..283c179277 100644 --- a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java +++ b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java @@ -19,6 +19,12 @@ package io.grpc; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.common.collect.ImmutableSet; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collection; +import java.util.Collections; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -156,6 +162,256 @@ public class ManagedChannelRegistryTest { } } + @Test + public void newChannelBuilder_usesScheme() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + class SocketAddress2 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + }); + nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") { + @Override + protected Collection> getProducedSocketAddressTypes() { + fail("Should not be called"); + throw new AssertionError(); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + registry.register(new BaseProvider(true, 5) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress2.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_unsupportedSocketAddressTypes() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + class SocketAddress2 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + registry.register(new BaseProvider(true, 5) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress2.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + try { + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds); + fail("expected exception"); + } catch (ManagedChannelRegistry.ProviderNotFoundException ex) { + assertThat(ex).hasMessageThat().contains("does not support 1 or more of"); + assertThat(ex).hasMessageThat().contains("SocketAddress1"); + assertThat(ex).hasMessageThat().contains("SocketAddress2"); + } + } + + @Test + public void newChannelBuilder_emptySet_asDefault() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.emptySet(); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_noSchemeUsesDefaultScheme() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat(registry.newChannelBuilder(nameResolverRegistry, target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_badUri() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, ":testing123", creds)).isSameInstanceAs( + mcb); + } + + private static class BaseNameResolverProvider extends NameResolverProvider { + private final boolean isAvailable; + private final int priority; + private final String defaultScheme; + + public BaseNameResolverProvider(boolean isAvailable, int priority, String defaultScheme) { + this.isAvailable = isAvailable; + this.priority = priority; + this.defaultScheme = defaultScheme; + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return null; + } + + @Override + public String getDefaultScheme() { + return defaultScheme; + } + + @Override + protected boolean isAvailable() { + return isAvailable; + } + + @Override + protected int priority() { + return priority; + } + } + private static class BaseProvider extends ManagedChannelProvider { private final boolean isAvailable; private final int priority; @@ -184,5 +440,10 @@ public class ManagedChannelRegistryTest { protected ManagedChannelBuilder builderForTarget(String target) { throw new UnsupportedOperationException(); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } } diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java index 1c9290d2fc..8078aa0d4c 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java @@ -21,7 +21,11 @@ import com.google.common.base.Stopwatch; import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; /** * A provider for {@link DnsNameResolver}. @@ -75,4 +79,9 @@ public final class DnsNameResolverProvider extends NameResolverProvider { public int priority() { return 5; } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java index ffbc24be69..81c3501e1d 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java @@ -24,6 +24,10 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelProvider; import io.grpc.ManagedChannelRegistry; import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** A channel provider that injects logging interceptor. */ final class LoggingChannelProvider extends ManagedChannelProvider { @@ -90,4 +94,9 @@ final class LoggingChannelProvider extends ManagedChannelProvider { checkNotNull(result.getError(), "Expected error to be set!"); return result; } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java index ac39ab1e62..b431cba4d2 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java @@ -22,7 +22,11 @@ import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; import io.grpc.internal.GrpcUtil; import io.grpc.xds.InternalSharedXdsClientPoolProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; import java.util.Map; /** @@ -58,6 +62,11 @@ public final class GoogleCloudToProdNameResolverProvider extends NameResolverPro return 4; } + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + private static final class SharedXdsClientPoolProviderBootstrapSetter implements GoogleCloudToProdNameResolver.BootstrapSetter { @Override diff --git a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java index bc25f28f94..da5b7c3353 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java @@ -22,7 +22,11 @@ import io.grpc.InternalServiceProviders; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; import io.grpc.internal.GrpcUtil; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; /** * A provider for {@code io.grpc.grpclb.GrpclbNameResolver}. @@ -85,5 +89,10 @@ final class SecretGrpclbNameResolverProvider { // Must be higher than DnsNameResolverProvider#priority. return 6; } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java index bf3df4fa6a..7cc77c150a 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java @@ -19,6 +19,10 @@ package io.grpc.netty; import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** Provider for {@link NettyChannelBuilder} instances. */ @Internal @@ -52,4 +56,9 @@ public final class NettyChannelProvider extends ManagedChannelProvider { return NewChannelBuilderResult.channelBuilder( new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator)); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java index 19f99d0502..17a2512a66 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java @@ -20,6 +20,10 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.InternalServiceProviders; import io.grpc.ManagedChannelProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** * Provider for {@link OkHttpChannelBuilder} instances. @@ -57,4 +61,9 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider { return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder( target, creds, result.callCredentials, result.factory)); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index a02e27c37c..0eb51c9128 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -23,7 +23,11 @@ import io.grpc.Internal; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; import io.grpc.internal.ObjectPool; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; @@ -99,6 +103,11 @@ public final class XdsNameResolverProvider extends NameResolverProvider { return 4; } + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + interface XdsClientPoolFactory { void setBootstrapOverride(Map bootstrap);