core: enhance ManagedChannelBuilder.overrideAuthority()

Enhance `ManagedChannelBuilder.overrideAuthority()` to make it impossible to use a different authority to a backend by wrapping ClientTransportFactory.newClientTransport() and setting ClientTransportOptions’ authority. To avoid confusing the LB policy, it would need to keep the original addresses to return during `Subchannel.getAddresses()`

The class `OverrideAuthorityNameResolverFactory` is deleted and its logic is moved into `ManagedChannelImpl`.
This commit is contained in:
ZHANG Dapeng 2021-01-29 09:29:06 -08:00 committed by GitHub
parent 64676198c5
commit 9437783838
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 141 additions and 202 deletions

View File

@ -37,6 +37,9 @@ public final class EquivalentAddressGroup {
/**
* The authority to be used when constructing Subchannels for this EquivalentAddressGroup.
* However, if the channel has overridden authority via
* {@link ManagedChannelBuilder#overrideAuthority(String)}, the transport will use the channel's
* authority override.
*/
@Attr
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/6138")

View File

@ -163,6 +163,9 @@ public abstract class ManagedChannelBuilder<T extends ManagedChannelBuilder<T>>
* Overrides the authority used with TLS and HTTP virtual hosting. It does not change what host is
* actually connected to. Is commonly in the form {@code host:port}.
*
* <p>If the channel builder overrides authority, any authority override from name resolution
* result (via {@link EquivalentAddressGroup#ATTR_AUTHORITY_OVERRIDE}) will be discarded.
*
* <p>This method is intended for testing, but may safely be used outside of tests as an
* alternative to DNS overrides.
*

View File

@ -22,6 +22,7 @@ import static com.google.common.base.Preconditions.checkState;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
@ -152,6 +153,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
private final InternalLogId logId;
private final String target;
@Nullable
private final String authorityOverride;
private final NameResolverRegistry nameResolverRegistry;
private final NameResolver.Factory nameResolverFactory;
private final NameResolver.Args nameResolverArgs;
@ -361,7 +364,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
nameResolver.shutdown();
nameResolverStarted = false;
if (channelIsActive) {
nameResolver = getNameResolver(target, nameResolverFactory, nameResolverArgs);
nameResolver = getNameResolver(
target, authorityOverride, nameResolverFactory, nameResolverArgs);
} else {
nameResolver = null;
}
@ -612,7 +616,6 @@ final class ManagedChannelImpl extends ManagedChannel implements
logId, builder.maxTraceEvents, timeProvider.currentTimeNanos(),
"Channel for '" + target + "'");
channelLogger = new ChannelLoggerImpl(channelTracer, timeProvider);
this.nameResolverFactory = builder.getNameResolverFactory();
ProxyDetector proxyDetector =
builder.proxyDetector != null ? builder.proxyDetector : GrpcUtil.DEFAULT_PROXY_DETECTOR;
this.retryEnabled = builder.retryEnabled && !builder.temporarilyDisableRetry;
@ -644,7 +647,10 @@ final class ManagedChannelImpl extends ManagedChannel implements
}
})
.build();
this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverArgs);
this.authorityOverride = builder.authorityOverride;
this.nameResolverFactory = builder.nameResolverFactory;
this.nameResolver = getNameResolver(
target, authorityOverride, nameResolverFactory, nameResolverArgs);
this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool");
this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool);
this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext);
@ -715,9 +721,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
}
}
@VisibleForTesting
static NameResolver getNameResolver(String target, NameResolver.Factory nameResolverFactory,
NameResolver.Args nameResolverArgs) {
private static NameResolver getNameResolver(
String target, NameResolver.Factory nameResolverFactory, NameResolver.Args nameResolverArgs) {
// Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending
// "dns:///".
URI targetUri = null;
@ -760,6 +765,22 @@ final class ManagedChannelImpl extends ManagedChannel implements
target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : ""));
}
@VisibleForTesting
static NameResolver getNameResolver(
String target, @Nullable final String overrideAuthority,
NameResolver.Factory nameResolverFactory, NameResolver.Args nameResolverArgs) {
NameResolver resolver = getNameResolver(target, nameResolverFactory, nameResolverArgs);
if (overrideAuthority == null) {
return resolver;
}
return new ForwardingNameResolver(resolver) {
@Override
public String getServiceAuthority() {
return overrideAuthority;
}
};
}
@VisibleForTesting
InternalConfigSelector getConfigSelector() {
return realChannel.configSelector.get();
@ -1850,12 +1871,19 @@ final class ManagedChannelImpl extends ManagedChannel implements
final InternalLogId subchannelLogId;
final ChannelLoggerImpl subchannelLogger;
final ChannelTracer subchannelTracer;
List<EquivalentAddressGroup> addressGroups;
InternalSubchannel subchannel;
boolean started;
boolean shutdown;
ScheduledHandle delayedShutdownTask;
SubchannelImpl(CreateSubchannelArgs args, LbHelperImpl helper) {
addressGroups = args.getAddresses();
if (authorityOverride != null) {
List<EquivalentAddressGroup> eagsWithoutOverrideAttr =
stripOverrideAuthorityAttributes(args.getAddresses());
args = args.toBuilder().setAddresses(eagsWithoutOverrideAttr).build();
}
this.args = checkNotNull(args, "args");
this.helper = checkNotNull(helper, "helper");
subchannelLogId = InternalLogId.allocate("Subchannel", /*details=*/ authority());
@ -1992,7 +2020,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
public List<EquivalentAddressGroup> getAllAddresses() {
syncContext.throwIfNotInThisSynchronizationContext();
checkState(started, "not started");
return subchannel.getAddressGroups();
return addressGroups;
}
@Override
@ -2029,8 +2057,24 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override
public void updateAddresses(List<EquivalentAddressGroup> addrs) {
syncContext.throwIfNotInThisSynchronizationContext();
addressGroups = addrs;
if (authorityOverride != null) {
addrs = stripOverrideAuthorityAttributes(addrs);
}
subchannel.updateAddresses(addrs);
}
private List<EquivalentAddressGroup> stripOverrideAuthorityAttributes(
List<EquivalentAddressGroup> eags) {
List<EquivalentAddressGroup> eagsWithoutOverrideAttr = new ArrayList<>();
for (EquivalentAddressGroup eag : eags) {
EquivalentAddressGroup eagWithoutOverrideAttr = new EquivalentAddressGroup(
eag.getAddresses(),
eag.getAttributes().toBuilder().discard(ATTR_AUTHORITY_OVERRIDE).build());
eagsWithoutOverrideAttr.add(eagWithoutOverrideAttr);
}
return Collections.unmodifiableList(eagsWithoutOverrideAttr);
}
}
@Override

View File

@ -108,7 +108,7 @@ public final class ManagedChannelImplBuilder
final NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry();
// Access via getter, which may perform authority override as needed
private NameResolver.Factory nameResolverFactory = nameResolverRegistry.asFactory();
NameResolver.Factory nameResolverFactory = nameResolverRegistry.asFactory();
final String target;
@Nullable
@ -123,7 +123,7 @@ public final class ManagedChannelImplBuilder
String userAgent;
@Nullable
private String authorityOverride;
String authorityOverride;
String defaultLbPolicy = GrpcUtil.DEFAULT_LB_POLICY;
@ -393,12 +393,6 @@ public final class ManagedChannelImplBuilder
return this;
}
@Nullable
@VisibleForTesting
String getOverrideAuthority() {
return authorityOverride;
}
@Override
public ManagedChannelImplBuilder idleTimeout(long value, TimeUnit unit) {
checkArgument(value > 0, "idle timeout is %s, but must be positive", value);
@ -699,17 +693,6 @@ public final class ManagedChannelImplBuilder
return channelBuilderDefaultPortProvider.getDefaultPort();
}
/**
* Returns a {@link NameResolver.Factory} for the channel.
*/
NameResolver.Factory getNameResolverFactory() {
if (authorityOverride == null) {
return nameResolverFactory;
} else {
return new OverrideAuthorityNameResolverFactory(nameResolverFactory, authorityOverride);
}
}
private static class DirectAddressNameResolverFactory extends NameResolver.Factory {
final SocketAddress address;
final String authority;

View File

@ -1,63 +0,0 @@
/*
* Copyright 2017 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import io.grpc.NameResolver;
import java.net.URI;
import javax.annotation.Nullable;
/**
* A wrapper class that overrides the authority of a NameResolver, while preserving all other
* functionality.
*/
final class OverrideAuthorityNameResolverFactory extends NameResolver.Factory {
private final NameResolver.Factory delegate;
private final String authorityOverride;
/**
* Constructor for the {@link NameResolver.Factory}
*
* @param delegate The actual underlying factory that will produce the a {@link NameResolver}
* @param authorityOverride The authority that will be returned by {@link
* NameResolver#getServiceAuthority()}
*/
OverrideAuthorityNameResolverFactory(NameResolver.Factory delegate, String authorityOverride) {
this.delegate = delegate;
this.authorityOverride = authorityOverride;
}
@Nullable
@Override
public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
final NameResolver resolver = delegate.newNameResolver(targetUri, args);
// Do not wrap null values. We do not want to impede error signaling.
if (resolver == null) {
return null;
}
return new ForwardingNameResolver(resolver) {
@Override
public String getServiceAuthority() {
return authorityOverride;
}
};
}
@Override
public String getDefaultScheme() {
return delegate.getDefaultScheme();
}
}

