From 45a151810c64a4f096823b891009ff8152d77a87 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Thu, 28 Jan 2021 09:49:53 -0800 Subject: [PATCH] all: implement Helper.createResolvingOobChannelBuilder(target, creds) - Add APIs to `ClientTransportFactory`: ```java public interface ClientTransportFactory { /** * Swaps to a new ChannelCredentials with all other settings unchanged. Returns null if the * ChannelCredentials is not supported by the current ClientTransportFactory settings. */ SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds); final class SwapChannelCredentialsResult { final ClientTransportFactory transportFactory; @Nullable final CallCredentials callCredentials; } } ``` - Add `ChannelCredentials` to constructor args of `ManagedChannelImplBuilder`: ```java public ManagedChannelImplBuilder( String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds, ...) ``` --- .../inprocess/InProcessChannelBuilder.java | 6 ++ ...llCredentialsApplyingTransportFactory.java | 8 +- .../grpc/internal/ClientTransportFactory.java | 22 +++++ .../io/grpc/internal/ManagedChannelImpl.java | 97 ++++++++++++++----- .../internal/ManagedChannelImplBuilder.java | 13 ++- .../InProcessChannelBuilderTest.java | 10 ++ .../grpc/internal/ManagedChannelImplTest.java | 89 ++++++++++++++++- .../io/grpc/cronet/CronetChannelBuilder.java | 6 ++ .../io/grpc/netty/NettyChannelBuilder.java | 40 +++++--- .../io/grpc/netty/NettyChannelProvider.java | 4 +- .../grpc/netty/NettyChannelBuilderTest.java | 19 ++++ .../io/grpc/okhttp/OkHttpChannelBuilder.java | 46 +++++++-- .../io/grpc/okhttp/OkHttpChannelProvider.java | 2 +- .../grpc/okhttp/OkHttpChannelBuilderTest.java | 15 +++ 14 files changed, 328 insertions(+), 49 deletions(-) diff --git a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index 63069f2ea7..8d285897fc 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -19,6 +19,7 @@ package io.grpc.inprocess; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.ExperimentalApi; import io.grpc.Internal; @@ -246,6 +247,11 @@ public final class InProcessChannelBuilder extends return timerService; } + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + return null; + } + @Override public void close() { if (closed) { diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index de96a3306b..0b1ce3514a 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -20,9 +20,10 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkNotNull; import io.grpc.Attributes; -import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallCredentials; +import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; +import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.CompositeCallCredentials; import io.grpc.Metadata; @@ -61,6 +62,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa return delegate.getScheduledExecutorService(); } + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + throw new UnsupportedOperationException(); + } + @Override public void close() { delegate.close(); diff --git a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java index 9be5431144..4d2ee92a0a 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java @@ -19,11 +19,14 @@ package io.grpc.internal; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.HttpConnectProxiedSocketAddress; import java.io.Closeable; import java.net.SocketAddress; import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** Pre-configured factory for creating {@link ConnectionClientTransport} instances. */ @@ -53,6 +56,14 @@ public interface ClientTransportFactory extends Closeable { */ ScheduledExecutorService getScheduledExecutorService(); + /** + * Swaps to a new ChannelCredentials with all other settings unchanged. Returns null if the + * ChannelCredentials is not supported by the current ClientTransportFactory settings. + */ + @CheckReturnValue + @Nullable + SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds); + /** * Releases any resources. * @@ -143,4 +154,15 @@ public interface ClientTransportFactory extends Closeable { && Objects.equal(this.connectProxiedSocketAddr, that.connectProxiedSocketAddr); } } + + final class SwapChannelCredentialsResult { + final ClientTransportFactory transportFactory; + @Nullable final CallCredentials callCredentials; + + public SwapChannelCredentialsResult( + ClientTransportFactory transportFactory, @Nullable CallCredentials callCredentials) { + this.transportFactory = Preconditions.checkNotNull(transportFactory, "transportFactory"); + this.callCredentials = callCredentials; + } + } } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 5bba60220e..5f6ef46769 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -30,8 +30,10 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ClientCall; @@ -46,6 +48,7 @@ import io.grpc.DecompressorRegistry; import io.grpc.EquivalentAddressGroup; import io.grpc.ForwardingChannelBuilder; import io.grpc.ForwardingClientCall; +import io.grpc.Grpc; import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelTrace; @@ -74,8 +77,9 @@ import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer; import io.grpc.internal.ClientCallImpl.ClientStreamProvider; +import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; +import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; -import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo; import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelector; import io.grpc.internal.RetriableStream.ChannelBufferMeter; @@ -153,6 +157,8 @@ final class ManagedChannelImpl extends ManagedChannel implements private final NameResolver.Args nameResolverArgs; private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; private final ClientTransportFactory originalTransportFactory; + @Nullable + private final ChannelCredentials originalChannelCreds; private final ClientTransportFactory transportFactory; private final ClientTransportFactory oobTransportFactory; private final RestrictedScheduledExecutor scheduledExecutor; @@ -593,6 +599,7 @@ final class ManagedChannelImpl extends ManagedChannel implements this.timeProvider = checkNotNull(timeProvider, "timeProvider"); this.executorPool = checkNotNull(builder.executorPool, "executorPool"); this.executor = checkNotNull(executorPool.getObject(), "executor"); + this.originalChannelCreds = builder.channelCredentials; this.originalTransportFactory = clientTransportFactory; this.transportFactory = new CallCredentialsApplyingTransportFactory( clientTransportFactory, builder.callCredentials, this.executor); @@ -1516,50 +1523,82 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public ManagedChannelBuilder createResolvingOobChannelBuilder(String target) { + return createResolvingOobChannelBuilder(target, new DefaultChannelCreds()); + } + + // TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated + // TODO(zdapeng) register the channel as a subchannel of the parent channel in channelz. + @Override + public ManagedChannelBuilder createResolvingOobChannelBuilder( + final String target, final ChannelCredentials channelCreds) { + checkNotNull(channelCreds, "channelCreds"); + final class ResolvingOobChannelBuilder extends ForwardingChannelBuilder { - private final ManagedChannelImplBuilder managedChannelImplBuilder; + final ManagedChannelBuilder delegate; - ResolvingOobChannelBuilder(String target) { - managedChannelImplBuilder = new ManagedChannelImplBuilder(target, - new UnsupportedClientTransportFactoryBuilder(), + ResolvingOobChannelBuilder() { + final ClientTransportFactory transportFactory; + CallCredentials callCredentials; + if (channelCreds instanceof DefaultChannelCreds) { + transportFactory = originalTransportFactory; + callCredentials = null; + } else { + SwapChannelCredentialsResult swapResult = + originalTransportFactory.swapChannelCredentials(channelCreds); + if (swapResult == null) { + delegate = Grpc.newChannelBuilder(target, channelCreds); + return; + } else { + transportFactory = swapResult.transportFactory; + callCredentials = swapResult.callCredentials; + } + } + ClientTransportFactoryBuilder transportFactoryBuilder = + new ClientTransportFactoryBuilder() { + @Override + public ClientTransportFactory buildClientTransportFactory() { + return transportFactory; + } + }; + delegate = new ManagedChannelImplBuilder( + target, + channelCreds, + callCredentials, + transportFactoryBuilder, new FixedPortProvider(nameResolverArgs.getDefaultPort())); - managedChannelImplBuilder.executorPool = executorPool; - managedChannelImplBuilder.offloadExecutorPool = offloadExecutorHolder.pool; } @Override protected ManagedChannelBuilder delegate() { - return managedChannelImplBuilder; - } - - @Override - public ManagedChannel build() { - // TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated - return new ManagedChannelImpl( - managedChannelImplBuilder, - originalTransportFactory, - backoffPolicyProvider, - balancerRpcExecutorPool, - stopwatchSupplier, - Collections.emptyList(), - timeProvider); + return delegate; } } checkState(!terminated, "Channel is terminated"); @SuppressWarnings("deprecation") - ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder(target) + ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder() .nameResolverFactory(nameResolverFactory); return builder - .overrideAuthority(getAuthority()) + .overrideAuthority(ManagedChannelImpl.this.authority()) + // TODO(zdapeng): executors should not outlive the parent channel. + .executor(executor) + .offloadExecutor(offloadExecutorHolder.getExecutor()) .maxTraceEvents(maxTraceEvents) .proxyDetector(nameResolverArgs.getProxyDetector()) .userAgent(userAgent); } + @Override + public ChannelCredentials getUnsafeChannelCredentials() { + if (originalChannelCreds == null) { + return new DefaultChannelCreds(); + } + return originalChannelCreds; + } + @Override public void updateOobChannelAddresses(ManagedChannel channel, EquivalentAddressGroup eag) { checkArgument(channel instanceof OobChannel, @@ -1596,6 +1635,18 @@ final class ManagedChannelImpl extends ManagedChannel implements public NameResolverRegistry getNameResolverRegistry() { return nameResolverRegistry; } + + /** + * A placeholder for channel creds if user did not specify channel creds for the channel. + */ + // TODO(zdapeng): get rid of this class and let all ChannelBuilders always provide a non-null + // channel creds. + final class DefaultChannelCreds extends ChannelCredentials { + @Override + public ChannelCredentials withoutBearerTokens() { + return this; + } + } } private final class NameResolverListener extends NameResolver.Listener2 { diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index e7806f56ec..13f3672d43 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -24,6 +24,7 @@ import com.google.common.util.concurrent.MoreExecutors; import io.grpc.Attributes; import io.grpc.BinaryLog; import io.grpc.CallCredentials; +import io.grpc.ChannelCredentials; import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; @@ -111,6 +112,8 @@ public final class ManagedChannelImplBuilder final String target; @Nullable + final ChannelCredentials channelCredentials; + @Nullable final CallCredentials callCredentials; @Nullable @@ -225,18 +228,23 @@ public final class ManagedChannelImplBuilder public ManagedChannelImplBuilder(String target, ClientTransportFactoryBuilder clientTransportFactoryBuilder, @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { - this(target, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider); + this(target, null, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider); } /** * Creates a new managed channel builder with a target string, which can be either a valid {@link * io.grpc.NameResolver}-compliant URI, or an authority string. Transport implementors must * provide client transport factory builder, and may set custom channel default port provider. + * + * @param channelCreds The ChannelCredentials provided by the user. These may be used when + * creating derivative channels. */ - public ManagedChannelImplBuilder(String target, @Nullable CallCredentials callCreds, + public ManagedChannelImplBuilder( + String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds, ClientTransportFactoryBuilder clientTransportFactoryBuilder, @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { this.target = Preconditions.checkNotNull(target, "target"); + this.channelCredentials = channelCreds; this.callCredentials = callCreds; this.clientTransportFactoryBuilder = Preconditions .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); @@ -273,6 +281,7 @@ public final class ManagedChannelImplBuilder ClientTransportFactoryBuilder clientTransportFactoryBuilder, @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { this.target = makeTargetStringForDirectAddress(directServerAddress); + this.channelCredentials = null; this.callCredentials = null; this.clientTransportFactoryBuilder = Preconditions .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); diff --git a/core/src/test/java/io/grpc/inprocess/InProcessChannelBuilderTest.java b/core/src/test/java/io/grpc/inprocess/InProcessChannelBuilderTest.java index 58efb7fa26..714d19cc74 100644 --- a/core/src/test/java/io/grpc/inprocess/InProcessChannelBuilderTest.java +++ b/core/src/test/java/io/grpc/inprocess/InProcessChannelBuilderTest.java @@ -16,9 +16,12 @@ package io.grpc.inprocess; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.mock; +import io.grpc.ChannelCredentials; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.FakeClock; import io.grpc.internal.SharedResourceHolder; @@ -60,4 +63,11 @@ public class InProcessChannelBuilderTest { clientTransportFactory.close(); } + + @Test + public void transportFactoryDoesNotSupportSwapChannelCreds() { + InProcessChannelBuilder builder = InProcessChannelBuilder.forName("foo"); + ClientTransportFactory transportFactory = builder.buildTransportFactory(); + assertThat(transportFactory.swapChannelCredentials(mock(ChannelCredentials.class))).isNull(); + } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index a57e19759e..99bdd1244f 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -63,15 +63,18 @@ import io.grpc.CallCredentials; import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.CompositeChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Context; import io.grpc.EquivalentAddressGroup; +import io.grpc.InsecureChannelCredentials; import io.grpc.IntegerMarshaller; import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.ChannelStats; @@ -105,6 +108,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StringMarshaller; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.InternalSubchannel.TransportLogger; import io.grpc.internal.ManagedChannelImpl.ScParser; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; @@ -1662,7 +1666,8 @@ public class ManagedChannelImplTest { Metadata.Key metadataKey = Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER); String channelCredValue = "channel-provided call cred"; - channelBuilder = new ManagedChannelImplBuilder(TARGET, + channelBuilder = new ManagedChannelImplBuilder( + TARGET, InsecureChannelCredentials.create(), new FakeCallCredentials(metadataKey, channelCredValue), new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); configureBuilder(channelBuilder); @@ -1733,11 +1738,91 @@ public class ManagedChannelImplTest { call = oob.newCall(method, callOptions); call.start(mockCallListener2, headers); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + // CallOptions may contain StreamTracerFactory for census that is added by default. + verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); } + @Test + public void oobChannelWithOobChannelCredsHasChannelCallCredentials() { + Metadata.Key metadataKey = + Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER); + String channelCredValue = "channel-provided call cred"; + when(mockTransportFactory.swapChannelCredentials(any(CompositeChannelCredentials.class))) + .thenAnswer(new Answer() { + @Override + public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { + CompositeChannelCredentials c = + invocation.getArgument(0, CompositeChannelCredentials.class); + return new SwapChannelCredentialsResult(mockTransportFactory, c.getCallCredentials()); + } + }); + channelBuilder = new ManagedChannelImplBuilder( + TARGET, InsecureChannelCredentials.create(), + new FakeCallCredentials(metadataKey, channelCredValue), + new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + configureBuilder(channelBuilder); + createChannel(); + + // Verify that the normal channel has call creds, to validate configuration + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + requestConnectionSafely(helper, subchannel); + MockClientTransportInfo transportInfo = transports.poll(); + transportInfo.listener.transportReady(); + when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( + PickResult.withSubchannel(subchannel)); + updateBalancingStateSafely(helper, READY, mockPicker); + + String callCredValue = "per-RPC call cred"; + CallOptions callOptions = CallOptions.DEFAULT + .withCallCredentials(new FakeCallCredentials(metadataKey, callCredValue)); + Metadata headers = new Metadata(); + ClientCall call = channel.newCall(method, callOptions); + call.start(mockCallListener, headers); + + verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + assertThat(headers.getAll(metadataKey)) + .containsExactly(channelCredValue, callCredValue).inOrder(); + + // Verify that resolving oob channel with oob channel creds provides call creds + String oobChannelCredValue = "oob-channel-provided call cred"; + ChannelCredentials oobChannelCreds = CompositeChannelCredentials.create( + InsecureChannelCredentials.create(), + new FakeCallCredentials(metadataKey, oobChannelCredValue)); + ManagedChannel oob = helper.createResolvingOobChannelBuilder("oobauthority", oobChannelCreds) + .nameResolverFactory( + new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build()) + .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) + .idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS) + .build(); + oob.getState(true); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); + verify(mockLoadBalancerProvider, times(2)).newLoadBalancer(helperCaptor.capture()); + Helper oobHelper = helperCaptor.getValue(); + + subchannel = + createSubchannelSafely(oobHelper, addressGroup, Attributes.EMPTY, subchannelStateListener); + requestConnectionSafely(oobHelper, subchannel); + transportInfo = transports.poll(); + transportInfo.listener.transportReady(); + SubchannelPicker mockPicker2 = mock(SubchannelPicker.class); + when(mockPicker2.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( + PickResult.withSubchannel(subchannel)); + updateBalancingStateSafely(oobHelper, READY, mockPicker2); + + headers = new Metadata(); + call = oob.newCall(method, callOptions); + call.start(mockCallListener2, headers); + + // CallOptions may contain StreamTracerFactory for census that is added by default. + verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); + assertThat(headers.getAll(metadataKey)) + .containsExactly(oobChannelCredValue, callCredValue).inOrder(); + oob.shutdownNow(); + } + @Test public void oobChannelsWhenChannelShutdownNow() { createChannel(); diff --git a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java index 6818242133..217928ae94 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java @@ -24,6 +24,7 @@ import android.util.Log; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.ExperimentalApi; import io.grpc.Internal; @@ -269,6 +270,11 @@ public final class CronetChannelBuilder return timeoutService; } + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + return null; + } + @Override public void close() { if (usingSharedScheduler) { diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 0160a0a03b..bb203e2906 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -46,6 +46,7 @@ import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import io.grpc.internal.TransportTracer; +import io.grpc.netty.ProtocolNegotiators.FromChannelCredentialsResult; import io.netty.channel.Channel; import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelOption; @@ -157,12 +158,11 @@ public final class NettyChannelBuilder extends */ @CheckReturnValue public static NettyChannelBuilder forTarget(String target, ChannelCredentials creds) { - ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(creds); + FromChannelCredentialsResult result = ProtocolNegotiators.from(creds); if (result.error != null) { throw new IllegalArgumentException(result.error); } - return new NettyChannelBuilder( - target, result.negotiator, result.callCredentials); + return new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator); } private final class NettyChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder { @@ -187,11 +187,11 @@ public final class NettyChannelBuilder extends this.freezeProtocolNegotiatorFactory = false; } - @CheckReturnValue NettyChannelBuilder( - String target, ProtocolNegotiator.ClientFactory negotiator, - @Nullable CallCredentials callCreds) { - managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCreds, + String target, ChannelCredentials channelCreds, CallCredentials callCreds, + ProtocolNegotiator.ClientFactory negotiator) { + managedChannelImplBuilder = new ManagedChannelImplBuilder( + target, channelCreds, callCreds, new NettyChannelTransportFactoryBuilder(), new NettyChannelDefaultPortProvider()); this.protocolNegotiatorFactory = checkNotNull(negotiator, "negotiator"); @@ -628,7 +628,8 @@ public final class NettyChannelBuilder extends private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; - private final AtomicBackoff keepAliveTimeNanos; + private final long keepAliveTimeNanos; + private final AtomicBackoff keepAliveBackoff; private final long keepAliveTimeoutNanos; private final boolean keepAliveWithoutCalls; private final TransportTracer.Factory transportTracerFactory; @@ -637,7 +638,8 @@ public final class NettyChannelBuilder extends private boolean closed; - NettyTransportFactory(ProtocolNegotiator protocolNegotiator, + NettyTransportFactory( + ProtocolNegotiator protocolNegotiator, ChannelFactory channelFactory, Map, ?> channelOptions, ObjectPool groupPool, boolean autoFlowControl, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, @@ -653,7 +655,8 @@ public final class NettyChannelBuilder extends this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; - this.keepAliveTimeNanos = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); + this.keepAliveTimeNanos = keepAliveTimeNanos; + this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.transportTracerFactory = transportTracerFactory; @@ -678,7 +681,7 @@ public final class NettyChannelBuilder extends protocolNegotiator); } - final AtomicBackoff.State keepAliveTimeNanosState = keepAliveTimeNanos.getState(); + final AtomicBackoff.State keepAliveTimeNanosState = keepAliveBackoff.getState(); Runnable tooManyPingsRunnable = new Runnable() { @Override public void run() { @@ -702,6 +705,21 @@ public final class NettyChannelBuilder extends return group; } + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + checkNotNull(channelCreds, "channelCreds"); + FromChannelCredentialsResult result = ProtocolNegotiators.from(channelCreds); + if (result.error != null) { + return null; + } + ClientTransportFactory factory = new NettyTransportFactory( + result.negotiator.newNegotiator(), channelFactory, channelOptions, groupPool, + autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanos, + keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory, localSocketPicker, + useGetForSafeMethods); + return new SwapChannelCredentialsResult(factory, result.callCredentials); + } + @Override public void close() { if (closed) { diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java index 6bed584053..bf3df4fa6a 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java @@ -49,7 +49,7 @@ public final class NettyChannelProvider extends ManagedChannelProvider { if (result.error != null) { return NewChannelBuilderResult.error(result.error); } - return NewChannelBuilderResult.channelBuilder(new NettyChannelBuilder( - target, result.negotiator, result.callCredentials)); + return NewChannelBuilderResult.channelBuilder( + new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator)); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java index a4184eb47d..b255923abc 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java @@ -16,13 +16,18 @@ package io.grpc.netty; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; +import io.grpc.ChannelCredentials; import io.grpc.ManagedChannel; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest; +import io.grpc.netty.ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory; import io.netty.channel.Channel; import io.netty.channel.ChannelFactory; import io.netty.channel.EventLoopGroup; @@ -282,4 +287,18 @@ public class NettyChannelBuilderTest { builder.assertEventLoopAndChannelType(); } + + @Test + public void transportFactorySupportsNettyChannelCreds() { + NettyChannelBuilder builder = NettyChannelBuilder.forTarget("foo"); + ClientTransportFactory transportFactory = builder.buildTransportFactory(); + + SwapChannelCredentialsResult result = transportFactory.swapChannelCredentials( + mock(ChannelCredentials.class)); + assertThat(result).isNull(); + + result = transportFactory.swapChannelCredentials( + NettyChannelCredentials.create(new PlaintextProtocolNegotiatorClientFactory())); + assertThat(result).isNotNull(); + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 6fed60d769..d003f735d2 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -59,6 +59,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; @@ -145,7 +146,7 @@ public final class OkHttpChannelBuilder extends if (result.error != null) { throw new IllegalArgumentException(result.error); } - return new OkHttpChannelBuilder(target, result.factory, result.callCredentials); + return new OkHttpChannelBuilder(target, creds, result.callCredentials, result.factory); } private Executor transportExecutor; @@ -181,9 +182,11 @@ public final class OkHttpChannelBuilder extends this.freezeSecurityConfiguration = false; } - OkHttpChannelBuilder(String target, @Nullable SSLSocketFactory factory, - @Nullable CallCredentials callCredentials) { - managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCredentials, + OkHttpChannelBuilder( + String target, ChannelCredentials channelCreds, CallCredentials callCreds, + SSLSocketFactory factory) { + managedChannelImplBuilder = new ManagedChannelImplBuilder( + target, channelCreds, callCreds, new OkHttpChannelTransportFactoryBuilder(), new OkHttpChannelDefaultPortProvider()); this.sslSocketFactory = factory; @@ -631,7 +634,8 @@ public final class OkHttpChannelBuilder extends private final ConnectionSpec connectionSpec; private final int maxMessageSize; private final boolean enableKeepAlive; - private final AtomicBackoff keepAliveTimeNanos; + private final long keepAliveTimeNanos; + private final AtomicBackoff keepAliveBackoff; private final long keepAliveTimeoutNanos; private final int flowControlWindow; private final boolean keepAliveWithoutCalls; @@ -665,7 +669,8 @@ public final class OkHttpChannelBuilder extends this.connectionSpec = connectionSpec; this.maxMessageSize = maxMessageSize; this.enableKeepAlive = enableKeepAlive; - this.keepAliveTimeNanos = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); + this.keepAliveTimeNanos = keepAliveTimeNanos; + this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.flowControlWindow = flowControlWindow; this.keepAliveWithoutCalls = keepAliveWithoutCalls; @@ -689,7 +694,7 @@ public final class OkHttpChannelBuilder extends if (closed) { throw new IllegalStateException("The transport factory is closed."); } - final AtomicBackoff.State keepAliveTimeNanosState = keepAliveTimeNanos.getState(); + final AtomicBackoff.State keepAliveTimeNanosState = keepAliveBackoff.getState(); Runnable tooManyPingsRunnable = new Runnable() { @Override public void run() { @@ -727,6 +732,33 @@ public final class OkHttpChannelBuilder extends return timeoutService; } + @Nullable + @CheckReturnValue + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + SslSocketFactoryResult result = sslSocketFactoryFrom(channelCreds); + if (result.error != null) { + return null; + } + ClientTransportFactory factory = new OkHttpTransportFactory( + executor, + timeoutService, + socketFactory, + result.factory, + hostnameVerifier, + connectionSpec, + maxMessageSize, + enableKeepAlive, + keepAliveTimeNanos, + keepAliveTimeoutNanos, + flowControlWindow, + keepAliveWithoutCalls, + maxInboundMetadataSize, + transportTracerFactory, + useGetForSafeMethods); + return new SwapChannelCredentialsResult(factory, result.callCredentials); + } + @Override public void close() { if (closed) { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java index f8caaea512..19f99d0502 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java @@ -55,6 +55,6 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider { return NewChannelBuilderResult.error(result.error); } return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder( - target, result.factory, result.callCredentials)); + target, creds, result.callCredentials, result.factory)); } } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index a2ef6911db..23d140ec9b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -34,6 +34,7 @@ import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.TlsChannelCredentials; import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder; @@ -358,6 +359,20 @@ public class OkHttpChannelBuilderTest { transportFactory.close(); } + @Test + public void transportFactorySupportsOkHttpChannelCreds() { + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo"); + ClientTransportFactory transportFactory = builder.buildTransportFactory(); + + SwapChannelCredentialsResult result = transportFactory.swapChannelCredentials( + mock(ChannelCredentials.class)); + assertThat(result).isNull(); + + result = transportFactory.swapChannelCredentials( + SslSocketFactoryChannelCredentials.create(mock(SSLSocketFactory.class))); + assertThat(result).isNotNull(); + } + private static final class FakeChannelLogger extends ChannelLogger { @Override