all: implement Helper.createResolvingOobChannelBuilder(target, creds)

- Add APIs to `ClientTransportFactory`:
```java
public interface ClientTransportFactory {
  /**
   * Swaps to a new ChannelCredentials with all other settings unchanged. Returns null if the
   * ChannelCredentials is not supported by the current ClientTransportFactory settings.
   */
  SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds);

  final class SwapChannelCredentialsResult {
    final ClientTransportFactory transportFactory;
    @Nullable final CallCredentials callCredentials;
  }
}
```

- Add `ChannelCredentials` to constructor args of `ManagedChannelImplBuilder`:
 ```java
public ManagedChannelImplBuilder(
      String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds, ...)
  ```
This commit is contained in:
ZHANG Dapeng 2021-01-28 09:49:53 -08:00 committed by GitHub
parent a6df2b2ff4
commit 45a151810c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 328 additions and 49 deletions

View File

@ -19,6 +19,7 @@ package io.grpc.inprocess;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ExperimentalApi; import io.grpc.ExperimentalApi;
import io.grpc.Internal; import io.grpc.Internal;
@ -246,6 +247,11 @@ public final class InProcessChannelBuilder extends
return timerService; return timerService;
} }
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
return null;
}
@Override @Override
public void close() { public void close() {
if (closed) { if (closed) {

View File

@ -20,9 +20,10 @@ import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallCredentials.RequestInfo;
import io.grpc.CallCredentials; import io.grpc.CallCredentials;
import io.grpc.CallCredentials.RequestInfo;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.CompositeCallCredentials; import io.grpc.CompositeCallCredentials;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -61,6 +62,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
return delegate.getScheduledExecutorService(); return delegate.getScheduledExecutorService();
} }
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
throw new UnsupportedOperationException();
}
@Override @Override
public void close() { public void close() {
delegate.close(); delegate.close();

View File

@ -19,11 +19,14 @@ package io.grpc.internal;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.HttpConnectProxiedSocketAddress;
import java.io.Closeable; import java.io.Closeable;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable; import javax.annotation.Nullable;
/** Pre-configured factory for creating {@link ConnectionClientTransport} instances. */ /** Pre-configured factory for creating {@link ConnectionClientTransport} instances. */
@ -53,6 +56,14 @@ public interface ClientTransportFactory extends Closeable {
*/ */
ScheduledExecutorService getScheduledExecutorService(); ScheduledExecutorService getScheduledExecutorService();
/**
* Swaps to a new ChannelCredentials with all other settings unchanged. Returns null if the
* ChannelCredentials is not supported by the current ClientTransportFactory settings.
*/
@CheckReturnValue
@Nullable
SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds);
/** /**
* Releases any resources. * Releases any resources.
* *
@ -143,4 +154,15 @@ public interface ClientTransportFactory extends Closeable {
&& Objects.equal(this.connectProxiedSocketAddr, that.connectProxiedSocketAddr); && Objects.equal(this.connectProxiedSocketAddr, that.connectProxiedSocketAddr);
} }
} }
final class SwapChannelCredentialsResult {
final ClientTransportFactory transportFactory;
@Nullable final CallCredentials callCredentials;
public SwapChannelCredentialsResult(
ClientTransportFactory transportFactory, @Nullable CallCredentials callCredentials) {
this.transportFactory = Preconditions.checkNotNull(transportFactory, "transportFactory");
this.callCredentials = callCredentials;
}
}
} }

View File

@ -30,8 +30,10 @@ import com.google.common.base.Supplier;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
@ -46,6 +48,7 @@ import io.grpc.DecompressorRegistry;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
import io.grpc.ForwardingChannelBuilder; import io.grpc.ForwardingChannelBuilder;
import io.grpc.ForwardingClientCall; import io.grpc.ForwardingClientCall;
import io.grpc.Grpc;
import io.grpc.InternalChannelz; import io.grpc.InternalChannelz;
import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelStats;
import io.grpc.InternalChannelz.ChannelTrace; import io.grpc.InternalChannelz.ChannelTrace;
@ -74,8 +77,9 @@ import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer; import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer;
import io.grpc.internal.ClientCallImpl.ClientStreamProvider; import io.grpc.internal.ClientCallImpl.ClientStreamProvider;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider;
import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder;
import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo; import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo;
import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelector; import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelector;
import io.grpc.internal.RetriableStream.ChannelBufferMeter; import io.grpc.internal.RetriableStream.ChannelBufferMeter;
@ -153,6 +157,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
private final NameResolver.Args nameResolverArgs; private final NameResolver.Args nameResolverArgs;
private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; private final AutoConfiguredLoadBalancerFactory loadBalancerFactory;
private final ClientTransportFactory originalTransportFactory; private final ClientTransportFactory originalTransportFactory;
@Nullable
private final ChannelCredentials originalChannelCreds;
private final ClientTransportFactory transportFactory; private final ClientTransportFactory transportFactory;
private final ClientTransportFactory oobTransportFactory; private final ClientTransportFactory oobTransportFactory;
private final RestrictedScheduledExecutor scheduledExecutor; private final RestrictedScheduledExecutor scheduledExecutor;
@ -593,6 +599,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
this.timeProvider = checkNotNull(timeProvider, "timeProvider"); this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.executorPool = checkNotNull(builder.executorPool, "executorPool"); this.executorPool = checkNotNull(builder.executorPool, "executorPool");
this.executor = checkNotNull(executorPool.getObject(), "executor"); this.executor = checkNotNull(executorPool.getObject(), "executor");
this.originalChannelCreds = builder.channelCredentials;
this.originalTransportFactory = clientTransportFactory; this.originalTransportFactory = clientTransportFactory;
this.transportFactory = new CallCredentialsApplyingTransportFactory( this.transportFactory = new CallCredentialsApplyingTransportFactory(
clientTransportFactory, builder.callCredentials, this.executor); clientTransportFactory, builder.callCredentials, this.executor);
@ -1516,50 +1523,82 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override @Override
public ManagedChannelBuilder<?> createResolvingOobChannelBuilder(String target) { public ManagedChannelBuilder<?> createResolvingOobChannelBuilder(String target) {
return createResolvingOobChannelBuilder(target, new DefaultChannelCreds());
}
// TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated
// TODO(zdapeng) register the channel as a subchannel of the parent channel in channelz.
@Override
public ManagedChannelBuilder<?> createResolvingOobChannelBuilder(
final String target, final ChannelCredentials channelCreds) {
checkNotNull(channelCreds, "channelCreds");
final class ResolvingOobChannelBuilder final class ResolvingOobChannelBuilder
extends ForwardingChannelBuilder<ResolvingOobChannelBuilder> { extends ForwardingChannelBuilder<ResolvingOobChannelBuilder> {
private final ManagedChannelImplBuilder managedChannelImplBuilder; final ManagedChannelBuilder<?> delegate;
ResolvingOobChannelBuilder(String target) { ResolvingOobChannelBuilder() {
managedChannelImplBuilder = new ManagedChannelImplBuilder(target, final ClientTransportFactory transportFactory;
new UnsupportedClientTransportFactoryBuilder(), CallCredentials callCredentials;
if (channelCreds instanceof DefaultChannelCreds) {
transportFactory = originalTransportFactory;
callCredentials = null;
} else {
SwapChannelCredentialsResult swapResult =
originalTransportFactory.swapChannelCredentials(channelCreds);
if (swapResult == null) {
delegate = Grpc.newChannelBuilder(target, channelCreds);
return;
} else {
transportFactory = swapResult.transportFactory;
callCredentials = swapResult.callCredentials;
}
}
ClientTransportFactoryBuilder transportFactoryBuilder =
new ClientTransportFactoryBuilder() {
@Override
public ClientTransportFactory buildClientTransportFactory() {
return transportFactory;
}
};
delegate = new ManagedChannelImplBuilder(
target,
channelCreds,
callCredentials,
transportFactoryBuilder,
new FixedPortProvider(nameResolverArgs.getDefaultPort())); new FixedPortProvider(nameResolverArgs.getDefaultPort()));
managedChannelImplBuilder.executorPool = executorPool;
managedChannelImplBuilder.offloadExecutorPool = offloadExecutorHolder.pool;
} }
@Override @Override
protected ManagedChannelBuilder<?> delegate() { protected ManagedChannelBuilder<?> delegate() {
return managedChannelImplBuilder; return delegate;
}
@Override
public ManagedChannel build() {
// TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated
return new ManagedChannelImpl(
managedChannelImplBuilder,
originalTransportFactory,
backoffPolicyProvider,
balancerRpcExecutorPool,
stopwatchSupplier,
Collections.<ClientInterceptor>emptyList(),
timeProvider);
} }
} }
checkState(!terminated, "Channel is terminated"); checkState(!terminated, "Channel is terminated");
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder(target) ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder()
.nameResolverFactory(nameResolverFactory); .nameResolverFactory(nameResolverFactory);
return builder return builder
.overrideAuthority(getAuthority()) .overrideAuthority(ManagedChannelImpl.this.authority())
// TODO(zdapeng): executors should not outlive the parent channel.
.executor(executor)
.offloadExecutor(offloadExecutorHolder.getExecutor())
.maxTraceEvents(maxTraceEvents) .maxTraceEvents(maxTraceEvents)
.proxyDetector(nameResolverArgs.getProxyDetector()) .proxyDetector(nameResolverArgs.getProxyDetector())
.userAgent(userAgent); .userAgent(userAgent);
} }
@Override
public ChannelCredentials getUnsafeChannelCredentials() {
if (originalChannelCreds == null) {
return new DefaultChannelCreds();
}
return originalChannelCreds;
}
@Override @Override
public void updateOobChannelAddresses(ManagedChannel channel, EquivalentAddressGroup eag) { public void updateOobChannelAddresses(ManagedChannel channel, EquivalentAddressGroup eag) {
checkArgument(channel instanceof OobChannel, checkArgument(channel instanceof OobChannel,
@ -1596,6 +1635,18 @@ final class ManagedChannelImpl extends ManagedChannel implements
public NameResolverRegistry getNameResolverRegistry() { public NameResolverRegistry getNameResolverRegistry() {
return nameResolverRegistry; return nameResolverRegistry;
} }
/**
* A placeholder for channel creds if user did not specify channel creds for the channel.
*/
// TODO(zdapeng): get rid of this class and let all ChannelBuilders always provide a non-null
// channel creds.
final class DefaultChannelCreds extends ChannelCredentials {
@Override
public ChannelCredentials withoutBearerTokens() {
return this;
}
}
} }
private final class NameResolverListener extends NameResolver.Listener2 { private final class NameResolverListener extends NameResolver.Listener2 {

View File

@ -24,6 +24,7 @@ import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.BinaryLog; import io.grpc.BinaryLog;
import io.grpc.CallCredentials; import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
@ -111,6 +112,8 @@ public final class ManagedChannelImplBuilder
final String target; final String target;
@Nullable @Nullable
final ChannelCredentials channelCredentials;
@Nullable
final CallCredentials callCredentials; final CallCredentials callCredentials;
@Nullable @Nullable
@ -225,18 +228,23 @@ public final class ManagedChannelImplBuilder
public ManagedChannelImplBuilder(String target, public ManagedChannelImplBuilder(String target,
ClientTransportFactoryBuilder clientTransportFactoryBuilder, ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this(target, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider); this(target, null, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider);
} }
/** /**
* Creates a new managed channel builder with a target string, which can be either a valid {@link * Creates a new managed channel builder with a target string, which can be either a valid {@link
* io.grpc.NameResolver}-compliant URI, or an authority string. Transport implementors must * io.grpc.NameResolver}-compliant URI, or an authority string. Transport implementors must
* provide client transport factory builder, and may set custom channel default port provider. * provide client transport factory builder, and may set custom channel default port provider.
*
* @param channelCreds The ChannelCredentials provided by the user. These may be used when
* creating derivative channels.
*/ */
public ManagedChannelImplBuilder(String target, @Nullable CallCredentials callCreds, public ManagedChannelImplBuilder(
String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds,
ClientTransportFactoryBuilder clientTransportFactoryBuilder, ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = Preconditions.checkNotNull(target, "target"); this.target = Preconditions.checkNotNull(target, "target");
this.channelCredentials = channelCreds;
this.callCredentials = callCreds; this.callCredentials = callCreds;
this.clientTransportFactoryBuilder = Preconditions this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");
@ -273,6 +281,7 @@ public final class ManagedChannelImplBuilder
ClientTransportFactoryBuilder clientTransportFactoryBuilder, ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = makeTargetStringForDirectAddress(directServerAddress); this.target = makeTargetStringForDirectAddress(directServerAddress);
this.channelCredentials = null;
this.callCredentials = null; this.callCredentials = null;
this.clientTransportFactoryBuilder = Preconditions this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");

View File

@ -16,9 +16,12 @@
package io.grpc.inprocess; package io.grpc.inprocess;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock;
import io.grpc.ChannelCredentials;
import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
@ -60,4 +63,11 @@ public class InProcessChannelBuilderTest {
clientTransportFactory.close(); clientTransportFactory.close();
} }
@Test
public void transportFactoryDoesNotSupportSwapChannelCreds() {
InProcessChannelBuilder builder = InProcessChannelBuilder.forName("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
assertThat(transportFactory.swapChannelCredentials(mock(ChannelCredentials.class))).isNull();
}
} }

View File

@ -63,15 +63,18 @@ import io.grpc.CallCredentials;
import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallCredentials.RequestInfo;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors; import io.grpc.ClientInterceptors;
import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer;
import io.grpc.CompositeChannelCredentials;
import io.grpc.ConnectivityState; import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo; import io.grpc.ConnectivityStateInfo;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
import io.grpc.InsecureChannelCredentials;
import io.grpc.IntegerMarshaller; import io.grpc.IntegerMarshaller;
import io.grpc.InternalChannelz; import io.grpc.InternalChannelz;
import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelStats;
@ -105,6 +108,7 @@ import io.grpc.Status;
import io.grpc.Status.Code; import io.grpc.Status.Code;
import io.grpc.StringMarshaller; import io.grpc.StringMarshaller;
import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.InternalSubchannel.TransportLogger; import io.grpc.internal.InternalSubchannel.TransportLogger;
import io.grpc.internal.ManagedChannelImpl.ScParser; import io.grpc.internal.ManagedChannelImpl.ScParser;
import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
@ -1662,7 +1666,8 @@ public class ManagedChannelImplTest {
Metadata.Key<String> metadataKey = Metadata.Key<String> metadataKey =
Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER); Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER);
String channelCredValue = "channel-provided call cred"; String channelCredValue = "channel-provided call cred";
channelBuilder = new ManagedChannelImplBuilder(TARGET, channelBuilder = new ManagedChannelImplBuilder(
TARGET, InsecureChannelCredentials.create(),
new FakeCallCredentials(metadataKey, channelCredValue), new FakeCallCredentials(metadataKey, channelCredValue),
new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT));
configureBuilder(channelBuilder); configureBuilder(channelBuilder);
@ -1733,11 +1738,91 @@ public class ManagedChannelImplTest {
call = oob.newCall(method, callOptions); call = oob.newCall(method, callOptions);
call.start(mockCallListener2, headers); call.start(mockCallListener2, headers);
verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); // CallOptions may contain StreamTracerFactory for census that is added by default.
verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class));
assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue);
oob.shutdownNow(); oob.shutdownNow();
} }
@Test
public void oobChannelWithOobChannelCredsHasChannelCallCredentials() {
Metadata.Key<String> metadataKey =
Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER);
String channelCredValue = "channel-provided call cred";
when(mockTransportFactory.swapChannelCredentials(any(CompositeChannelCredentials.class)))
.thenAnswer(new Answer<SwapChannelCredentialsResult>() {
@Override
public SwapChannelCredentialsResult answer(InvocationOnMock invocation) {
CompositeChannelCredentials c =
invocation.getArgument(0, CompositeChannelCredentials.class);
return new SwapChannelCredentialsResult(mockTransportFactory, c.getCallCredentials());
}
});
channelBuilder = new ManagedChannelImplBuilder(
TARGET, InsecureChannelCredentials.create(),
new FakeCallCredentials(metadataKey, channelCredValue),
new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT));
configureBuilder(channelBuilder);
createChannel();
// Verify that the normal channel has call creds, to validate configuration
Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady();
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(helper, READY, mockPicker);
String callCredValue = "per-RPC call cred";
CallOptions callOptions = CallOptions.DEFAULT
.withCallCredentials(new FakeCallCredentials(metadataKey, callCredValue));
Metadata headers = new Metadata();
ClientCall<String, Integer> call = channel.newCall(method, callOptions);
call.start(mockCallListener, headers);
verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions));
assertThat(headers.getAll(metadataKey))
.containsExactly(channelCredValue, callCredValue).inOrder();
// Verify that resolving oob channel with oob channel creds provides call creds
String oobChannelCredValue = "oob-channel-provided call cred";
ChannelCredentials oobChannelCreds = CompositeChannelCredentials.create(
InsecureChannelCredentials.create(),
new FakeCallCredentials(metadataKey, oobChannelCredValue));
ManagedChannel oob = helper.createResolvingOobChannelBuilder("oobauthority", oobChannelCreds)
.nameResolverFactory(
new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build())
.defaultLoadBalancingPolicy(MOCK_POLICY_NAME)
.idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS)
.build();
oob.getState(true);
ArgumentCaptor<Helper> helperCaptor = ArgumentCaptor.forClass(Helper.class);
verify(mockLoadBalancerProvider, times(2)).newLoadBalancer(helperCaptor.capture());
Helper oobHelper = helperCaptor.getValue();
subchannel =
createSubchannelSafely(oobHelper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(oobHelper, subchannel);
transportInfo = transports.poll();
transportInfo.listener.transportReady();
SubchannelPicker mockPicker2 = mock(SubchannelPicker.class);
when(mockPicker2.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(oobHelper, READY, mockPicker2);
headers = new Metadata();
call = oob.newCall(method, callOptions);
call.start(mockCallListener2, headers);
// CallOptions may contain StreamTracerFactory for census that is added by default.
verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class));
assertThat(headers.getAll(metadataKey))
.containsExactly(oobChannelCredValue, callCredValue).inOrder();
oob.shutdownNow();
}
@Test @Test
public void oobChannelsWhenChannelShutdownNow() { public void oobChannelsWhenChannelShutdownNow() {
createChannel(); createChannel();

View File

@ -24,6 +24,7 @@ import android.util.Log;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ExperimentalApi; import io.grpc.ExperimentalApi;
import io.grpc.Internal; import io.grpc.Internal;
@ -269,6 +270,11 @@ public final class CronetChannelBuilder
return timeoutService; return timeoutService;
} }
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
return null;
}
@Override @Override
public void close() { public void close() {
if (usingSharedScheduler) { if (usingSharedScheduler) {

View File

@ -46,6 +46,7 @@ import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
import io.grpc.internal.ObjectPool; import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool; import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer; import io.grpc.internal.TransportTracer;
import io.grpc.netty.ProtocolNegotiators.FromChannelCredentialsResult;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
@ -157,12 +158,11 @@ public final class NettyChannelBuilder extends
*/ */
@CheckReturnValue @CheckReturnValue
public static NettyChannelBuilder forTarget(String target, ChannelCredentials creds) { public static NettyChannelBuilder forTarget(String target, ChannelCredentials creds) {
ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(creds); FromChannelCredentialsResult result = ProtocolNegotiators.from(creds);
if (result.error != null) { if (result.error != null) {
throw new IllegalArgumentException(result.error); throw new IllegalArgumentException(result.error);
} }
return new NettyChannelBuilder( return new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator);
target, result.negotiator, result.callCredentials);
} }
private final class NettyChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder { private final class NettyChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder {
@ -187,11 +187,11 @@ public final class NettyChannelBuilder extends
this.freezeProtocolNegotiatorFactory = false; this.freezeProtocolNegotiatorFactory = false;
} }
@CheckReturnValue
NettyChannelBuilder( NettyChannelBuilder(
String target, ProtocolNegotiator.ClientFactory negotiator, String target, ChannelCredentials channelCreds, CallCredentials callCreds,
@Nullable CallCredentials callCreds) { ProtocolNegotiator.ClientFactory negotiator) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCreds, managedChannelImplBuilder = new ManagedChannelImplBuilder(
target, channelCreds, callCreds,
new NettyChannelTransportFactoryBuilder(), new NettyChannelTransportFactoryBuilder(),
new NettyChannelDefaultPortProvider()); new NettyChannelDefaultPortProvider());
this.protocolNegotiatorFactory = checkNotNull(negotiator, "negotiator"); this.protocolNegotiatorFactory = checkNotNull(negotiator, "negotiator");
@ -628,7 +628,8 @@ public final class NettyChannelBuilder extends
private final int flowControlWindow; private final int flowControlWindow;
private final int maxMessageSize; private final int maxMessageSize;
private final int maxHeaderListSize; private final int maxHeaderListSize;
private final AtomicBackoff keepAliveTimeNanos; private final long keepAliveTimeNanos;
private final AtomicBackoff keepAliveBackoff;
private final long keepAliveTimeoutNanos; private final long keepAliveTimeoutNanos;
private final boolean keepAliveWithoutCalls; private final boolean keepAliveWithoutCalls;
private final TransportTracer.Factory transportTracerFactory; private final TransportTracer.Factory transportTracerFactory;
@ -637,7 +638,8 @@ public final class NettyChannelBuilder extends
private boolean closed; private boolean closed;
NettyTransportFactory(ProtocolNegotiator protocolNegotiator, NettyTransportFactory(
ProtocolNegotiator protocolNegotiator,
ChannelFactory<? extends Channel> channelFactory, ChannelFactory<? extends Channel> channelFactory,
Map<ChannelOption<?>, ?> channelOptions, ObjectPool<? extends EventLoopGroup> groupPool, Map<ChannelOption<?>, ?> channelOptions, ObjectPool<? extends EventLoopGroup> groupPool,
boolean autoFlowControl, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, boolean autoFlowControl, int flowControlWindow, int maxMessageSize, int maxHeaderListSize,
@ -653,7 +655,8 @@ public final class NettyChannelBuilder extends
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.maxHeaderListSize = maxHeaderListSize; this.maxHeaderListSize = maxHeaderListSize;
this.keepAliveTimeNanos = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); this.keepAliveTimeNanos = keepAliveTimeNanos;
this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos);
this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveTimeoutNanos = keepAliveTimeoutNanos;
this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.keepAliveWithoutCalls = keepAliveWithoutCalls;
this.transportTracerFactory = transportTracerFactory; this.transportTracerFactory = transportTracerFactory;
@ -678,7 +681,7 @@ public final class NettyChannelBuilder extends
protocolNegotiator); protocolNegotiator);
} }
final AtomicBackoff.State keepAliveTimeNanosState = keepAliveTimeNanos.getState(); final AtomicBackoff.State keepAliveTimeNanosState = keepAliveBackoff.getState();
Runnable tooManyPingsRunnable = new Runnable() { Runnable tooManyPingsRunnable = new Runnable() {
@Override @Override
public void run() { public void run() {
@ -702,6 +705,21 @@ public final class NettyChannelBuilder extends
return group; return group;
} }
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
checkNotNull(channelCreds, "channelCreds");
FromChannelCredentialsResult result = ProtocolNegotiators.from(channelCreds);
if (result.error != null) {
return null;
}
ClientTransportFactory factory = new NettyTransportFactory(
result.negotiator.newNegotiator(), channelFactory, channelOptions, groupPool,
autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanos,
keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory, localSocketPicker,
useGetForSafeMethods);
return new SwapChannelCredentialsResult(factory, result.callCredentials);
}
@Override @Override
public void close() { public void close() {
if (closed) { if (closed) {

View File

@ -49,7 +49,7 @@ public final class NettyChannelProvider extends ManagedChannelProvider {
if (result.error != null) { if (result.error != null) {
return NewChannelBuilderResult.error(result.error); return NewChannelBuilderResult.error(result.error);
} }
return NewChannelBuilderResult.channelBuilder(new NettyChannelBuilder( return NewChannelBuilderResult.channelBuilder(
target, result.negotiator, result.callCredentials)); new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator));
} }
} }

View File

@ -16,13 +16,18 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import io.grpc.ChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest; import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest;
import io.grpc.netty.ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelFactory;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
@ -282,4 +287,18 @@ public class NettyChannelBuilderTest {
builder.assertEventLoopAndChannelType(); builder.assertEventLoopAndChannelType();
} }
@Test
public void transportFactorySupportsNettyChannelCreds() {
NettyChannelBuilder builder = NettyChannelBuilder.forTarget("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
SwapChannelCredentialsResult result = transportFactory.swapChannelCredentials(
mock(ChannelCredentials.class));
assertThat(result).isNull();
result = transportFactory.swapChannelCredentials(
NettyChannelCredentials.create(new PlaintextProtocolNegotiatorClientFactory()));
assertThat(result).isNotNull();
}
} }

