From 5733cd481a3df2cbf59d501af5cefa58c15ec2a3 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 31 Jul 2020 16:53:44 -0700 Subject: [PATCH] core: Add ChannelCredentials --- ...llCredentialsApplyingTransportFactory.java | 13 ++- .../io/grpc/internal/ManagedChannelImpl.java | 4 +- .../internal/ManagedChannelImplBuilder.java | 16 ++++ .../CallCredentials2ApplyingTest.java | 2 +- .../internal/CallCredentialsApplyingTest.java | 87 +++++++++++++++---- 5 files changed, 99 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 2997464416..ab2a36d949 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -20,10 +20,11 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkNotNull; import io.grpc.Attributes; -import io.grpc.CallCredentials; import io.grpc.CallCredentials.RequestInfo; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.CompositeCallCredentials; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -34,11 +35,14 @@ import java.util.concurrent.ScheduledExecutorService; final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory { private final ClientTransportFactory delegate; + private final CallCredentials channelCallCredentials; private final Executor appExecutor; CallCredentialsApplyingTransportFactory( - ClientTransportFactory delegate, Executor appExecutor) { + ClientTransportFactory delegate, CallCredentials channelCallCredentials, + Executor appExecutor) { this.delegate = checkNotNull(delegate, "delegate"); + this.channelCallCredentials = channelCallCredentials; this.appExecutor = checkNotNull(appExecutor, "appExecutor"); } @@ -78,6 +82,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa public ClientStream newStream( final MethodDescriptor method, Metadata headers, final CallOptions callOptions) { CallCredentials creds = callOptions.getCredentials(); + if (creds == null) { + creds = channelCallCredentials; + } else if (channelCallCredentials != null) { + creds = new CompositeCallCredentials(channelCallCredentials, creds); + } if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( delegate, method, headers, callOptions); diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 117e02ad53..1646940bf7 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -589,8 +589,8 @@ final class ManagedChannelImpl extends ManagedChannel implements this.timeProvider = checkNotNull(timeProvider, "timeProvider"); this.executorPool = checkNotNull(builder.executorPool, "executorPool"); this.executor = checkNotNull(executorPool.getObject(), "executor"); - this.transportFactory = - new CallCredentialsApplyingTransportFactory(clientTransportFactory, this.executor); + this.transportFactory = new CallCredentialsApplyingTransportFactory( + clientTransportFactory, builder.callCredentials, this.executor); this.scheduledExecutor = new RestrictedScheduledExecutor(transportFactory.getScheduledExecutorService()); maxTraceEvents = builder.maxTraceEvents; diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 202055f971..e7806f56ec 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -23,6 +23,7 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; import io.grpc.Attributes; import io.grpc.BinaryLog; +import io.grpc.CallCredentials; import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; @@ -109,6 +110,8 @@ public final class ManagedChannelImplBuilder private NameResolver.Factory nameResolverFactory = nameResolverRegistry.asFactory(); final String target; + @Nullable + final CallCredentials callCredentials; @Nullable private final SocketAddress directServerAddress; @@ -222,7 +225,19 @@ public final class ManagedChannelImplBuilder public ManagedChannelImplBuilder(String target, ClientTransportFactoryBuilder clientTransportFactoryBuilder, @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { + this(target, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider); + } + + /** + * Creates a new managed channel builder with a target string, which can be either a valid {@link + * io.grpc.NameResolver}-compliant URI, or an authority string. Transport implementors must + * provide client transport factory builder, and may set custom channel default port provider. + */ + public ManagedChannelImplBuilder(String target, @Nullable CallCredentials callCreds, + ClientTransportFactoryBuilder clientTransportFactoryBuilder, + @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { this.target = Preconditions.checkNotNull(target, "target"); + this.callCredentials = callCreds; this.clientTransportFactoryBuilder = Preconditions .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); this.directServerAddress = null; @@ -258,6 +273,7 @@ public final class ManagedChannelImplBuilder ClientTransportFactoryBuilder clientTransportFactoryBuilder, @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { this.target = makeTargetStringForDirectAddress(directServerAddress); + this.callCredentials = null; this.clientTransportFactoryBuilder = Preconditions .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); this.directServerAddress = directServerAddress; diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java index 76882b829a..c26944c16b 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -120,7 +120,7 @@ public class CallCredentials2ApplyingTest { when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( - mockTransportFactory, mockExecutor); + mockTransportFactory, null, mockExecutor); transport = (ForwardingConnectionClientTransport) transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); callOptions = CallOptions.DEFAULT.withCallCredentials(mockCreds); diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 81e45df1c5..52ebb8b8a2 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -102,24 +102,23 @@ public class CallCredentialsApplyingTest { Metadata.Key.of("test-creds", Metadata.ASCII_STRING_MARSHALLER); private static final String CREDS_VALUE = "some credentials"; + private final ClientTransportFactory.ClientTransportOptions clientTransportOptions = + new ClientTransportFactory.ClientTransportOptions() + .setAuthority(AUTHORITY) + .setUserAgent(USER_AGENT); private final Metadata origHeaders = new Metadata(); private ForwardingConnectionClientTransport transport; private CallOptions callOptions; @Before public void setUp() { - ClientTransportFactory.ClientTransportOptions clientTransportOptions = - new ClientTransportFactory.ClientTransportOptions() - .setAuthority(AUTHORITY) - .setUserAgent(USER_AGENT); - origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) .thenReturn(mockTransport); when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( - mockTransportFactory, mockExecutor); + mockTransportFactory, null, mockExecutor); transport = (ForwardingConnectionClientTransport) transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); callOptions = CallOptions.DEFAULT.withCallCredentials(mockCreds); @@ -185,19 +184,8 @@ public class CallCredentialsApplyingTest { @Test public void applyMetadata_inline() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - CallCredentials.MetadataApplier applier = - (CallCredentials.MetadataApplier) invocation.getArguments()[2]; - Metadata headers = new Metadata(); - headers.put(CREDS_KEY, CREDS_VALUE); - applier.apply(headers); - return null; - } - }).when(mockCreds).applyRequestMetadata(any(RequestInfo.class), - same(mockExecutor), any(CallCredentials.MetadataApplier.class)); + callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); ClientStream stream = transport.newStream(method, origHeaders, callOptions); verify(mockTransport).newStream(method, origHeaders, callOptions); @@ -279,4 +267,67 @@ public class CallCredentialsApplyingTest { assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); } + + @Test + public void justCallOptionCreds() { + callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); + + ClientStream stream = transport.newStream(method, origHeaders, callOptions); + + assertSame(mockStream, stream); + assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); + assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + } + + @Test + public void justChannelCreds() { + ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( + mockTransportFactory, new FakeCallCredentials(CREDS_KEY, CREDS_VALUE), mockExecutor); + transport = (ForwardingConnectionClientTransport) + transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); + callOptions = callOptions.withCallCredentials(null); + + ClientStream stream = transport.newStream(method, origHeaders, callOptions); + + assertSame(mockStream, stream); + assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); + assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + } + + @Test + public void callOptionAndChanelCreds() { + ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( + mockTransportFactory, new FakeCallCredentials(CREDS_KEY, CREDS_VALUE), mockExecutor); + transport = (ForwardingConnectionClientTransport) + transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); + Metadata.Key creds2Key = + Metadata.Key.of("test-creds2", Metadata.ASCII_STRING_MARSHALLER); + String creds2Value = "some more credentials"; + callOptions = callOptions.withCallCredentials(new FakeCallCredentials(creds2Key, creds2Value)); + + ClientStream stream = transport.newStream(method, origHeaders, callOptions); + + assertSame(mockStream, stream); + assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); + assertEquals(creds2Value, origHeaders.get(creds2Key)); + assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + } + + private abstract static class BaseCallCredentials extends CallCredentials { + @Override public void thisUsesUnstableApi() {} + } + + private static class FakeCallCredentials extends BaseCallCredentials { + private final Metadata headers; + + public FakeCallCredentials(Metadata.Key key, T value) { + headers = new Metadata(); + headers.put(key, value); + } + + @Override public void applyRequestMetadata( + RequestInfo requestInfo, Executor appExecutor, CallCredentials.MetadataApplier applier) { + applier.apply(headers); + } + } }