From 9437783838791489ea0efdf5a2ad991c68730d1d Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Fri, 29 Jan 2021 09:29:06 -0800 Subject: [PATCH] core: enhance ManagedChannelBuilder.overrideAuthority() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enhance `ManagedChannelBuilder.overrideAuthority()` to make it impossible to use a different authority to a backend by wrapping ClientTransportFactory.newClientTransport() and setting ClientTransportOptions’ authority. To avoid confusing the LB policy, it would need to keep the original addresses to return during `Subchannel.getAddresses()` The class `OverrideAuthorityNameResolverFactory` is deleted and its logic is moved into `ManagedChannelImpl`. --- .../java/io/grpc/EquivalentAddressGroup.java | 3 + .../java/io/grpc/ManagedChannelBuilder.java | 3 + .../io/grpc/internal/ManagedChannelImpl.java | 58 +++++++++-- .../internal/ManagedChannelImplBuilder.java | 21 +--- .../OverrideAuthorityNameResolverFactory.java | 63 ------------ .../ManagedChannelImplBuilderTest.java | 20 ++-- ...ManagedChannelImplGetNameResolverTest.java | 21 +++- .../grpc/internal/ManagedChannelImplTest.java | 59 ++++++++++++ .../OverrideAuthorityNameResolverTest.java | 95 ------------------- 9 files changed, 141 insertions(+), 202 deletions(-) delete mode 100644 core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java delete mode 100644 core/src/test/java/io/grpc/internal/OverrideAuthorityNameResolverTest.java diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index 2fa1099e24..34b2957d83 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -37,6 +37,9 @@ public final class EquivalentAddressGroup { /** * The authority to be used when constructing Subchannels for this EquivalentAddressGroup. + * However, if the channel has overridden authority via + * {@link ManagedChannelBuilder#overrideAuthority(String)}, the transport will use the channel's + * authority override. */ @Attr @ExperimentalApi("https://github.com/grpc/grpc-java/issues/6138") diff --git a/api/src/main/java/io/grpc/ManagedChannelBuilder.java b/api/src/main/java/io/grpc/ManagedChannelBuilder.java index a340dc408a..e4a4611541 100644 --- a/api/src/main/java/io/grpc/ManagedChannelBuilder.java +++ b/api/src/main/java/io/grpc/ManagedChannelBuilder.java @@ -163,6 +163,9 @@ public abstract class ManagedChannelBuilder> * Overrides the authority used with TLS and HTTP virtual hosting. It does not change what host is * actually connected to. Is commonly in the form {@code host:port}. * + *

If the channel builder overrides authority, any authority override from name resolution + * result (via {@link EquivalentAddressGroup#ATTR_AUTHORITY_OVERRIDE}) will be discarded. + * *

This method is intended for testing, but may safely be used outside of tests as an * alternative to DNS overrides. * diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 5f6ef46769..79af752ddf 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -22,6 +22,7 @@ import static com.google.common.base.Preconditions.checkState; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; @@ -152,6 +153,8 @@ final class ManagedChannelImpl extends ManagedChannel implements private final InternalLogId logId; private final String target; + @Nullable + private final String authorityOverride; private final NameResolverRegistry nameResolverRegistry; private final NameResolver.Factory nameResolverFactory; private final NameResolver.Args nameResolverArgs; @@ -361,7 +364,8 @@ final class ManagedChannelImpl extends ManagedChannel implements nameResolver.shutdown(); nameResolverStarted = false; if (channelIsActive) { - nameResolver = getNameResolver(target, nameResolverFactory, nameResolverArgs); + nameResolver = getNameResolver( + target, authorityOverride, nameResolverFactory, nameResolverArgs); } else { nameResolver = null; } @@ -612,7 +616,6 @@ final class ManagedChannelImpl extends ManagedChannel implements logId, builder.maxTraceEvents, timeProvider.currentTimeNanos(), "Channel for '" + target + "'"); channelLogger = new ChannelLoggerImpl(channelTracer, timeProvider); - this.nameResolverFactory = builder.getNameResolverFactory(); ProxyDetector proxyDetector = builder.proxyDetector != null ? builder.proxyDetector : GrpcUtil.DEFAULT_PROXY_DETECTOR; this.retryEnabled = builder.retryEnabled && !builder.temporarilyDisableRetry; @@ -644,7 +647,10 @@ final class ManagedChannelImpl extends ManagedChannel implements } }) .build(); - this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverArgs); + this.authorityOverride = builder.authorityOverride; + this.nameResolverFactory = builder.nameResolverFactory; + this.nameResolver = getNameResolver( + target, authorityOverride, nameResolverFactory, nameResolverArgs); this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool"); this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool); this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext); @@ -715,9 +721,8 @@ final class ManagedChannelImpl extends ManagedChannel implements } } - @VisibleForTesting - static NameResolver getNameResolver(String target, NameResolver.Factory nameResolverFactory, - NameResolver.Args nameResolverArgs) { + private static NameResolver getNameResolver( + String target, NameResolver.Factory nameResolverFactory, NameResolver.Args nameResolverArgs) { // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending // "dns:///". URI targetUri = null; @@ -760,6 +765,22 @@ final class ManagedChannelImpl extends ManagedChannel implements target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); } + @VisibleForTesting + static NameResolver getNameResolver( + String target, @Nullable final String overrideAuthority, + NameResolver.Factory nameResolverFactory, NameResolver.Args nameResolverArgs) { + NameResolver resolver = getNameResolver(target, nameResolverFactory, nameResolverArgs); + if (overrideAuthority == null) { + return resolver; + } + return new ForwardingNameResolver(resolver) { + @Override + public String getServiceAuthority() { + return overrideAuthority; + } + }; + } + @VisibleForTesting InternalConfigSelector getConfigSelector() { return realChannel.configSelector.get(); @@ -1850,12 +1871,19 @@ final class ManagedChannelImpl extends ManagedChannel implements final InternalLogId subchannelLogId; final ChannelLoggerImpl subchannelLogger; final ChannelTracer subchannelTracer; + List addressGroups; InternalSubchannel subchannel; boolean started; boolean shutdown; ScheduledHandle delayedShutdownTask; SubchannelImpl(CreateSubchannelArgs args, LbHelperImpl helper) { + addressGroups = args.getAddresses(); + if (authorityOverride != null) { + List eagsWithoutOverrideAttr = + stripOverrideAuthorityAttributes(args.getAddresses()); + args = args.toBuilder().setAddresses(eagsWithoutOverrideAttr).build(); + } this.args = checkNotNull(args, "args"); this.helper = checkNotNull(helper, "helper"); subchannelLogId = InternalLogId.allocate("Subchannel", /*details=*/ authority()); @@ -1992,7 +2020,7 @@ final class ManagedChannelImpl extends ManagedChannel implements public List getAllAddresses() { syncContext.throwIfNotInThisSynchronizationContext(); checkState(started, "not started"); - return subchannel.getAddressGroups(); + return addressGroups; } @Override @@ -2029,8 +2057,24 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public void updateAddresses(List addrs) { syncContext.throwIfNotInThisSynchronizationContext(); + addressGroups = addrs; + if (authorityOverride != null) { + addrs = stripOverrideAuthorityAttributes(addrs); + } subchannel.updateAddresses(addrs); } + + private List stripOverrideAuthorityAttributes( + List eags) { + List eagsWithoutOverrideAttr = new ArrayList<>(); + for (EquivalentAddressGroup eag : eags) { + EquivalentAddressGroup eagWithoutOverrideAttr = new EquivalentAddressGroup( + eag.getAddresses(), + eag.getAttributes().toBuilder().discard(ATTR_AUTHORITY_OVERRIDE).build()); + eagsWithoutOverrideAttr.add(eagWithoutOverrideAttr); + } + return Collections.unmodifiableList(eagsWithoutOverrideAttr); + } } @Override diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 13f3672d43..c74eed6df7 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -108,7 +108,7 @@ public final class ManagedChannelImplBuilder final NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry(); // Access via getter, which may perform authority override as needed - private NameResolver.Factory nameResolverFactory = nameResolverRegistry.asFactory(); + NameResolver.Factory nameResolverFactory = nameResolverRegistry.asFactory(); final String target; @Nullable @@ -123,7 +123,7 @@ public final class ManagedChannelImplBuilder String userAgent; @Nullable - private String authorityOverride; + String authorityOverride; String defaultLbPolicy = GrpcUtil.DEFAULT_LB_POLICY; @@ -393,12 +393,6 @@ public final class ManagedChannelImplBuilder return this; } - @Nullable - @VisibleForTesting - String getOverrideAuthority() { - return authorityOverride; - } - @Override public ManagedChannelImplBuilder idleTimeout(long value, TimeUnit unit) { checkArgument(value > 0, "idle timeout is %s, but must be positive", value); @@ -699,17 +693,6 @@ public final class ManagedChannelImplBuilder return channelBuilderDefaultPortProvider.getDefaultPort(); } - /** - * Returns a {@link NameResolver.Factory} for the channel. - */ - NameResolver.Factory getNameResolverFactory() { - if (authorityOverride == null) { - return nameResolverFactory; - } else { - return new OverrideAuthorityNameResolverFactory(nameResolverFactory, authorityOverride); - } - } - private static class DirectAddressNameResolverFactory extends NameResolver.Factory { final SocketAddress address; final String authority; diff --git a/core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java b/core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java deleted file mode 100644 index d379b04f8b..0000000000 --- a/core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2017 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.internal; - -import io.grpc.NameResolver; -import java.net.URI; -import javax.annotation.Nullable; - -/** - * A wrapper class that overrides the authority of a NameResolver, while preserving all other - * functionality. - */ -final class OverrideAuthorityNameResolverFactory extends NameResolver.Factory { - private final NameResolver.Factory delegate; - private final String authorityOverride; - - /** - * Constructor for the {@link NameResolver.Factory} - * - * @param delegate The actual underlying factory that will produce the a {@link NameResolver} - * @param authorityOverride The authority that will be returned by {@link - * NameResolver#getServiceAuthority()} - */ - OverrideAuthorityNameResolverFactory(NameResolver.Factory delegate, String authorityOverride) { - this.delegate = delegate; - this.authorityOverride = authorityOverride; - } - - @Nullable - @Override - public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { - final NameResolver resolver = delegate.newNameResolver(targetUri, args); - // Do not wrap null values. We do not want to impede error signaling. - if (resolver == null) { - return null; - } - return new ForwardingNameResolver(resolver) { - @Override - public String getServiceAuthority() { - return authorityOverride; - } - }; - } - - @Override - public String getDefaultScheme() { - return delegate.getDefaultScheme(); - } -} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index 9534c16fbb..bc5b7cde65 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -180,7 +180,7 @@ public class ManagedChannelImplBuilderTest { @Test public void nameResolverFactory_default() { - assertNotNull(builder.getNameResolverFactory()); + assertNotNull(builder.nameResolverFactory); } @Test @@ -188,16 +188,16 @@ public class ManagedChannelImplBuilderTest { public void nameResolverFactory_normal() { NameResolver.Factory nameResolverFactory = mock(NameResolver.Factory.class); assertEquals(builder, builder.nameResolverFactory(nameResolverFactory)); - assertEquals(nameResolverFactory, builder.getNameResolverFactory()); + assertEquals(nameResolverFactory, builder.nameResolverFactory); } @Test @SuppressWarnings("deprecation") public void nameResolverFactory_null() { - NameResolver.Factory defaultValue = builder.getNameResolverFactory(); + NameResolver.Factory defaultValue = builder.nameResolverFactory; builder.nameResolverFactory(mock(NameResolver.Factory.class)); assertEquals(builder, builder.nameResolverFactory(null)); - assertEquals(defaultValue, builder.getNameResolverFactory()); + assertEquals(defaultValue, builder.nameResolverFactory); } @Test(expected = IllegalStateException.class) @@ -334,14 +334,14 @@ public class ManagedChannelImplBuilderTest { @Test public void overrideAuthority_default() { - assertNull(builder.getOverrideAuthority()); + assertNull(builder.authorityOverride); } @Test public void overrideAuthority_normal() { String overrideAuthority = "best-authority"; assertEquals(builder, builder.overrideAuthority(overrideAuthority)); - assertEquals(overrideAuthority, builder.getOverrideAuthority()); + assertEquals(overrideAuthority, builder.authorityOverride); } @Test(expected = NullPointerException.class) @@ -354,14 +354,6 @@ public class ManagedChannelImplBuilderTest { builder.overrideAuthority("not_allowed"); } - @Test - public void overrideAuthority_getNameResolverFactory() { - assertNull(builder.getOverrideAuthority()); - assertFalse(builder.getNameResolverFactory() instanceof OverrideAuthorityNameResolverFactory); - builder.overrideAuthority("google.com"); - assertTrue(builder.getNameResolverFactory() instanceof OverrideAuthorityNameResolverFactory); - } - @Test public void checkAuthority_validAuthorityAllowed() { assertEquals(DUMMY_AUTHORITY_VALID, builder.checkAuthority(DUMMY_AUTHORITY_VALID)); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java index ff0659db16..481db99b11 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; @@ -32,7 +33,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link ManagedChannelImpl#getNameResolver}. */ +/** Unit tests for {@link ManagedChannelImpl#getNameResolver( + * String, String,NameResolver.Factory, NameResolver.Args)}. */ @RunWith(JUnit4.class) public class ManagedChannelImplGetNameResolverTest { private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder() @@ -60,6 +62,17 @@ public class ManagedChannelImplGetNameResolverTest { new URI("defaultscheme", "", "/foo.googleapis.com:8080", null)); } + @Test + public void validAuthorityTarget_overrideAuthority() throws Exception { + String target = "foo.googleapis.com:8080"; + String overrideAuthority = "override.authority"; + URI expectedUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null); + NameResolver.Factory nameResolverFactory = new FakeNameResolverFactory(expectedUri.getScheme()); + NameResolver nameResolver = ManagedChannelImpl.getNameResolver( + target, overrideAuthority, nameResolverFactory, NAMERESOLVER_ARGS); + assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority); + } + @Test public void validUriTarget() throws Exception { testValidTarget("scheme:///foo.googleapis.com:8080", "scheme:///foo.googleapis.com:8080", @@ -116,7 +129,7 @@ public class ManagedChannelImplGetNameResolverTest { }; try { ManagedChannelImpl.getNameResolver( - "foo.googleapis.com:8080", nameResolverFactory, NAMERESOLVER_ARGS); + "foo.googleapis.com:8080", null, nameResolverFactory, NAMERESOLVER_ARGS); fail("Should fail"); } catch (IllegalArgumentException e) { // expected @@ -126,7 +139,7 @@ public class ManagedChannelImplGetNameResolverTest { private void testValidTarget(String target, String expectedUriString, URI expectedUri) { NameResolver.Factory nameResolverFactory = new FakeNameResolverFactory(expectedUri.getScheme()); FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver( - target, nameResolverFactory, NAMERESOLVER_ARGS); + target, null, nameResolverFactory, NAMERESOLVER_ARGS); assertNotNull(nameResolver); assertEquals(expectedUri, nameResolver.uri); assertEquals(expectedUriString, nameResolver.uri.toString()); @@ -137,7 +150,7 @@ public class ManagedChannelImplGetNameResolverTest { try { FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver( - target, nameResolverFactory, NAMERESOLVER_ARGS); + target, null, nameResolverFactory, NAMERESOLVER_ARGS); fail("Should have failed, but got resolver with " + nameResolver.uri); } catch (IllegalArgumentException e) { // expected diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 99bdd1244f..bf72556757 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -24,6 +24,7 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE; import static junit.framework.TestCase.assertNotSame; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -380,6 +381,64 @@ public class ManagedChannelImplTest { } } + @Test + public void createSubchannel_resolverOverrideAuthority() { + EquivalentAddressGroup addressGroup = new EquivalentAddressGroup( + socketAddress, + Attributes.newBuilder() + .set(ATTR_AUTHORITY_OVERRIDE, "resolver.override.authority") + .build()); + channelBuilder.nameResolverFactory( + new FakeNameResolverFactory.Builder(expectedUri) + .setServers(Collections.singletonList(addressGroup)) + .build()); + createChannel(); + + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + requestConnectionSafely(helper, subchannel); + ArgumentCaptor transportOptionCaptor = ArgumentCaptor.forClass(null); + verify(mockTransportFactory) + .newClientTransport( + any(SocketAddress.class), transportOptionCaptor.capture(), any(ChannelLogger.class)); + assertThat(transportOptionCaptor.getValue().getAuthority()) + .isEqualTo("resolver.override.authority"); + } + + @Test + public void createSubchannel_channelBuilderOverrideAuthority() { + channelBuilder.overrideAuthority("channel-builder.override.authority"); + EquivalentAddressGroup addressGroup = new EquivalentAddressGroup( + socketAddress, + Attributes.newBuilder() + .set(ATTR_AUTHORITY_OVERRIDE, "resolver.override.authority") + .build()); + channelBuilder.nameResolverFactory( + new FakeNameResolverFactory.Builder(expectedUri) + .setServers(Collections.singletonList(addressGroup)) + .build()); + createChannel(); + + final Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + requestConnectionSafely(helper, subchannel); + ArgumentCaptor transportOptionCaptor = ArgumentCaptor.forClass(null); + verify(mockTransportFactory) + .newClientTransport( + any(SocketAddress.class), transportOptionCaptor.capture(), any(ChannelLogger.class)); + assertThat(transportOptionCaptor.getValue().getAuthority()) + .isEqualTo("channel-builder.override.authority"); + final List subchannelEags = new ArrayList<>(); + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + subchannelEags.addAll(subchannel.getAllAddresses()); + } + }); + assertThat(subchannelEags).isEqualTo(ImmutableList.of(addressGroup)); + } + @Test public void idleModeDisabled() { channelBuilder.nameResolverFactory( diff --git a/core/src/test/java/io/grpc/internal/OverrideAuthorityNameResolverTest.java b/core/src/test/java/io/grpc/internal/OverrideAuthorityNameResolverTest.java deleted file mode 100644 index 8d23ce3f88..0000000000 --- a/core/src/test/java/io/grpc/internal/OverrideAuthorityNameResolverTest.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright 2017 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.internal; - -import static junit.framework.TestCase.assertNotNull; -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.ChannelLogger; -import io.grpc.NameResolver; -import io.grpc.NameResolver.ServiceConfigParser; -import io.grpc.ProxyDetector; -import io.grpc.SynchronizationContext; -import java.lang.Thread.UncaughtExceptionHandler; -import java.net.URI; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Unit tests for {@link OverrideAuthorityNameResolverFactory}. */ -@RunWith(JUnit4.class) -public class OverrideAuthorityNameResolverTest { - private static final NameResolver.Args ARGS = NameResolver.Args.newBuilder() - .setDefaultPort(8080) - .setProxyDetector(mock(ProxyDetector.class)) - .setSynchronizationContext(new SynchronizationContext(mock(UncaughtExceptionHandler.class))) - .setServiceConfigParser(mock(ServiceConfigParser.class)) - .setChannelLogger(mock(ChannelLogger.class)) - .build(); - - @Test - public void overridesAuthority() { - NameResolver nameResolverMock = mock(NameResolver.class); - NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class); - when(wrappedFactory.newNameResolver(any(URI.class), any(NameResolver.Args.class))) - .thenReturn(nameResolverMock); - String override = "override:5678"; - NameResolver.Factory factory = - new OverrideAuthorityNameResolverFactory(wrappedFactory, override); - NameResolver nameResolver = factory.newNameResolver(URI.create("dns:///localhost:443"), ARGS); - assertNotNull(nameResolver); - assertEquals(override, nameResolver.getServiceAuthority()); - } - - @Test - public void wontWrapNull() { - NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class); - when(wrappedFactory.newNameResolver(any(URI.class), any(NameResolver.Args.class))) - .thenReturn(null); - NameResolver.Factory factory = - new OverrideAuthorityNameResolverFactory(wrappedFactory, "override:5678"); - assertEquals(null, - factory.newNameResolver(URI.create("dns:///localhost:443"), ARGS)); - } - - @Test - public void forwardsNonOverridenCalls() { - NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class); - NameResolver mockResolver = mock(NameResolver.class); - when(wrappedFactory.newNameResolver(any(URI.class), any(NameResolver.Args.class))) - .thenReturn(mockResolver); - NameResolver.Factory factory = - new OverrideAuthorityNameResolverFactory(wrappedFactory, "override:5678"); - NameResolver overrideResolver = - factory.newNameResolver(URI.create("dns:///localhost:443"), ARGS); - assertNotNull(overrideResolver); - NameResolver.Listener2 listener = mock(NameResolver.Listener2.class); - - overrideResolver.start(listener); - verify(mockResolver).start(listener); - - overrideResolver.shutdown(); - verify(mockResolver).shutdown(); - - overrideResolver.refresh(); - verify(mockResolver).refresh(); - } -}