View File

@ -59,6 +59,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
@ -145,7 +146,7 @@ public final class OkHttpChannelBuilder extends
if (result.error != null) { if (result.error != null) {
throw new IllegalArgumentException(result.error); throw new IllegalArgumentException(result.error);
} }
return new OkHttpChannelBuilder(target, result.factory, result.callCredentials); return new OkHttpChannelBuilder(target, creds, result.callCredentials, result.factory);
} }
private Executor transportExecutor; private Executor transportExecutor;
@ -181,9 +182,11 @@ public final class OkHttpChannelBuilder extends
this.freezeSecurityConfiguration = false; this.freezeSecurityConfiguration = false;
} }
OkHttpChannelBuilder(String target, @Nullable SSLSocketFactory factory, OkHttpChannelBuilder(
@Nullable CallCredentials callCredentials) { String target, ChannelCredentials channelCreds, CallCredentials callCreds,
managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCredentials, SSLSocketFactory factory) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(
target, channelCreds, callCreds,
new OkHttpChannelTransportFactoryBuilder(), new OkHttpChannelTransportFactoryBuilder(),
new OkHttpChannelDefaultPortProvider()); new OkHttpChannelDefaultPortProvider());
this.sslSocketFactory = factory; this.sslSocketFactory = factory;
@ -631,7 +634,8 @@ public final class OkHttpChannelBuilder extends
private final ConnectionSpec connectionSpec; private final ConnectionSpec connectionSpec;
private final int maxMessageSize; private final int maxMessageSize;
private final boolean enableKeepAlive; private final boolean enableKeepAlive;
private final AtomicBackoff keepAliveTimeNanos; private final long keepAliveTimeNanos;
private final AtomicBackoff keepAliveBackoff;
private final long keepAliveTimeoutNanos; private final long keepAliveTimeoutNanos;
private final int flowControlWindow; private final int flowControlWindow;
private final boolean keepAliveWithoutCalls; private final boolean keepAliveWithoutCalls;
@ -665,7 +669,8 @@ public final class OkHttpChannelBuilder extends
this.connectionSpec = connectionSpec; this.connectionSpec = connectionSpec;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.enableKeepAlive = enableKeepAlive; this.enableKeepAlive = enableKeepAlive;
this.keepAliveTimeNanos = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); this.keepAliveTimeNanos = keepAliveTimeNanos;
this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos);
this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveTimeoutNanos = keepAliveTimeoutNanos;
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.keepAliveWithoutCalls = keepAliveWithoutCalls;
@ -689,7 +694,7 @@ public final class OkHttpChannelBuilder extends
if (closed) { if (closed) {
throw new IllegalStateException("The transport factory is closed."); throw new IllegalStateException("The transport factory is closed.");
} }
final AtomicBackoff.State keepAliveTimeNanosState = keepAliveTimeNanos.getState(); final AtomicBackoff.State keepAliveTimeNanosState = keepAliveBackoff.getState();
Runnable tooManyPingsRunnable = new Runnable() { Runnable tooManyPingsRunnable = new Runnable() {
@Override @Override
public void run() { public void run() {
@ -727,6 +732,33 @@ public final class OkHttpChannelBuilder extends
return timeoutService; return timeoutService;
} }
@Nullable
@CheckReturnValue
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
SslSocketFactoryResult result = sslSocketFactoryFrom(channelCreds);
if (result.error != null) {
return null;
}
ClientTransportFactory factory = new OkHttpTransportFactory(
executor,
timeoutService,
socketFactory,
result.factory,
hostnameVerifier,
connectionSpec,
maxMessageSize,
enableKeepAlive,
keepAliveTimeNanos,
keepAliveTimeoutNanos,
flowControlWindow,
keepAliveWithoutCalls,
maxInboundMetadataSize,
transportTracerFactory,
useGetForSafeMethods);
return new SwapChannelCredentialsResult(factory, result.callCredentials);
}
@Override @Override
public void close() { public void close() {
if (closed) { if (closed) {

View File

@ -55,6 +55,6 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider {
return NewChannelBuilderResult.error(result.error); return NewChannelBuilderResult.error(result.error);
} }
return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder( return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder(
target, result.factory, result.callCredentials)); target, creds, result.callCredentials, result.factory));
} }
} }

View File

@ -34,6 +34,7 @@ import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.TlsChannelCredentials; import io.grpc.TlsChannelCredentials;
import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
@ -358,6 +359,20 @@ public class OkHttpChannelBuilderTest {
transportFactory.close(); transportFactory.close();
} }
@Test
public void transportFactorySupportsOkHttpChannelCreds() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
SwapChannelCredentialsResult result = transportFactory.swapChannelCredentials(
mock(ChannelCredentials.class));
assertThat(result).isNull();
result = transportFactory.swapChannelCredentials(
SslSocketFactoryChannelCredentials.create(mock(SSLSocketFactory.class)));
assertThat(result).isNotNull();
}
private static final class FakeChannelLogger extends ChannelLogger { private static final class FakeChannelLogger extends ChannelLogger {
@Override @Override