api: add support for SocketAddress types in ManagedChannelProvider (#9076)

* api: add support for SocketAddress types in ManagedChannelProvider
also add support for SocketAddress types in NameResolverProvider
Use scheme in target URI to select a NameRseolverProvider and get
that provider's supported SocketAddress types.
implement selection in ManagedChannelRegistry of appropriate
ManagedChannelProvider based on NameResolver's SocketAddress types
This commit is contained in:
sanjaypujare 2022-04-22 09:10:55 -07:00 committed by GitHub
parent 8e65700edc
commit 538db03d56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 381 additions and 0 deletions

View File

@ -17,6 +17,8 @@
package io.grpc;
import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.util.Collection;
/**
* Provider of managed channels for transport agnostic consumption.
@ -79,6 +81,11 @@ public abstract class ManagedChannelProvider {
return NewChannelBuilderResult.error("ChannelCredentials are unsupported");
}
/**
* Returns the {@link SocketAddress} types this ManagedChannelProvider supports.
*/
protected abstract Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes();
public static final class NewChannelBuilderResult {
private final ManagedChannelBuilder<?> channelBuilder;
private final String error;

View File

@ -18,7 +18,12 @@ package io.grpc;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
@ -144,6 +149,28 @@ public final class ManagedChannelRegistry {
}
ManagedChannelBuilder<?> newChannelBuilder(String target, ChannelCredentials creds) {
return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds);
}
@VisibleForTesting
ManagedChannelBuilder<?> newChannelBuilder(NameResolverRegistry nameResolverRegistry,
String target, ChannelCredentials creds) {
NameResolverProvider nameResolverProvider = null;
try {
URI uri = new URI(target);
nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme());
} catch (URISyntaxException ignore) {
// bad URI found, just ignore and continue
}
if (nameResolverProvider == null) {
nameResolverProvider = nameResolverRegistry.providers().get(
nameResolverRegistry.asFactory().getDefaultScheme());
}
Collection<Class<? extends SocketAddress>> nameResolverSocketAddressTypes
= (nameResolverProvider != null)
? nameResolverProvider.getProducedSocketAddressTypes() :
Collections.emptySet();
List<ManagedChannelProvider> providers = providers();
if (providers.isEmpty()) {
throw new ProviderNotFoundException("No functional channel service provider found. "
@ -152,6 +179,15 @@ public final class ManagedChannelRegistry {
}
StringBuilder error = new StringBuilder();
for (ManagedChannelProvider provider : providers()) {
Collection<Class<? extends SocketAddress>> channelProviderSocketAddressTypes
= provider.getSupportedSocketAddressTypes();
if (!channelProviderSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) {
error.append("; ");
error.append(provider.getClass().getName());
error.append(": does not support 1 or more of ");
error.append(Arrays.toString(nameResolverSocketAddressTypes.toArray()));
continue;
}
ManagedChannelProvider.NewChannelBuilderResult result
= provider.newChannelBuilder(target, creds);
if (result.getChannelBuilder() != null) {

View File

@ -17,6 +17,10 @@
package io.grpc;
import io.grpc.NameResolver.Factory;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/**
* Provider of name resolvers for name agnostic consumption.
@ -62,4 +66,14 @@ public abstract class NameResolverProvider extends NameResolver.Factory {
protected String getScheme() {
return getDefaultScheme();
}
/**
* Returns the {@link SocketAddress} types this provider's name-resolver is capable of producing.
* This enables selection of the appropriate {@link ManagedChannelProvider} for a channel.
*
* @return the {@link SocketAddress} types this provider's name-resolver is capable of producing.
*/
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}

View File

@ -19,6 +19,12 @@ package io.grpc;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import com.google.common.collect.ImmutableSet;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@ -156,6 +162,256 @@ public class ManagedChannelRegistryTest {
}
}
@Test
public void newChannelBuilder_usesScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
class SocketAddress2 extends SocketAddress {
}
nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});
nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
fail("Should not be called");
throw new AssertionError();
}
});
ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}
@Test
public void newChannelBuilder_unsupportedSocketAddressTypes() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
class SocketAddress2 extends SocketAddress {
}
nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class);
}
});
ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
try {
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds);
fail("expected exception");
} catch (ManagedChannelRegistry.ProviderNotFoundException ex) {
assertThat(ex).hasMessageThat().contains("does not support 1 or more of");
assertThat(ex).hasMessageThat().contains("SocketAddress1");
assertThat(ex).hasMessageThat().contains("SocketAddress2");
}
}
@Test
public void newChannelBuilder_emptySet_asDefault() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.emptySet();
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}
@Test
public void newChannelBuilder_noSchemeUsesDefaultScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});
ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(registry.newChannelBuilder(nameResolverRegistry, target, creds)).isSameInstanceAs(
mcb);
}
@Test
public void newChannelBuilder_badUri() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}
ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}
final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, ":testing123", creds)).isSameInstanceAs(
mcb);
}
private static class BaseNameResolverProvider extends NameResolverProvider {
private final boolean isAvailable;
private final int priority;
private final String defaultScheme;
public BaseNameResolverProvider(boolean isAvailable, int priority, String defaultScheme) {
this.isAvailable = isAvailable;
this.priority = priority;
this.defaultScheme = defaultScheme;
}
@Override
public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return null;
}
@Override
public String getDefaultScheme() {
return defaultScheme;
}
@Override
protected boolean isAvailable() {
return isAvailable;
}
@Override
protected int priority() {
return priority;
}
}
private static class BaseProvider extends ManagedChannelProvider {
private final boolean isAvailable;
private final int priority;
@ -184,5 +440,10 @@ public class ManagedChannelRegistryTest {
protected ManagedChannelBuilder<?> builderForTarget(String target) {
throw new UnsupportedOperationException();
}
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}
}