View File

@ -180,7 +180,7 @@ public class ManagedChannelImplBuilderTest {
@Test
public void nameResolverFactory_default() {
assertNotNull(builder.getNameResolverFactory());
assertNotNull(builder.nameResolverFactory);
}
@Test
@ -188,16 +188,16 @@ public class ManagedChannelImplBuilderTest {
public void nameResolverFactory_normal() {
NameResolver.Factory nameResolverFactory = mock(NameResolver.Factory.class);
assertEquals(builder, builder.nameResolverFactory(nameResolverFactory));
assertEquals(nameResolverFactory, builder.getNameResolverFactory());
assertEquals(nameResolverFactory, builder.nameResolverFactory);
}
@Test
@SuppressWarnings("deprecation")
public void nameResolverFactory_null() {
NameResolver.Factory defaultValue = builder.getNameResolverFactory();
NameResolver.Factory defaultValue = builder.nameResolverFactory;
builder.nameResolverFactory(mock(NameResolver.Factory.class));
assertEquals(builder, builder.nameResolverFactory(null));
assertEquals(defaultValue, builder.getNameResolverFactory());
assertEquals(defaultValue, builder.nameResolverFactory);
}
@Test(expected = IllegalStateException.class)
@ -334,14 +334,14 @@ public class ManagedChannelImplBuilderTest {
@Test
public void overrideAuthority_default() {
assertNull(builder.getOverrideAuthority());
assertNull(builder.authorityOverride);
}
@Test
public void overrideAuthority_normal() {
String overrideAuthority = "best-authority";
assertEquals(builder, builder.overrideAuthority(overrideAuthority));
assertEquals(overrideAuthority, builder.getOverrideAuthority());
assertEquals(overrideAuthority, builder.authorityOverride);
}
@Test(expected = NullPointerException.class)
@ -354,14 +354,6 @@ public class ManagedChannelImplBuilderTest {
builder.overrideAuthority("not_allowed");
}
@Test
public void overrideAuthority_getNameResolverFactory() {
assertNull(builder.getOverrideAuthority());
assertFalse(builder.getNameResolverFactory() instanceof OverrideAuthorityNameResolverFactory);
builder.overrideAuthority("google.com");
assertTrue(builder.getNameResolverFactory() instanceof OverrideAuthorityNameResolverFactory);
}
@Test
public void checkAuthority_validAuthorityAllowed() {
assertEquals(DUMMY_AUTHORITY_VALID, builder.checkAuthority(DUMMY_AUTHORITY_VALID));

View File

@ -16,6 +16,7 @@
package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
@ -32,7 +33,8 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link ManagedChannelImpl#getNameResolver}. */
/** Unit tests for {@link ManagedChannelImpl#getNameResolver(
* String, String,NameResolver.Factory, NameResolver.Args)}. */
@RunWith(JUnit4.class)
public class ManagedChannelImplGetNameResolverTest {
private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder()
@ -60,6 +62,17 @@ public class ManagedChannelImplGetNameResolverTest {
new URI("defaultscheme", "", "/foo.googleapis.com:8080", null));
}
@Test
public void validAuthorityTarget_overrideAuthority() throws Exception {
String target = "foo.googleapis.com:8080";
String overrideAuthority = "override.authority";
URI expectedUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null);
NameResolver.Factory nameResolverFactory = new FakeNameResolverFactory(expectedUri.getScheme());
NameResolver nameResolver = ManagedChannelImpl.getNameResolver(
target, overrideAuthority, nameResolverFactory, NAMERESOLVER_ARGS);
assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority);
}
@Test
public void validUriTarget() throws Exception {
testValidTarget("scheme:///foo.googleapis.com:8080", "scheme:///foo.googleapis.com:8080",
@ -116,7 +129,7 @@ public class ManagedChannelImplGetNameResolverTest {
};
try {
ManagedChannelImpl.getNameResolver(
"foo.googleapis.com:8080", nameResolverFactory, NAMERESOLVER_ARGS);
"foo.googleapis.com:8080", null, nameResolverFactory, NAMERESOLVER_ARGS);
fail("Should fail");
} catch (IllegalArgumentException e) {
// expected
@ -126,7 +139,7 @@ public class ManagedChannelImplGetNameResolverTest {
private void testValidTarget(String target, String expectedUriString, URI expectedUri) {
NameResolver.Factory nameResolverFactory = new FakeNameResolverFactory(expectedUri.getScheme());
FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver(
target, nameResolverFactory, NAMERESOLVER_ARGS);
target, null, nameResolverFactory, NAMERESOLVER_ARGS);
assertNotNull(nameResolver);
assertEquals(expectedUri, nameResolver.uri);
assertEquals(expectedUriString, nameResolver.uri.toString());
@ -137,7 +150,7 @@ public class ManagedChannelImplGetNameResolverTest {
try {
FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver(
target, nameResolverFactory, NAMERESOLVER_ARGS);
target, null, nameResolverFactory, NAMERESOLVER_ARGS);
fail("Should have failed, but got resolver with " + nameResolver.uri);
} catch (IllegalArgumentException e) {
// expected

View File

@ -24,6 +24,7 @@ import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE;
import static junit.framework.TestCase.assertNotSame;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@ -380,6 +381,64 @@ public class ManagedChannelImplTest {
}
}
@Test
public void createSubchannel_resolverOverrideAuthority() {
EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(
socketAddress,
Attributes.newBuilder()
.set(ATTR_AUTHORITY_OVERRIDE, "resolver.override.authority")
.build());
channelBuilder.nameResolverFactory(
new FakeNameResolverFactory.Builder(expectedUri)
.setServers(Collections.singletonList(addressGroup))
.build());
createChannel();
Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
ArgumentCaptor<ClientTransportOptions> transportOptionCaptor = ArgumentCaptor.forClass(null);
verify(mockTransportFactory)
.newClientTransport(
any(SocketAddress.class), transportOptionCaptor.capture(), any(ChannelLogger.class));
assertThat(transportOptionCaptor.getValue().getAuthority())
.isEqualTo("resolver.override.authority");
}
@Test
public void createSubchannel_channelBuilderOverrideAuthority() {
channelBuilder.overrideAuthority("channel-builder.override.authority");
EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(
socketAddress,
Attributes.newBuilder()
.set(ATTR_AUTHORITY_OVERRIDE, "resolver.override.authority")
.build());
channelBuilder.nameResolverFactory(
new FakeNameResolverFactory.Builder(expectedUri)
.setServers(Collections.singletonList(addressGroup))
.build());
createChannel();
final Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
ArgumentCaptor<ClientTransportOptions> transportOptionCaptor = ArgumentCaptor.forClass(null);
verify(mockTransportFactory)
.newClientTransport(
any(SocketAddress.class), transportOptionCaptor.capture(), any(ChannelLogger.class));
assertThat(transportOptionCaptor.getValue().getAuthority())
.isEqualTo("channel-builder.override.authority");
final List<EquivalentAddressGroup> subchannelEags = new ArrayList<>();
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
subchannelEags.addAll(subchannel.getAllAddresses());
}
});
assertThat(subchannelEags).isEqualTo(ImmutableList.of(addressGroup));
}
@Test
public void idleModeDisabled() {
channelBuilder.nameResolverFactory(

View File

@ -1,95 +0,0 @@
/*
* Copyright 2017 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import static junit.framework.TestCase.assertNotNull;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.ChannelLogger;
import io.grpc.NameResolver;
import io.grpc.NameResolver.ServiceConfigParser;
import io.grpc.ProxyDetector;
import io.grpc.SynchronizationContext;
import java.lang.Thread.UncaughtExceptionHandler;
import java.net.URI;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link OverrideAuthorityNameResolverFactory}. */
@RunWith(JUnit4.class)
public class OverrideAuthorityNameResolverTest {
private static final NameResolver.Args ARGS = NameResolver.Args.newBuilder()
.setDefaultPort(8080)
.setProxyDetector(mock(ProxyDetector.class))
.setSynchronizationContext(new SynchronizationContext(mock(UncaughtExceptionHandler.class)))
.setServiceConfigParser(mock(ServiceConfigParser.class))
.setChannelLogger(mock(ChannelLogger.class))
.build();
@Test
public void overridesAuthority() {
NameResolver nameResolverMock = mock(NameResolver.class);
NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
when(wrappedFactory.newNameResolver(any(URI.class), any(NameResolver.Args.class)))
.thenReturn(nameResolverMock);
String override = "override:5678";
NameResolver.Factory factory =
new OverrideAuthorityNameResolverFactory(wrappedFactory, override);
NameResolver nameResolver = factory.newNameResolver(URI.create("dns:///localhost:443"), ARGS);
assertNotNull(nameResolver);
assertEquals(override, nameResolver.getServiceAuthority());
}
@Test
public void wontWrapNull() {
NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
when(wrappedFactory.newNameResolver(any(URI.class), any(NameResolver.Args.class)))
.thenReturn(null);
NameResolver.Factory factory =
new OverrideAuthorityNameResolverFactory(wrappedFactory, "override:5678");
assertEquals(null,
factory.newNameResolver(URI.create("dns:///localhost:443"), ARGS));
}
@Test
public void forwardsNonOverridenCalls() {
NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
NameResolver mockResolver = mock(NameResolver.class);
when(wrappedFactory.newNameResolver(any(URI.class), any(NameResolver.Args.class)))
.thenReturn(mockResolver);
NameResolver.Factory factory =
new OverrideAuthorityNameResolverFactory(wrappedFactory, "override:5678");
NameResolver overrideResolver =
factory.newNameResolver(URI.create("dns:///localhost:443"), ARGS);
assertNotNull(overrideResolver);
NameResolver.Listener2 listener = mock(NameResolver.Listener2.class);
overrideResolver.start(listener);
verify(mockResolver).start(listener);
overrideResolver.shutdown();
verify(mockResolver).shutdown();
overrideResolver.refresh();
verify(mockResolver).refresh();
}
}