api: add support for SocketAddress types in ManagedChannelProvider (#9076)

* api: add support for SocketAddress types in ManagedChannelProvider
also add support for SocketAddress types in NameResolverProvider
Use scheme in target URI to select a NameRseolverProvider and get
that provider's supported SocketAddress types.
implement selection in ManagedChannelRegistry of appropriate
ManagedChannelProvider based on NameResolver's SocketAddress types
This commit is contained in:
sanjaypujare 2022-04-22 09:10:55 -07:00 committed by GitHub
parent 8e65700edc
commit 538db03d56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 381 additions and 0 deletions

View File

@ -17,6 +17,8 @@
package io.grpc; package io.grpc;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.util.Collection;
/** /**
* Provider of managed channels for transport agnostic consumption. * Provider of managed channels for transport agnostic consumption.
@ -79,6 +81,11 @@ public abstract class ManagedChannelProvider {
return NewChannelBuilderResult.error("ChannelCredentials are unsupported"); return NewChannelBuilderResult.error("ChannelCredentials are unsupported");
} }
/**
* Returns the {@link SocketAddress} types this ManagedChannelProvider supports.
*/
protected abstract Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes();
public static final class NewChannelBuilderResult { public static final class NewChannelBuilderResult {
private final ManagedChannelBuilder<?> channelBuilder; private final ManagedChannelBuilder<?> channelBuilder;
private final String error; private final String error;

View File

@ -18,7 +18,12 @@ package io.grpc;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
@ -144,6 +149,28 @@ public final class ManagedChannelRegistry {
} }
ManagedChannelBuilder<?> newChannelBuilder(String target, ChannelCredentials creds) { ManagedChannelBuilder<?> newChannelBuilder(String target, ChannelCredentials creds) {
return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds);
}
@VisibleForTesting
ManagedChannelBuilder<?> newChannelBuilder(NameResolverRegistry nameResolverRegistry,
String target, ChannelCredentials creds) {
NameResolverProvider nameResolverProvider = null;
try {
URI uri = new URI(target);
nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme());
} catch (URISyntaxException ignore) {
// bad URI found, just ignore and continue
}
if (nameResolverProvider == null) {
nameResolverProvider = nameResolverRegistry.providers().get(
nameResolverRegistry.asFactory().getDefaultScheme());
}
Collection<Class<? extends SocketAddress>> nameResolverSocketAddressTypes
= (nameResolverProvider != null)
? nameResolverProvider.getProducedSocketAddressTypes() :
Collections.emptySet();
List<ManagedChannelProvider> providers = providers(); List<ManagedChannelProvider> providers = providers();
if (providers.isEmpty()) { if (providers.isEmpty()) {
throw new ProviderNotFoundException("No functional channel service provider found. " throw new ProviderNotFoundException("No functional channel service provider found. "
@ -152,6 +179,15 @@ public final class ManagedChannelRegistry {
} }
StringBuilder error = new StringBuilder(); StringBuilder error = new StringBuilder();
for (ManagedChannelProvider provider : providers()) { for (ManagedChannelProvider provider : providers()) {
Collection<Class<? extends SocketAddress>> channelProviderSocketAddressTypes
= provider.getSupportedSocketAddressTypes();
if (!channelProviderSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) {
error.append("; ");
error.append(provider.getClass().getName());
error.append(": does not support 1 or more of ");
error.append(Arrays.toString(nameResolverSocketAddressTypes.toArray()));
continue;
}
ManagedChannelProvider.NewChannelBuilderResult result ManagedChannelProvider.NewChannelBuilderResult result
= provider.newChannelBuilder(target, creds); = provider.newChannelBuilder(target, creds);
if (result.getChannelBuilder() != null) { if (result.getChannelBuilder() != null) {

View File

@ -17,6 +17,10 @@
package io.grpc; package io.grpc;
import io.grpc.NameResolver.Factory; import io.grpc.NameResolver.Factory;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/** /**
* Provider of name resolvers for name agnostic consumption. * Provider of name resolvers for name agnostic consumption.
@ -62,4 +66,14 @@ public abstract class NameResolverProvider extends NameResolver.Factory {
protected String getScheme() { protected String getScheme() {
return getDefaultScheme(); return getDefaultScheme();
} }
/**
* Returns the {@link SocketAddress} types this provider's name-resolver is capable of producing.
* This enables selection of the appropriate {@link ManagedChannelProvider} for a channel.
*
* @return the {@link SocketAddress} types this provider's name-resolver is capable of producing.
*/
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
} }

View File

@ -19,6 +19,12 @@ package io.grpc;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import com.google.common.collect.ImmutableSet;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -156,6 +162,256 @@ public class ManagedChannelRegistryTest {
} }
} }
@Test
public void newChannelBuilder_usesScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
class SocketAddress2 extends SocketAddress {
}
nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});
nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
fail("Should not be called");
throw new AssertionError();
}
});
ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}
@Test
public void newChannelBuilder_unsupportedSocketAddressTypes() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
class SocketAddress2 extends SocketAddress {
}
nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class);
}
});
ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
try {
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds);
fail("expected exception");
} catch (ManagedChannelRegistry.ProviderNotFoundException ex) {
assertThat(ex).hasMessageThat().contains("does not support 1 or more of");
assertThat(ex).hasMessageThat().contains("SocketAddress1");
assertThat(ex).hasMessageThat().contains("SocketAddress2");
}
}
@Test
public void newChannelBuilder_emptySet_asDefault() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.emptySet();
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}
@Test
public void newChannelBuilder_noSchemeUsesDefaultScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});
ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(registry.newChannelBuilder(nameResolverRegistry, target, creds)).isSameInstanceAs(
mcb);
}
@Test
public void newChannelBuilder_badUri() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, ":testing123", creds)).isSameInstanceAs(
mcb);
}
private static class BaseNameResolverProvider extends NameResolverProvider {
private final boolean isAvailable;
private final int priority;
private final String defaultScheme;
public BaseNameResolverProvider(boolean isAvailable, int priority, String defaultScheme) {
this.isAvailable = isAvailable;
this.priority = priority;
this.defaultScheme = defaultScheme;
}
@Override
public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return null;
}
@Override
public String getDefaultScheme() {
return defaultScheme;
}
@Override
protected boolean isAvailable() {
return isAvailable;
}
@Override
protected int priority() {
return priority;
}
}
private static class BaseProvider extends ManagedChannelProvider { private static class BaseProvider extends ManagedChannelProvider {
private final boolean isAvailable; private final boolean isAvailable;
private final int priority; private final int priority;
@ -184,5 +440,10 @@ public class ManagedChannelRegistryTest {
protected ManagedChannelBuilder<?> builderForTarget(String target) { protected ManagedChannelBuilder<?> builderForTarget(String target) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
} }
} }

