core: Add ChannelCredentials

This commit is contained in:
Eric Anderson 2020-07-31 16:53:44 -07:00 committed by Eric Anderson
parent c8a94d1059
commit 5733cd481a
5 changed files with 99 additions and 23 deletions

View File

@ -20,10 +20,11 @@ import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.CallCredentials.RequestInfo;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.ChannelLogger;
import io.grpc.CompositeCallCredentials;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
@ -34,11 +35,14 @@ import java.util.concurrent.ScheduledExecutorService;
final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
private final ClientTransportFactory delegate;
private final CallCredentials channelCallCredentials;
private final Executor appExecutor;
CallCredentialsApplyingTransportFactory(
ClientTransportFactory delegate, Executor appExecutor) {
ClientTransportFactory delegate, CallCredentials channelCallCredentials,
Executor appExecutor) {
this.delegate = checkNotNull(delegate, "delegate");
this.channelCallCredentials = channelCallCredentials;
this.appExecutor = checkNotNull(appExecutor, "appExecutor");
}
@ -78,6 +82,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
public ClientStream newStream(
final MethodDescriptor<?, ?> method, Metadata headers, final CallOptions callOptions) {
CallCredentials creds = callOptions.getCredentials();
if (creds == null) {
creds = channelCallCredentials;
} else if (channelCallCredentials != null) {
creds = new CompositeCallCredentials(channelCallCredentials, creds);
}
if (creds != null) {
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions);

View File

@ -589,8 +589,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.executorPool = checkNotNull(builder.executorPool, "executorPool");
this.executor = checkNotNull(executorPool.getObject(), "executor");
this.transportFactory =
new CallCredentialsApplyingTransportFactory(clientTransportFactory, this.executor);
this.transportFactory = new CallCredentialsApplyingTransportFactory(
clientTransportFactory, builder.callCredentials, this.executor);
this.scheduledExecutor =
new RestrictedScheduledExecutor(transportFactory.getScheduledExecutorService());
maxTraceEvents = builder.maxTraceEvents;

View File

@ -23,6 +23,7 @@ import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Attributes;
import io.grpc.BinaryLog;
import io.grpc.CallCredentials;
import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
@ -109,6 +110,8 @@ public final class ManagedChannelImplBuilder
private NameResolver.Factory nameResolverFactory = nameResolverRegistry.asFactory();
final String target;
@Nullable
final CallCredentials callCredentials;
@Nullable
private final SocketAddress directServerAddress;
@ -222,7 +225,19 @@ public final class ManagedChannelImplBuilder
public ManagedChannelImplBuilder(String target,
ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this(target, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider);
}
/**
* Creates a new managed channel builder with a target string, which can be either a valid {@link
* io.grpc.NameResolver}-compliant URI, or an authority string. Transport implementors must
* provide client transport factory builder, and may set custom channel default port provider.
*/
public ManagedChannelImplBuilder(String target, @Nullable CallCredentials callCreds,
ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = Preconditions.checkNotNull(target, "target");
this.callCredentials = callCreds;
this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");
this.directServerAddress = null;
@ -258,6 +273,7 @@ public final class ManagedChannelImplBuilder
ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = makeTargetStringForDirectAddress(directServerAddress);
this.callCredentials = null;
this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");
this.directServerAddress = directServerAddress;

View File

@ -120,7 +120,7 @@ public class CallCredentials2ApplyingTest {
when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class)))
.thenReturn(mockStream);
ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory(
mockTransportFactory, mockExecutor);
mockTransportFactory, null, mockExecutor);
transport = (ForwardingConnectionClientTransport)
transportFactory.newClientTransport(address, clientTransportOptions, channelLogger);
callOptions = CallOptions.DEFAULT.withCallCredentials(mockCreds);

View File

@ -102,24 +102,23 @@ public class CallCredentialsApplyingTest {
Metadata.Key.of("test-creds", Metadata.ASCII_STRING_MARSHALLER);
private static final String CREDS_VALUE = "some credentials";
private final ClientTransportFactory.ClientTransportOptions clientTransportOptions =
new ClientTransportFactory.ClientTransportOptions()
.setAuthority(AUTHORITY)
.setUserAgent(USER_AGENT);
private final Metadata origHeaders = new Metadata();
private ForwardingConnectionClientTransport transport;
private CallOptions callOptions;
@Before
public void setUp() {
ClientTransportFactory.ClientTransportOptions clientTransportOptions =
new ClientTransportFactory.ClientTransportOptions()
.setAuthority(AUTHORITY)
.setUserAgent(USER_AGENT);
origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE);
when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger))
.thenReturn(mockTransport);
when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class)))
.thenReturn(mockStream);
ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory(
mockTransportFactory, mockExecutor);
mockTransportFactory, null, mockExecutor);
transport = (ForwardingConnectionClientTransport)
transportFactory.newClientTransport(address, clientTransportOptions, channelLogger);
callOptions = CallOptions.DEFAULT.withCallCredentials(mockCreds);
@ -185,19 +184,8 @@ public class CallCredentialsApplyingTest {
@Test
public void applyMetadata_inline() {
when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
CallCredentials.MetadataApplier applier =
(CallCredentials.MetadataApplier) invocation.getArguments()[2];
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applier.apply(headers);
return null;
}
}).when(mockCreds).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), any(CallCredentials.MetadataApplier.class));
callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE));
ClientStream stream = transport.newStream(method, origHeaders, callOptions);
verify(mockTransport).newStream(method, origHeaders, callOptions);
@ -279,4 +267,67 @@ public class CallCredentialsApplyingTest {
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
}
@Test
public void justCallOptionCreds() {
callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE));
ClientStream stream = transport.newStream(method, origHeaders, callOptions);
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
}
@Test
public void justChannelCreds() {
ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory(
mockTransportFactory, new FakeCallCredentials(CREDS_KEY, CREDS_VALUE), mockExecutor);
transport = (ForwardingConnectionClientTransport)
transportFactory.newClientTransport(address, clientTransportOptions, channelLogger);
callOptions = callOptions.withCallCredentials(null);
ClientStream stream = transport.newStream(method, origHeaders, callOptions);
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
}
@Test
public void callOptionAndChanelCreds() {
ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory(
mockTransportFactory, new FakeCallCredentials(CREDS_KEY, CREDS_VALUE), mockExecutor);
transport = (ForwardingConnectionClientTransport)
transportFactory.newClientTransport(address, clientTransportOptions, channelLogger);
Metadata.Key<String> creds2Key =
Metadata.Key.of("test-creds2", Metadata.ASCII_STRING_MARSHALLER);
String creds2Value = "some more credentials";
callOptions = callOptions.withCallCredentials(new FakeCallCredentials(creds2Key, creds2Value));
ClientStream stream = transport.newStream(method, origHeaders, callOptions);
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(creds2Value, origHeaders.get(creds2Key));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
}
private abstract static class BaseCallCredentials extends CallCredentials {
@Override public void thisUsesUnstableApi() {}
}
private static class FakeCallCredentials extends BaseCallCredentials {
private final Metadata headers;
public <T> FakeCallCredentials(Metadata.Key<T> key, T value) {
headers = new Metadata();
headers.put(key, value);
}
@Override public void applyRequestMetadata(
RequestInfo requestInfo, Executor appExecutor, CallCredentials.MetadataApplier applier) {
applier.apply(headers);
}
}
}