View File

@ -21,7 +21,11 @@ import com.google.common.base.Stopwatch;
import io.grpc.InternalServiceProviders;
import io.grpc.NameResolver;
import io.grpc.NameResolverProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
/**
* A provider for {@link DnsNameResolver}.
@ -75,4 +79,9 @@ public final class DnsNameResolverProvider extends NameResolverProvider {
public int priority() {
return 5;
}
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}

View File

@ -24,6 +24,10 @@ import io.grpc.ManagedChannelBuilder;
import io.grpc.ManagedChannelProvider;
import io.grpc.ManagedChannelRegistry;
import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/** A channel provider that injects logging interceptor. */
final class LoggingChannelProvider extends ManagedChannelProvider {
@ -90,4 +94,9 @@ final class LoggingChannelProvider extends ManagedChannelProvider {
checkNotNull(result.getError(), "Expected error to be set!");
return result;
}
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}

View File

@ -22,7 +22,11 @@ import io.grpc.NameResolver.Args;
import io.grpc.NameResolverProvider;
import io.grpc.internal.GrpcUtil;
import io.grpc.xds.InternalSharedXdsClientPoolProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
/**
@ -58,6 +62,11 @@ public final class GoogleCloudToProdNameResolverProvider extends NameResolverPro
return 4;
}
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
private static final class SharedXdsClientPoolProviderBootstrapSetter
implements GoogleCloudToProdNameResolver.BootstrapSetter {
@Override

View File

@ -22,7 +22,11 @@ import io.grpc.InternalServiceProviders;
import io.grpc.NameResolver.Args;
import io.grpc.NameResolverProvider;
import io.grpc.internal.GrpcUtil;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
/**
* A provider for {@code io.grpc.grpclb.GrpclbNameResolver}.
@ -85,5 +89,10 @@ final class SecretGrpclbNameResolverProvider {
// Must be higher than DnsNameResolverProvider#priority.
return 6;
}
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}
}

View File

@ -19,6 +19,10 @@ package io.grpc.netty;
import io.grpc.ChannelCredentials;
import io.grpc.Internal;
import io.grpc.ManagedChannelProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/** Provider for {@link NettyChannelBuilder} instances. */
@Internal
@ -52,4 +56,9 @@ public final class NettyChannelProvider extends ManagedChannelProvider {
return NewChannelBuilderResult.channelBuilder(
new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator));
}
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}

View File

@ -20,6 +20,10 @@ import io.grpc.ChannelCredentials;
import io.grpc.Internal;
import io.grpc.InternalServiceProviders;
import io.grpc.ManagedChannelProvider;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
/**
* Provider for {@link OkHttpChannelBuilder} instances.
@ -57,4 +61,9 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider {
return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder(
target, creds, result.callCredentials, result.factory));
}
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}

View File

@ -23,7 +23,11 @@ import io.grpc.Internal;
import io.grpc.NameResolver.Args;
import io.grpc.NameResolverProvider;
import io.grpc.internal.ObjectPool;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable;
@ -99,6 +103,11 @@ public final class XdsNameResolverProvider extends NameResolverProvider {
return 4;
}
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
interface XdsClientPoolFactory {
void setBootstrapOverride(Map<String, ?> bootstrap);