View File

@ -21,7 +21,11 @@ import com.google.common.base.Stopwatch;
import io.grpc.InternalServiceProviders; import io.grpc.InternalServiceProviders;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import io.grpc.NameResolverProvider; import io.grpc.NameResolverProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI; import java.net.URI;
import java.util.Collection;
import java.util.Collections;
/** /**
* A provider for {@link DnsNameResolver}. * A provider for {@link DnsNameResolver}.
@ -75,4 +79,9 @@ public final class DnsNameResolverProvider extends NameResolverProvider {
public int priority() { public int priority() {
return 5; return 5;
} }
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
} }

View File

@ -24,6 +24,10 @@ import io.grpc.ManagedChannelBuilder;
import io.grpc.ManagedChannelProvider; import io.grpc.ManagedChannelProvider;
import io.grpc.ManagedChannelRegistry; import io.grpc.ManagedChannelRegistry;
import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/** A channel provider that injects logging interceptor. */ /** A channel provider that injects logging interceptor. */
final class LoggingChannelProvider extends ManagedChannelProvider { final class LoggingChannelProvider extends ManagedChannelProvider {
@ -90,4 +94,9 @@ final class LoggingChannelProvider extends ManagedChannelProvider {
checkNotNull(result.getError(), "Expected error to be set!"); checkNotNull(result.getError(), "Expected error to be set!");
return result; return result;
} }
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
} }

View File

@ -22,7 +22,11 @@ import io.grpc.NameResolver.Args;
import io.grpc.NameResolverProvider; import io.grpc.NameResolverProvider;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.xds.InternalSharedXdsClientPoolProvider; import io.grpc.xds.InternalSharedXdsClientPoolProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI; import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.Map; import java.util.Map;
/** /**
@ -58,6 +62,11 @@ public final class GoogleCloudToProdNameResolverProvider extends NameResolverPro
return 4; return 4;
} }
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
private static final class SharedXdsClientPoolProviderBootstrapSetter private static final class SharedXdsClientPoolProviderBootstrapSetter
implements GoogleCloudToProdNameResolver.BootstrapSetter { implements GoogleCloudToProdNameResolver.BootstrapSetter {
@Override @Override

View File

@ -22,7 +22,11 @@ import io.grpc.InternalServiceProviders;
import io.grpc.NameResolver.Args; import io.grpc.NameResolver.Args;
import io.grpc.NameResolverProvider; import io.grpc.NameResolverProvider;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI; import java.net.URI;
import java.util.Collection;
import java.util.Collections;
/** /**
* A provider for {@code io.grpc.grpclb.GrpclbNameResolver}. * A provider for {@code io.grpc.grpclb.GrpclbNameResolver}.
@ -85,5 +89,10 @@ final class SecretGrpclbNameResolverProvider {
// Must be higher than DnsNameResolverProvider#priority. // Must be higher than DnsNameResolverProvider#priority.
return 6; return 6;
} }
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
} }
} }

View File

@ -19,6 +19,10 @@ package io.grpc.netty;
import io.grpc.ChannelCredentials; import io.grpc.ChannelCredentials;
import io.grpc.Internal; import io.grpc.Internal;
import io.grpc.ManagedChannelProvider; import io.grpc.ManagedChannelProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/** Provider for {@link NettyChannelBuilder} instances. */ /** Provider for {@link NettyChannelBuilder} instances. */
@Internal @Internal
@ -52,4 +56,9 @@ public final class NettyChannelProvider extends ManagedChannelProvider {
return NewChannelBuilderResult.channelBuilder( return NewChannelBuilderResult.channelBuilder(
new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator)); new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator));
} }
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
} }

View File

@ -20,6 +20,10 @@ import io.grpc.ChannelCredentials;
import io.grpc.Internal; import io.grpc.Internal;
import io.grpc.InternalServiceProviders; import io.grpc.InternalServiceProviders;
import io.grpc.ManagedChannelProvider; import io.grpc.ManagedChannelProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/** /**
* Provider for {@link OkHttpChannelBuilder} instances. * Provider for {@link OkHttpChannelBuilder} instances.
@ -57,4 +61,9 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider {
return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder( return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder(
target, creds, result.callCredentials, result.factory)); target, creds, result.callCredentials, result.factory));
} }
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
} }

View File

@ -23,7 +23,11 @@ import io.grpc.Internal;
import io.grpc.NameResolver.Args; import io.grpc.NameResolver.Args;
import io.grpc.NameResolverProvider; import io.grpc.NameResolverProvider;
import io.grpc.internal.ObjectPool; import io.grpc.internal.ObjectPool;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI; import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -99,6 +103,11 @@ public final class XdsNameResolverProvider extends NameResolverProvider {
return 4; return 4;
} }
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
interface XdsClientPoolFactory { interface XdsClientPoolFactory {
void setBootstrapOverride(Map<String, ?> bootstrap); void setBootstrapOverride(Map<String, ?> bootstrap);