From 62620ccd00b4bc772ba6622a7c1b192dea45747c Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 27 May 2020 12:31:54 -0700 Subject: [PATCH] xds: replace generic with individual client and server SslContextProviders (#7059) --- .../java/io/grpc/xds/CdsLoadBalancer.java | 14 +- .../sds/ClientSslContextProviderFactory.java | 7 +- .../internal/sds/CommonTlsContextUtil.java | 56 +++++ .../sds/DownstreamTlsContextHolder.java | 40 ++++ ...eferenceCountingSslContextProviderMap.java | 30 +-- .../sds/SdsClientSslContextProvider.java | 107 +++++++++ .../internal/sds/SdsProtocolNegotiators.java | 4 +- .../sds/SdsServerSslContextProvider.java | 91 ++++++++ .../internal/sds/SdsSslContextProvider.java | 121 ++-------- .../SecretVolumeClientSslContextProvider.java | 125 ++++++++++ .../SecretVolumeServerSslContextProvider.java | 116 ++++++++++ .../sds/SecretVolumeSslContextProvider.java | 219 ------------------ .../sds/ServerSslContextProviderFactory.java | 8 +- .../xds/internal/sds/SslContextProvider.java | 48 ++-- .../xds/internal/sds/TlsContextHolder.java | 29 +++ .../xds/internal/sds/TlsContextManager.java | 10 +- .../internal/sds/TlsContextManagerImpl.java | 22 +- .../sds/UpstreamTlsContextHolder.java | 40 ++++ .../java/io/grpc/xds/CdsLoadBalancerTest.java | 18 +- .../ClientSslContextProviderFactoryTest.java | 6 +- ...enceCountingSslContextProviderMapTest.java | 57 ++--- .../sds/SdsSslContextProviderTest.java | 113 +++++---- .../SecretVolumeSslContextProviderTest.java | 116 ++++------ .../ServerSslContextProviderFactoryTest.java | 6 +- .../internal/sds/TlsContextManagerTest.java | 34 ++- 25 files changed, 858 insertions(+), 579 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/DownstreamTlsContextHolder.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java delete mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/TlsContextHolder.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/UpstreamTlsContextHolder.java diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java index 0acea3cb04..e3dc2bad36 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java @@ -206,10 +206,10 @@ public final class CdsLoadBalancer extends LoadBalancer { private static final class EdsLoadBalancingHelper extends ForwardingLoadBalancerHelper { private final Helper delegate; - private final AtomicReference> sslContextProvider; + private final AtomicReference sslContextProvider; EdsLoadBalancingHelper(Helper helper, - AtomicReference> sslContextProvider) { + AtomicReference sslContextProvider) { this.delegate = helper; this.sslContextProvider = sslContextProvider; } @@ -222,7 +222,7 @@ public final class CdsLoadBalancer extends LoadBalancer { .toBuilder() .setAddresses( addUpstreamTlsContext(createSubchannelArgs.getAddresses(), - sslContextProvider.get().getSource())) + sslContextProvider.get().getUpstreamTlsContext())) .build(); } return delegate.createSubchannel(createSubchannelArgs); @@ -267,7 +267,7 @@ public final class CdsLoadBalancer extends LoadBalancer { ClusterWatcherImpl(Helper helper, ResolvedAddresses resolvedAddresses) { this.helper = new EdsLoadBalancingHelper(helper, - new AtomicReference>()); + new AtomicReference()); this.resolvedAddresses = resolvedAddresses; } @@ -303,10 +303,10 @@ public final class CdsLoadBalancer extends LoadBalancer { /** For new UpstreamTlsContext value, release old SslContextProvider. */ private void updateSslContextProvider(UpstreamTlsContext newUpstreamTlsContext) { - SslContextProvider oldSslContextProvider = + SslContextProvider oldSslContextProvider = helper.sslContextProvider.get(); if (oldSslContextProvider != null) { - UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getSource(); + UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getUpstreamTlsContext(); if (oldUpstreamTlsContext.equals(newUpstreamTlsContext)) { return; @@ -314,7 +314,7 @@ public final class CdsLoadBalancer extends LoadBalancer { tlsContextManager.releaseClientSslContextProvider(oldSslContextProvider); } if (newUpstreamTlsContext != null) { - SslContextProvider newSslContextProvider = + SslContextProvider newSslContextProvider = tlsContextManager.findOrCreateClientSslContextProvider(newUpstreamTlsContext); helper.sslContextProvider.set(newSslContextProvider); } else { diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java index bd1e663570..a77d3ff194 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java @@ -32,18 +32,17 @@ final class ClientSslContextProviderFactory /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override - public SslContextProvider createSslContextProvider( - UpstreamTlsContext upstreamTlsContext) { + public SslContextProvider createSslContextProvider(UpstreamTlsContext upstreamTlsContext) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); checkArgument( upstreamTlsContext.hasCommonTlsContext(), "upstreamTlsContext should have CommonTlsContext"); if (CommonTlsContextUtil.hasAllSecretsUsingFilename(upstreamTlsContext.getCommonTlsContext())) { - return SecretVolumeSslContextProvider.getProviderForClient(upstreamTlsContext); + return SecretVolumeClientSslContextProvider.getProvider(upstreamTlsContext); } else if (CommonTlsContextUtil.hasAllSecretsUsingSds( upstreamTlsContext.getCommonTlsContext())) { try { - return SdsSslContextProvider.getProviderForClient( + return SdsClientSslContextProvider.getProvider( upstreamTlsContext, Bootstrapper.getInstance().readBootstrap().getNode(), Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java index 00c0777fea..daf375d2b3 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java @@ -16,9 +16,16 @@ package io.grpc.xds.internal.sds; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.ValidationContextTypeCase; +import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; +import io.envoyproxy.envoy.api.v2.core.DataSource.SpecifierCase; +import javax.annotation.Nullable; /** Class for utility functions for {@link CommonTlsContext}. */ final class CommonTlsContextUtil { @@ -40,4 +47,53 @@ final class CommonTlsContextUtil { return (commonTlsContext.getTlsCertificatesCount() == 0) && !commonTlsContext.hasValidationContext(); } + + @Nullable + static CertificateValidationContext getCertificateValidationContext( + CommonTlsContext commonTlsContext) { + checkNotNull(commonTlsContext, "commonTlsContext"); + ValidationContextTypeCase type = commonTlsContext.getValidationContextTypeCase(); + checkState( + type == ValidationContextTypeCase.VALIDATION_CONTEXT + || type == ValidationContextTypeCase.VALIDATIONCONTEXTTYPE_NOT_SET, + "incorrect ValidationContextTypeCase"); + return type == ValidationContextTypeCase.VALIDATION_CONTEXT + ? commonTlsContext.getValidationContext() + : null; + } + + @Nullable + static CertificateValidationContext validateCertificateContext( + @Nullable CertificateValidationContext certContext, boolean optional) { + if (certContext == null || !certContext.hasTrustedCa()) { + checkArgument(optional, "certContext is required"); + return null; + } + checkArgument( + certContext.getTrustedCa().getSpecifierCase() == SpecifierCase.FILENAME, + "filename expected"); + return certContext; + } + + @Nullable + static TlsCertificate validateTlsCertificate( + @Nullable TlsCertificate tlsCertificate, boolean optional) { + if (tlsCertificate == null) { + checkArgument(optional, "tlsCertificate is required"); + return null; + } + if (optional + && (tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.SPECIFIER_NOT_SET) + && (tlsCertificate.getCertificateChain().getSpecifierCase() + == SpecifierCase.SPECIFIER_NOT_SET)) { + return null; + } + checkArgument( + tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.FILENAME, + "filename expected"); + checkArgument( + tlsCertificate.getCertificateChain().getSpecifierCase() == SpecifierCase.FILENAME, + "filename expected"); + return tlsCertificate; + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/DownstreamTlsContextHolder.java b/xds/src/main/java/io/grpc/xds/internal/sds/DownstreamTlsContextHolder.java new file mode 100644 index 0000000000..5163b4b244 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/DownstreamTlsContextHolder.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; + +final class DownstreamTlsContextHolder implements TlsContextHolder { + + private final DownstreamTlsContext downstreamTlsContext; + + DownstreamTlsContextHolder(DownstreamTlsContext downstreamTlsContext) { + this.downstreamTlsContext = checkNotNull(downstreamTlsContext, "downstreamTlsContext"); + } + + public DownstreamTlsContext getDownstreamTlsContext() { + return downstreamTlsContext; + } + + @Override + public CommonTlsContext getCommonTlsContext() { + return downstreamTlsContext.getCommonTlsContext(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMap.java b/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMap.java index 7b7963f3f7..49921f9c8f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMap.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMap.java @@ -38,7 +38,7 @@ import javax.annotation.concurrent.ThreadSafe; @ThreadSafe final class ReferenceCountingSslContextProviderMap { - private final Map> instances = new HashMap<>(); + private final Map instances = new HashMap<>(); private final SslContextProviderFactory sslContextProviderFactory; ReferenceCountingSslContextProviderMap(SslContextProviderFactory sslContextProviderFactory) { @@ -51,7 +51,7 @@ final class ReferenceCountingSslContextProviderMap { * using the provided {@link SslContextProviderFactory<K>} */ @CheckReturnValue - public SslContextProvider get(K key) { + public SslContextProvider get(K key) { checkNotNull(key, "key"); return getInternal(key); } @@ -65,19 +65,20 @@ final class ReferenceCountingSslContextProviderMap { *

Caller must not release a reference more than once. It's advised that you clear the * reference to the instance with the null returned by this method. * + * @param key for the instance to be released * @param value the instance to be released * @return a null which the caller can use to clear the reference to that instance. */ - public SslContextProvider release(final SslContextProvider value) { + public SslContextProvider release(K key, SslContextProvider value) { + checkNotNull(key, "key"); checkNotNull(value, "value"); - K key = value.getSource(); return releaseInternal(key, value); } - private synchronized SslContextProvider getInternal(K key) { - Instance instance = instances.get(key); + private synchronized SslContextProvider getInternal(K key) { + Instance instance = instances.get(key); if (instance == null) { - instance = new Instance<>(sslContextProviderFactory.createSslContextProvider(key)); + instance = new Instance(sslContextProviderFactory.createSslContextProvider(key)); instances.put(key, instance); return instance.sslContextProvider; } else { @@ -85,9 +86,8 @@ final class ReferenceCountingSslContextProviderMap { } } - private synchronized SslContextProvider releaseInternal( - final K key, final SslContextProvider instance) { - final Instance cached = instances.get(key); + private synchronized SslContextProvider releaseInternal(K key, SslContextProvider instance) { + Instance cached = instances.get(key); checkArgument(cached != null, "No cached instance found for %s", key); checkArgument(instance == cached.sslContextProvider, "Releasing the wrong instance"); if (cached.release()) { @@ -103,15 +103,15 @@ final class ReferenceCountingSslContextProviderMap { /** A factory to create an SslContextProvider from the given key. */ public interface SslContextProviderFactory { - SslContextProvider createSslContextProvider(K key); + SslContextProvider createSslContextProvider(K key); } - private static class Instance { - final SslContextProvider sslContextProvider; + private static class Instance { + final SslContextProvider sslContextProvider; private int refCount; /** Increment refCount and acquire a reference to sslContextProvider. */ - SslContextProvider acquire() { + SslContextProvider acquire() { refCount++; return sslContextProvider; } @@ -122,7 +122,7 @@ final class ReferenceCountingSslContextProviderMap { return --refCount == 0; } - Instance(SslContextProvider sslContextProvider) { + Instance(SslContextProvider sslContextProvider) { this.sslContextProvider = sslContextProvider; this.refCount = 1; } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java new file mode 100644 index 0000000000..e8586c60ab --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; +import io.envoyproxy.envoy.api.v2.core.Node; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.IOException; +import java.security.cert.CertStoreException; +import java.security.cert.CertificateException; +import java.util.concurrent.Executor; + +/** A client SslContext provider that uses SDS to fetch secrets. */ +final class SdsClientSslContextProvider extends SdsSslContextProvider { + + private SdsClientSslContextProvider( + Node node, + SdsSecretConfig certSdsConfig, + SdsSecretConfig validationContextSdsConfig, + CertificateValidationContext staticCertValidationContext, + Executor watcherExecutor, + Executor channelExecutor, + UpstreamTlsContext upstreamTlsContext) { + super(node, + certSdsConfig, + validationContextSdsConfig, + staticCertValidationContext, + watcherExecutor, + channelExecutor, new UpstreamTlsContextHolder(upstreamTlsContext)); + } + + static SdsClientSslContextProvider getProvider( + UpstreamTlsContext upstreamTlsContext, + Node node, + Executor watcherExecutor, + Executor channelExecutor) { + checkNotNull(upstreamTlsContext, "upstreamTlsContext"); + CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); + SdsSecretConfig validationContextSdsConfig = null; + CertificateValidationContext staticCertValidationContext = null; + if (commonTlsContext.hasCombinedValidationContext()) { + CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasValidationContextSdsSecretConfig()) { + validationContextSdsConfig = + combinedValidationContext.getValidationContextSdsSecretConfig(); + } + if (combinedValidationContext.hasDefaultValidationContext()) { + staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); + } + } else if (commonTlsContext.hasValidationContextSdsSecretConfig()) { + validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig(); + } else if (commonTlsContext.hasValidationContext()) { + staticCertValidationContext = commonTlsContext.getValidationContext(); + } + SdsSecretConfig certSdsConfig = null; + if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { + certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0); + } + return new SdsClientSslContextProvider( + node, + certSdsConfig, + validationContextSdsConfig, + staticCertValidationContext, + watcherExecutor, + channelExecutor, + upstreamTlsContext); + } + + @Override + SslContextBuilder getSslContextBuilder( + CertificateValidationContext localCertValidationContext) + throws CertificateException, IOException, CertStoreException { + SslContextBuilder sslContextBuilder = + GrpcSslContexts.forClient() + .trustManager(new SdsTrustManagerFactory(localCertValidationContext)); + if (tlsCertificate != null) { + sslContextBuilder.keyManager( + tlsCertificate.getCertificateChain().getInlineBytes().newInput(), + tlsCertificate.getPrivateKey().getInlineBytes().newInput(), + tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null); + } + return sslContextBuilder; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index 15aedef94d..6bf33ead9c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -192,7 +192,7 @@ public final class SdsProtocolNegotiators { final BufferReadsHandler bufferReads = new BufferReadsHandler(); ctx.pipeline().addBefore(ctx.name(), null, bufferReads); - final SslContextProvider sslContextProvider = + final SslContextProvider sslContextProvider = TlsContextManagerImpl.getInstance() .findOrCreateClientSslContextProvider(upstreamTlsContext); @@ -349,7 +349,7 @@ public final class SdsProtocolNegotiators { final BufferReadsHandler bufferReads = new BufferReadsHandler(); ctx.pipeline().addBefore(ctx.name(), null, bufferReads); - final SslContextProvider sslContextProvider = + final SslContextProvider sslContextProvider = TlsContextManagerImpl.getInstance() .findOrCreateServerSslContextProvider(downstreamTlsContext); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java new file mode 100644 index 0000000000..d0baefbc7a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; +import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; +import io.envoyproxy.envoy.api.v2.core.Node; +import io.grpc.netty.GrpcSslContexts; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.IOException; +import java.security.cert.CertStoreException; +import java.security.cert.CertificateException; +import java.util.concurrent.Executor; + +/** A server SslContext provider that uses SDS to fetch secrets. */ +final class SdsServerSslContextProvider extends SdsSslContextProvider { + + private SdsServerSslContextProvider( + Node node, + SdsSecretConfig certSdsConfig, + SdsSecretConfig validationContextSdsConfig, + Executor watcherExecutor, + Executor channelExecutor, + DownstreamTlsContext downstreamTlsContext) { + super(node, + certSdsConfig, + validationContextSdsConfig, + null, + watcherExecutor, + channelExecutor, new DownstreamTlsContextHolder(downstreamTlsContext)); + } + + static SdsServerSslContextProvider getProvider( + DownstreamTlsContext downstreamTlsContext, + Node node, + Executor watcherExecutor, + Executor channelExecutor) { + checkNotNull(downstreamTlsContext, "downstreamTlsContext"); + CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); + + SdsSecretConfig certSdsConfig = null; + if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { + certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0); + } + + SdsSecretConfig validationContextSdsConfig = null; + if (commonTlsContext.hasValidationContextSdsSecretConfig()) { + validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig(); + } + return new SdsServerSslContextProvider( + node, + certSdsConfig, + validationContextSdsConfig, + watcherExecutor, + channelExecutor, + downstreamTlsContext); + } + + @Override + SslContextBuilder getSslContextBuilder( + CertificateValidationContext localCertValidationContext) + throws CertificateException, IOException, CertStoreException { + SslContextBuilder sslContextBuilder = + GrpcSslContexts.forServer( + tlsCertificate.getCertificateChain().getInlineBytes().newInput(), + tlsCertificate.getPrivateKey().getInlineBytes().newInput(), + tlsCertificate.hasPassword() + ? tlsCertificate.getPassword().getInlineString() + : null); + setClientAuthValues(sslContextBuilder, localCertValidationContext); + return sslContextBuilder; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java index 4af26f6ee8..d15067fe2e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java @@ -21,16 +21,11 @@ import static com.google.common.base.Preconditions.checkState; import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; -import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext; -import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; import io.envoyproxy.envoy.api.v2.auth.Secret; import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; -import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.envoyproxy.envoy.api.v2.core.Node; import io.grpc.Status; -import io.grpc.netty.GrpcSslContexts; -import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; import io.netty.handler.ssl.ApplicationProtocolConfig; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; @@ -44,12 +39,8 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -/** - * An SslContext provider that uses SDS to fetch secrets. Used for both server and client - * SslContexts - */ -final class SdsSslContextProvider extends SslContextProvider - implements SdsClient.SecretWatcher { +/** Base class for SdsClientSslContextProvider and SdsServerSslContextProvider. */ +abstract class SdsSslContextProvider extends SslContextProvider implements SdsClient.SecretWatcher { private static final Logger logger = Logger.getLogger(SdsSslContextProvider.class.getName()); @@ -59,20 +50,19 @@ final class SdsSslContextProvider extends SslContextProvider @Nullable private final SdsSecretConfig validationContextSdsConfig; @Nullable private final CertificateValidationContext staticCertificateValidationContext; private final List pendingCallbacks = new ArrayList<>(); - @Nullable private TlsCertificate tlsCertificate; + @Nullable protected TlsCertificate tlsCertificate; @Nullable private CertificateValidationContext certificateValidationContext; @Nullable private SslContext sslContext; - private SdsSslContextProvider( + SdsSslContextProvider( Node node, SdsSecretConfig certSdsConfig, SdsSecretConfig validationContextSdsConfig, CertificateValidationContext staticCertValidationContext, Executor watcherExecutor, Executor channelExecutor, - boolean server, - K source) { - super(source, server); + TlsContextHolder tlsContextHolder) { + super(tlsContextHolder); this.certSdsConfig = certSdsConfig; this.validationContextSdsConfig = validationContextSdsConfig; this.staticCertificateValidationContext = staticCertValidationContext; @@ -95,73 +85,6 @@ final class SdsSslContextProvider extends SslContextProvider } } - static SdsSslContextProvider getProviderForClient( - UpstreamTlsContext upstreamTlsContext, - Node node, - Executor watcherExecutor, - Executor channelExecutor) { - checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); - SdsSecretConfig validationContextSdsConfig = null; - CertificateValidationContext staticCertValidationContext = null; - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextSdsSecretConfig()) { - validationContextSdsConfig = - combinedValidationContext.getValidationContextSdsSecretConfig(); - } - if (combinedValidationContext.hasDefaultValidationContext()) { - staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); - } - } else if (commonTlsContext.hasValidationContextSdsSecretConfig()) { - validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig(); - } else if (commonTlsContext.hasValidationContext()) { - staticCertValidationContext = commonTlsContext.getValidationContext(); - } - SdsSecretConfig certSdsConfig = null; - if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { - certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0); - } - return new SdsSslContextProvider<>( - node, - certSdsConfig, - validationContextSdsConfig, - staticCertValidationContext, - watcherExecutor, - channelExecutor, - false, - upstreamTlsContext); - } - - static SdsSslContextProvider getProviderForServer( - DownstreamTlsContext downstreamTlsContext, - Node node, - Executor watcherExecutor, - Executor channelExecutor) { - checkNotNull(downstreamTlsContext, "downstreamTlsContext"); - CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); - - SdsSecretConfig certSdsConfig = null; - if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { - certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0); - } - - SdsSecretConfig validationContextSdsConfig = null; - if (commonTlsContext.hasValidationContextSdsSecretConfig()) { - validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig(); - } - return new SdsSslContextProvider<>( - node, - certSdsConfig, - validationContextSdsConfig, - null, - watcherExecutor, - channelExecutor, - true, - downstreamTlsContext); - } - @Override public void addCallback(Callback callback, Executor executor) { checkNotNull(callback, "callback"); @@ -219,34 +142,16 @@ final class SdsSslContextProvider extends SslContextProvider } } + /** Gets a server or client side SslContextBuilder. */ + abstract SslContextBuilder getSslContextBuilder( + CertificateValidationContext localCertValidationContext) + throws CertificateException, IOException, CertStoreException; + // this gets called only when requested secrets are ready... private void updateSslContext() { try { - SslContextBuilder sslContextBuilder; - CertificateValidationContext localCertValidationContext = - mergeStaticAndDynamicCertContexts(); - if (server) { - logger.log(Level.FINEST, "for server"); - sslContextBuilder = - GrpcSslContexts.forServer( - tlsCertificate.getCertificateChain().getInlineBytes().newInput(), - tlsCertificate.getPrivateKey().getInlineBytes().newInput(), - tlsCertificate.hasPassword() - ? tlsCertificate.getPassword().getInlineString() - : null); - setClientAuthValues(sslContextBuilder, localCertValidationContext); - } else { - logger.log(Level.FINEST, "for client"); - sslContextBuilder = - GrpcSslContexts.forClient() - .trustManager(new SdsTrustManagerFactory(localCertValidationContext)); - if (tlsCertificate != null) { - sslContextBuilder.keyManager( - tlsCertificate.getCertificateChain().getInlineBytes().newInput(), - tlsCertificate.getPrivateKey().getInlineBytes().newInput(), - tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null); - } - } + CertificateValidationContext localCertValidationContext = mergeStaticAndDynamicCertContexts(); + SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext); CommonTlsContext commonTlsContext = getCommonTlsContext(); if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { List alpnList = commonTlsContext.getAlpnProtocolsList(); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java new file mode 100644 index 0000000000..590b79a48b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java @@ -0,0 +1,125 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext; +import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext; +import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateTlsCertificate; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.File; +import java.io.IOException; +import java.security.cert.CertStoreException; +import java.security.cert.CertificateException; +import java.util.concurrent.Executor; +import javax.annotation.Nullable; + +/** A client SslContext provider that uses file-based secrets (secret volume). */ +final class SecretVolumeClientSslContextProvider extends SslContextProvider { + + @Nullable private final String privateKey; + @Nullable private final String privateKeyPassword; + @Nullable private final String certificateChain; + @Nullable private final CertificateValidationContext certContext; + + private SecretVolumeClientSslContextProvider( + @Nullable String privateKey, + @Nullable String privateKeyPassword, + @Nullable String certificateChain, + @Nullable CertificateValidationContext certContext, + UpstreamTlsContext upstreamTlsContext) { + super(new UpstreamTlsContextHolder(upstreamTlsContext)); + this.privateKey = privateKey; + this.privateKeyPassword = privateKeyPassword; + this.certificateChain = certificateChain; + this.certContext = certContext; + } + + static SecretVolumeClientSslContextProvider getProvider(UpstreamTlsContext upstreamTlsContext) { + checkNotNull(upstreamTlsContext, "upstreamTlsContext"); + CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); + CertificateValidationContext certificateValidationContext = + getCertificateValidationContext(commonTlsContext); + // first validate + validateCertificateContext(certificateValidationContext, /* optional= */ false); + TlsCertificate tlsCertificate = null; + if (commonTlsContext.getTlsCertificatesCount() > 0) { + tlsCertificate = commonTlsContext.getTlsCertificates(0); + } + // tlsCertificate exists in case of mTLS, else null for a client + if (tlsCertificate != null) { + tlsCertificate = validateTlsCertificate(tlsCertificate, /* optional= */ true); + } + String privateKey = null; + String privateKeyPassword = null; + String certificateChain = null; + if (tlsCertificate != null) { + privateKey = tlsCertificate.getPrivateKey().getFilename(); + if (tlsCertificate.hasPassword()) { + privateKeyPassword = tlsCertificate.getPassword().getInlineString(); + } + certificateChain = tlsCertificate.getCertificateChain().getFilename(); + } + return new SecretVolumeClientSslContextProvider( + privateKey, + privateKeyPassword, + certificateChain, + certificateValidationContext, + upstreamTlsContext); + } + + @Override + public void addCallback(final Callback callback, Executor executor) { + checkNotNull(callback, "callback"); + checkNotNull(executor, "executor"); + // as per the contract we will read the current secrets on disk + // this involves I/O which can potentially block the executor + performCallback( + new SslContextGetter() { + @Override + public SslContext get() throws CertificateException, IOException, CertStoreException { + return buildSslContextFromSecrets(); + } + }, + callback, + executor); + } + + @Override + public void close() {} + + @VisibleForTesting + SslContext buildSslContextFromSecrets() + throws IOException, CertificateException, CertStoreException { + SslContextBuilder sslContextBuilder = + GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext)); + if (privateKey != null && certificateChain != null) { + sslContextBuilder.keyManager( + new File(certificateChain), new File(privateKey), privateKeyPassword); + } + return sslContextBuilder.build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java new file mode 100644 index 0000000000..393c82055d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java @@ -0,0 +1,116 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext; +import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext; +import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateTlsCertificate; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; +import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; +import io.grpc.netty.GrpcSslContexts; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.File; +import java.io.IOException; +import java.security.cert.CertStoreException; +import java.security.cert.CertificateException; +import java.util.concurrent.Executor; +import javax.annotation.Nullable; + +/** A server SslContext provider that uses file-based secrets (secret volume). */ +final class SecretVolumeServerSslContextProvider extends SslContextProvider { + + @Nullable private final String privateKey; + @Nullable private final String privateKeyPassword; + @Nullable private final String certificateChain; + @Nullable private final CertificateValidationContext certContext; + + private SecretVolumeServerSslContextProvider( + @Nullable String privateKey, + @Nullable String privateKeyPassword, + @Nullable String certificateChain, + @Nullable CertificateValidationContext certContext, + DownstreamTlsContext downstreamTlsContext) { + super(new DownstreamTlsContextHolder(downstreamTlsContext)); + this.privateKey = privateKey; + this.privateKeyPassword = privateKeyPassword; + this.certificateChain = certificateChain; + this.certContext = certContext; + } + + static SecretVolumeServerSslContextProvider getProvider( + DownstreamTlsContext downstreamTlsContext) { + checkNotNull(downstreamTlsContext, "downstreamTlsContext"); + CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); + TlsCertificate tlsCertificate = null; + if (commonTlsContext.getTlsCertificatesCount() > 0) { + tlsCertificate = commonTlsContext.getTlsCertificates(0); + } + // first validate + validateTlsCertificate(tlsCertificate, /* optional= */ false); + CertificateValidationContext certificateValidationContext = + getCertificateValidationContext(commonTlsContext); + // certificateValidationContext exists in case of mTLS, else null for a server + if (certificateValidationContext != null) { + certificateValidationContext = + validateCertificateContext(certificateValidationContext, /* optional= */ true); + } + String privateKeyPassword = + tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null; + return new SecretVolumeServerSslContextProvider( + tlsCertificate.getPrivateKey().getFilename(), + privateKeyPassword, + tlsCertificate.getCertificateChain().getFilename(), + certificateValidationContext, + downstreamTlsContext); + } + + @Override + public void addCallback(final Callback callback, Executor executor) { + checkNotNull(callback, "callback"); + checkNotNull(executor, "executor"); + // as per the contract we will read the current secrets on disk + // this involves I/O which can potentially block the executor + performCallback( + new SslContextGetter() { + @Override + public SslContext get() throws CertificateException, IOException, CertStoreException { + return buildSslContextFromSecrets(); + } + }, + callback, + executor); + } + + @Override + public void close() {} + + @VisibleForTesting + SslContext buildSslContextFromSecrets() + throws IOException, CertificateException, CertStoreException { + SslContextBuilder sslContextBuilder = + GrpcSslContexts.forServer( + new File(certificateChain), new File(privateKey), privateKeyPassword); + setClientAuthValues(sslContextBuilder, certContext); + return sslContextBuilder.build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java deleted file mode 100644 index c545c81979..0000000000 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright 2019 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.sds; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.common.annotations.VisibleForTesting; -import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; -import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; -import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.ValidationContextTypeCase; -import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; -import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; -import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; -import io.envoyproxy.envoy.api.v2.core.DataSource.SpecifierCase; -import io.grpc.netty.GrpcSslContexts; -import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; -import io.netty.handler.ssl.SslContext; -import io.netty.handler.ssl.SslContextBuilder; -import java.io.File; -import java.io.IOException; -import java.security.cert.CertStoreException; -import java.security.cert.CertificateException; -import java.util.concurrent.Executor; -import javax.annotation.Nullable; - -/** - * An SslContext provider that uses file-based secrets (secret volume). Used for both server and - * client SslContexts - */ -final class SecretVolumeSslContextProvider extends SslContextProvider { - - @Nullable private final String privateKey; - @Nullable private final String privateKeyPassword; - @Nullable private final String certificateChain; - @Nullable private final CertificateValidationContext certContext; - - private SecretVolumeSslContextProvider( - @Nullable String privateKey, - @Nullable String privateKeyPassword, - @Nullable String certificateChain, - @Nullable CertificateValidationContext certContext, - boolean server, - K source) { - super(source, server); - this.privateKey = privateKey; - this.privateKeyPassword = privateKeyPassword; - this.certificateChain = certificateChain; - this.certContext = certContext; - } - - @VisibleForTesting - @Nullable - static CertificateValidationContext validateCertificateContext( - @Nullable CertificateValidationContext certContext, boolean optional) { - if (certContext == null || !certContext.hasTrustedCa()) { - checkArgument(optional, "certContext is required"); - return null; - } - checkArgument( - certContext.getTrustedCa().getSpecifierCase() == SpecifierCase.FILENAME, - "filename expected"); - return certContext; - } - - @VisibleForTesting - @Nullable - static TlsCertificate validateTlsCertificate( - @Nullable TlsCertificate tlsCertificate, boolean optional) { - if (tlsCertificate == null) { - checkArgument(optional, "tlsCertificate is required"); - return null; - } - if (optional - && (tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.SPECIFIER_NOT_SET) - && (tlsCertificate.getCertificateChain().getSpecifierCase() - == SpecifierCase.SPECIFIER_NOT_SET)) { - return null; - } - checkArgument( - tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.FILENAME, - "filename expected"); - checkArgument( - tlsCertificate.getCertificateChain().getSpecifierCase() == SpecifierCase.FILENAME, - "filename expected"); - return tlsCertificate; - } - - static SecretVolumeSslContextProvider getProviderForServer( - DownstreamTlsContext downstreamTlsContext) { - checkNotNull(downstreamTlsContext, "downstreamTlsContext"); - CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); - TlsCertificate tlsCertificate = null; - if (commonTlsContext.getTlsCertificatesCount() > 0) { - tlsCertificate = commonTlsContext.getTlsCertificates(0); - } - // first validate - validateTlsCertificate(tlsCertificate, /* optional= */ false); - CertificateValidationContext certificateValidationContext = - getCertificateValidationContext(commonTlsContext); - // certificateValidationContext exists in case of mTLS, else null for a server - if (certificateValidationContext != null) { - certificateValidationContext = - validateCertificateContext(certificateValidationContext, /* optional= */ true); - } - String privateKeyPassword = - tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null; - return new SecretVolumeSslContextProvider<>( - tlsCertificate.getPrivateKey().getFilename(), - privateKeyPassword, - tlsCertificate.getCertificateChain().getFilename(), - certificateValidationContext, - /* server= */ true, - downstreamTlsContext); - } - - static SecretVolumeSslContextProvider getProviderForClient( - UpstreamTlsContext upstreamTlsContext) { - checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); - CertificateValidationContext certificateValidationContext = - getCertificateValidationContext(commonTlsContext); - // first validate - validateCertificateContext(certificateValidationContext, /* optional= */ false); - TlsCertificate tlsCertificate = null; - if (commonTlsContext.getTlsCertificatesCount() > 0) { - tlsCertificate = commonTlsContext.getTlsCertificates(0); - } - // tlsCertificate exists in case of mTLS, else null for a client - if (tlsCertificate != null) { - tlsCertificate = validateTlsCertificate(tlsCertificate, /* optional= */ true); - } - String privateKey = null; - String privateKeyPassword = null; - String certificateChain = null; - if (tlsCertificate != null) { - privateKey = tlsCertificate.getPrivateKey().getFilename(); - if (tlsCertificate.hasPassword()) { - privateKeyPassword = tlsCertificate.getPassword().getInlineString(); - } - certificateChain = tlsCertificate.getCertificateChain().getFilename(); - } - return new SecretVolumeSslContextProvider<>( - privateKey, - privateKeyPassword, - certificateChain, - certificateValidationContext, - /* server= */ false, - upstreamTlsContext); - } - - private static CertificateValidationContext getCertificateValidationContext( - CommonTlsContext commonTlsContext) { - checkNotNull(commonTlsContext, "commonTlsContext"); - ValidationContextTypeCase type = commonTlsContext.getValidationContextTypeCase(); - checkState( - type == ValidationContextTypeCase.VALIDATION_CONTEXT - || type == ValidationContextTypeCase.VALIDATIONCONTEXTTYPE_NOT_SET, - "incorrect ValidationContextTypeCase"); - return type == ValidationContextTypeCase.VALIDATION_CONTEXT - ? commonTlsContext.getValidationContext() - : null; - } - - @Override - public void addCallback(final Callback callback, Executor executor) { - checkNotNull(callback, "callback"); - checkNotNull(executor, "executor"); - // as per the contract we will read the current secrets on disk - // this involves I/O which can potentially block the executor - performCallback( - new SslContextGetter() { - @Override - public SslContext get() throws CertificateException, IOException, CertStoreException { - return buildSslContextFromSecrets(); - } - }, - callback, - executor); - } - - @Override - public void close() {} - - @VisibleForTesting - SslContext buildSslContextFromSecrets() - throws IOException, CertificateException, CertStoreException { - SslContextBuilder sslContextBuilder; - if (server) { - sslContextBuilder = - GrpcSslContexts.forServer( - new File(certificateChain), new File(privateKey), privateKeyPassword); - setClientAuthValues(sslContextBuilder, certContext); - } else { - sslContextBuilder = - GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext)); - if (privateKey != null && certificateChain != null) { - sslContextBuilder.keyManager( - new File(certificateChain), new File(privateKey), privateKeyPassword); - } - } - return sslContextBuilder.build(); - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java index 24289b5cbf..36f3254c55 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java @@ -30,9 +30,9 @@ import java.util.concurrent.Executors; final class ServerSslContextProviderFactory implements SslContextProviderFactory { - /** Creates an SslContextProvider from the given DownstreamTlsContext. */ + /** Creates a SslContextProvider from the given DownstreamTlsContext. */ @Override - public SslContextProvider createSslContextProvider( + public SslContextProvider createSslContextProvider( DownstreamTlsContext downstreamTlsContext) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); checkArgument( @@ -40,11 +40,11 @@ final class ServerSslContextProviderFactory "downstreamTlsContext should have CommonTlsContext"); if (CommonTlsContextUtil.hasAllSecretsUsingFilename( downstreamTlsContext.getCommonTlsContext())) { - return SecretVolumeSslContextProvider.getProviderForServer(downstreamTlsContext); + return SecretVolumeServerSslContextProvider.getProvider(downstreamTlsContext); } else if (CommonTlsContextUtil.hasAllSecretsUsingSds( downstreamTlsContext.getCommonTlsContext())) { try { - return SdsSslContextProvider.getProviderForServer( + return SdsServerSslContextProvider.getProvider( downstreamTlsContext, Bootstrapper.getInstance().readBootstrap().getNode(), Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java index 4ec4e6cefa..d7f60a4c1e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java @@ -16,7 +16,6 @@ package io.grpc.xds.internal.sds; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; @@ -41,14 +40,11 @@ import java.util.logging.Logger; * stream that is receiving the requested secret(s) or it could represent file-system based * secret(s) that are dynamic. */ -// TODO(sanjaypujare): replace generic K with DownstreamTlsContext & UpstreamTlsContext in -// separate client&server classes -public abstract class SslContextProvider { +public abstract class SslContextProvider { private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName()); - protected final boolean server; - private final K source; + protected final TlsContextHolder tlsContextHolder; public interface Callback { /** Informs callee of new/updated SslContext. */ @@ -58,36 +54,20 @@ public abstract class SslContextProvider { void onException(Throwable throwable); } - protected SslContextProvider(K source, boolean server) { - if (server) { - checkArgument(source instanceof DownstreamTlsContext, "expecting DownstreamTlsContext"); - } else { - checkArgument(source instanceof UpstreamTlsContext, "expecting UpstreamTlsContext"); - } - this.source = source; - this.server = server; - } - - public K getSource() { - return source; + SslContextProvider(TlsContextHolder tlsContextHolder) { + this.tlsContextHolder = checkNotNull(tlsContextHolder, "tlsContextHolder"); } CommonTlsContext getCommonTlsContext() { - if (source instanceof UpstreamTlsContext) { - return ((UpstreamTlsContext) source).getCommonTlsContext(); - } else if (source instanceof DownstreamTlsContext) { - return ((DownstreamTlsContext) source).getCommonTlsContext(); - } - return null; + return tlsContextHolder.getCommonTlsContext(); } protected void setClientAuthValues( SslContextBuilder sslContextBuilder, CertificateValidationContext localCertValidationContext) throws CertificateException, IOException, CertStoreException { - checkState(server, "server side SslContextProvider expected"); + DownstreamTlsContext downstreamTlsContext = getDownstreamTlsContext(); if (localCertValidationContext != null) { sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext)); - DownstreamTlsContext downstreamTlsContext = (DownstreamTlsContext)getSource(); sslContextBuilder.clientAuth( downstreamTlsContext.hasRequireClientCertificate() ? ClientAuth.REQUIRE @@ -97,6 +77,20 @@ public abstract class SslContextProvider { } } + /** Returns the DownstreamTlsContext in this SslContextProvider if this is server side. **/ + public DownstreamTlsContext getDownstreamTlsContext() { + checkState(tlsContextHolder instanceof DownstreamTlsContextHolder, + "expected DownstreamTlsContextHolder"); + return ((DownstreamTlsContextHolder) tlsContextHolder).getDownstreamTlsContext(); + } + + /** Returns the UpstreamTlsContext in this SslContextProvider if this is client side. **/ + public UpstreamTlsContext getUpstreamTlsContext() { + checkState(tlsContextHolder instanceof UpstreamTlsContextHolder, + "expected UpstreamTlsContextHolder"); + return ((UpstreamTlsContextHolder) tlsContextHolder).getUpstreamTlsContext(); + } + /** Closes this provider and releases any resources. */ void close() {} @@ -106,7 +100,7 @@ public abstract class SslContextProvider { */ public abstract void addCallback(Callback callback, Executor executor); - protected void performCallback( + final void performCallback( final SslContextGetter sslContextGetter, final Callback callback, Executor executor) { checkNotNull(sslContextGetter, "sslContextGetter"); checkNotNull(callback, "callback"); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextHolder.java b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextHolder.java new file mode 100644 index 0000000000..06f69ea5e0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextHolder.java @@ -0,0 +1,29 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; + +/** + * A holder of {@link UpstreamTlsContext} or + * {@link io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext}. + */ +public interface TlsContextHolder { + + CommonTlsContext getCommonTlsContext(); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManager.java b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManager.java index 6dee14353e..6819336cb9 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManager.java @@ -22,11 +22,11 @@ import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; public interface TlsContextManager { /** Creates a SslContextProvider. Used for retrieving a server-side SslContext. */ - SslContextProvider findOrCreateServerSslContextProvider( + SslContextProvider findOrCreateServerSslContextProvider( DownstreamTlsContext downstreamTlsContext); /** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */ - SslContextProvider findOrCreateClientSslContextProvider( + SslContextProvider findOrCreateClientSslContextProvider( UpstreamTlsContext upstreamTlsContext); /** @@ -38,8 +38,7 @@ public interface TlsContextManager { *

Caller must not release a reference more than once. It's advised that you clear the * reference to the instance with the null returned by this method. */ - SslContextProvider releaseClientSslContextProvider( - SslContextProvider sslContextProvider); + SslContextProvider releaseClientSslContextProvider(SslContextProvider sslContextProvider); /** * Releases an instance of the given server-side {@link SslContextProvider}. @@ -50,6 +49,5 @@ public interface TlsContextManager { *

Caller must not release a reference more than once. It's advised that you clear the * reference to the instance with the null returned by this method. */ - SslContextProvider releaseServerSslContextProvider( - SslContextProvider sslContextProvider); + SslContextProvider releaseServerSslContextProvider(SslContextProvider sslContextProvider); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java index 84335b0cd5..42d19c4f11 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java @@ -59,30 +59,32 @@ public final class TlsContextManagerImpl implements TlsContextManager { } @Override - public SslContextProvider findOrCreateServerSslContextProvider( + public SslContextProvider findOrCreateServerSslContextProvider( DownstreamTlsContext downstreamTlsContext) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); return mapForServers.get(downstreamTlsContext); } @Override - public SslContextProvider findOrCreateClientSslContextProvider( + public SslContextProvider findOrCreateClientSslContextProvider( UpstreamTlsContext upstreamTlsContext) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); return mapForClients.get(upstreamTlsContext); } @Override - public SslContextProvider releaseClientSslContextProvider( - SslContextProvider sslContextProvider) { - checkNotNull(sslContextProvider, "sslContextProvider"); - return mapForClients.release(sslContextProvider); + public SslContextProvider releaseClientSslContextProvider( + SslContextProvider clientSslContextProvider) { + checkNotNull(clientSslContextProvider, "clientSslContextProvider"); + return mapForClients.release(clientSslContextProvider.getUpstreamTlsContext(), + clientSslContextProvider); } @Override - public SslContextProvider releaseServerSslContextProvider( - SslContextProvider sslContextProvider) { - checkNotNull(sslContextProvider, "sslContextProvider"); - return mapForServers.release(sslContextProvider); + public SslContextProvider releaseServerSslContextProvider( + SslContextProvider serverSslContextProvider) { + checkNotNull(serverSslContextProvider, "serverSslContextProvider"); + return mapForServers.release(serverSslContextProvider.getDownstreamTlsContext(), + serverSslContextProvider); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/UpstreamTlsContextHolder.java b/xds/src/main/java/io/grpc/xds/internal/sds/UpstreamTlsContextHolder.java new file mode 100644 index 0000000000..3b4ade3642 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/UpstreamTlsContextHolder.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; + +final class UpstreamTlsContextHolder implements TlsContextHolder { + + private final UpstreamTlsContext upstreamTlsContext; + + UpstreamTlsContextHolder(UpstreamTlsContext upstreamTlsContext) { + this.upstreamTlsContext = checkNotNull(upstreamTlsContext, "upstreamTlsContext"); + } + + public UpstreamTlsContext getUpstreamTlsContext() { + return upstreamTlsContext; + } + + @Override + public CommonTlsContext getCommonTlsContext() { + return upstreamTlsContext.getCommonTlsContext(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java index a137bb365d..7c48b0d838 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java @@ -356,9 +356,8 @@ public class CdsLoadBalancerTest { CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); - SslContextProvider mockSslContextProvider = - (SslContextProvider) mock(SslContextProvider.class); - doReturn(upstreamTlsContext).when(mockSslContextProvider).getSource(); + SslContextProvider mockSslContextProvider = mock(SslContextProvider.class); + doReturn(upstreamTlsContext).when(mockSslContextProvider).getUpstreamTlsContext(); doReturn(mockSslContextProvider).when(mockTlsContextManager) .findOrCreateClientSslContextProvider(same(upstreamTlsContext)); @@ -373,8 +372,8 @@ public class CdsLoadBalancerTest { assertThat(edsLbHelpers).hasSize(1); assertThat(edsLoadBalancers).hasSize(1); - verify(mockTlsContextManager, never()).releaseClientSslContextProvider( - (SslContextProvider) any(SslContextProvider.class)); + verify(mockTlsContextManager, never()) + .releaseClientSslContextProvider(any(SslContextProvider.class)); Helper edsLbHelper1 = edsLbHelpers.poll(); ArrayList eagList = new ArrayList<>(); @@ -403,8 +402,8 @@ public class CdsLoadBalancerTest { .setUpstreamTlsContext(upstreamTlsContext) .build()); - verify(mockTlsContextManager, never()).releaseClientSslContextProvider( - (SslContextProvider) any(SslContextProvider.class)); + verify(mockTlsContextManager, never()) + .releaseClientSslContextProvider(any(SslContextProvider.class)); verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider( any(UpstreamTlsContext.class)); @@ -414,9 +413,8 @@ public class CdsLoadBalancerTest { UpstreamTlsContext upstreamTlsContext1 = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); - SslContextProvider mockSslContextProvider1 = - (SslContextProvider) mock(SslContextProvider.class); - doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getSource(); + SslContextProvider mockSslContextProvider1 = mock(SslContextProvider.class); + doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getUpstreamTlsContext(); doReturn(mockSslContextProvider1).when(mockTlsContextManager) .findOrCreateClientSslContextProvider(same(upstreamTlsContext1)); clusterWatcher1.onClusterChanged( diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index be86950560..0b592a6860 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -41,7 +41,7 @@ public class ClientSslContextProviderFactoryTest { CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); - SslContextProvider sslContextProvider = + SslContextProvider sslContextProvider = clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); assertThat(sslContextProvider).isNotNull(); } @@ -55,7 +55,7 @@ public class ClientSslContextProviderFactoryTest { SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext); try { - SslContextProvider unused = + SslContextProvider unused = clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { @@ -77,7 +77,7 @@ public class ClientSslContextProviderFactoryTest { SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext); try { - SslContextProvider unused = + SslContextProvider unused = clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMapTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMapTest.java index 7cab5060d3..e05e2ee50d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMapTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMapTest.java @@ -51,60 +51,53 @@ public class ReferenceCountingSslContextProviderMapTest { @Test public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException { - SslContextProvider valueFor3 = getTypedMock(); + SslContextProvider valueFor3 = getTypedMock(); when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); - SslContextProvider val = map.get(3); + SslContextProvider val = map.get(3); assertThat(val).isSameInstanceAs(valueFor3); verify(valueFor3, never()).close(); val = map.get(3); assertThat(val).isSameInstanceAs(valueFor3); // at this point ref-count is 2 - when(valueFor3.getSource()).thenReturn(3); - assertThat(map.release(val)).isNull(); + assertThat(map.release(3, val)).isNull(); verify(valueFor3, never()).close(); - assertThat(map.release(val)).isNull(); // after this ref-count is 0 + assertThat(map.release(3, val)).isNull(); // after this ref-count is 0 verify(valueFor3, times(1)).close(); } - @SuppressWarnings("unchecked") - private static SslContextProvider getTypedMock() { + private static SslContextProvider getTypedMock() { return mock(SslContextProvider.class); } @Test public void referenceCountingMap_distinctElements() throws InterruptedException { - SslContextProvider valueFor3 = getTypedMock(); - SslContextProvider valueFor4 = getTypedMock(); - when(valueFor3.getSource()).thenReturn(3); - when(valueFor4.getSource()).thenReturn(4); + SslContextProvider valueFor3 = getTypedMock(); + SslContextProvider valueFor4 = getTypedMock(); when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); - SslContextProvider val3 = map.get(3); + SslContextProvider val3 = map.get(3); assertThat(val3).isSameInstanceAs(valueFor3); - SslContextProvider val4 = map.get(4); + SslContextProvider val4 = map.get(4); assertThat(val4).isSameInstanceAs(valueFor4); - assertThat(map.release(val3)).isNull(); + assertThat(map.release(3, val3)).isNull(); verify(valueFor3, times(1)).close(); verify(valueFor4, never()).close(); - assertThat(map.release(val4)).isNull(); + assertThat(map.release(4, val4)).isNull(); verify(valueFor4, times(1)).close(); } @Test public void referenceCountingMap_releaseWrongElement_expectException() throws InterruptedException { - SslContextProvider valueFor3 = getTypedMock(); - SslContextProvider valueFor4 = getTypedMock(); - when(valueFor3.getSource()).thenReturn(3); - when(valueFor4.getSource()).thenReturn(4); + SslContextProvider valueFor3 = getTypedMock(); + SslContextProvider valueFor4 = getTypedMock(); when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); - SslContextProvider unused = map.get(3); - SslContextProvider val4 = map.get(4); + SslContextProvider unused = map.get(3); + SslContextProvider val4 = map.get(4); // now provide wrong key (3) and value (val4) combination - when(valueFor4.getSource()).thenReturn(3); try { - map.release(val4); + map.release(3, val4); fail("exception expected"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().contains("Releasing the wrong instance"); @@ -113,16 +106,15 @@ public class ReferenceCountingSslContextProviderMapTest { @Test public void referenceCountingMap_excessRelease_expectException() throws InterruptedException { - SslContextProvider valueFor4 = getTypedMock(); - when(valueFor4.getSource()).thenReturn(4); + SslContextProvider valueFor4 = getTypedMock(); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); - SslContextProvider val = map.get(4); + SslContextProvider val = map.get(4); assertThat(val).isSameInstanceAs(valueFor4); // at this point ref-count is 1 - map.release(val); + map.release(4, val); // at this point ref-count is 0 try { - map.release(val); + map.release(4, val); fail("exception expected"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().contains("No cached instance found for 4"); @@ -131,16 +123,15 @@ public class ReferenceCountingSslContextProviderMapTest { @Test public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException { - SslContextProvider valueFor4 = getTypedMock(); - when(valueFor4.getSource()).thenReturn(4); + SslContextProvider valueFor4 = getTypedMock(); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); - SslContextProvider val = map.get(4); + SslContextProvider val = map.get(4); assertThat(val).isSameInstanceAs(valueFor4); // at this point ref-count is 1 - map.release(val); + map.release(4, val); // at this point ref-count is 0 and val is removed // should get another instance for 4 - SslContextProvider valueFor4a = getTypedMock(); + SslContextProvider valueFor4a = getTypedMock(); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4a); val = map.get(4); assertThat(val).isSameInstanceAs(valueFor4a); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java index 18759eb282..56054aa4f6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java @@ -40,7 +40,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link SdsSslContextProvider}. */ +/** Unit tests for {@link SdsClientSslContextProvider}. */ @RunWith(JUnit4.class) public class SdsSslContextProviderTest { @@ -62,10 +62,13 @@ public class SdsSslContextProviderTest { server.shutdown(); } - /** Helper method to build SdsSslContextProvider from given names. */ - private SdsSslContextProvider getSdsSslContextProvider( - boolean server, String certName, String validationContextName, - Iterable verifySubjectAltNames, Iterable alpnProtocols) throws IOException { + /** Helper method to build SdsClientSslContextProvider from given names. */ + private SdsClientSslContextProvider getSdsClientSslContextProvider( + String certName, + String validationContextName, + Iterable verifySubjectAltNames, + Iterable alpnProtocols) + throws IOException { CommonTlsContext commonTlsContext = CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues( @@ -77,18 +80,37 @@ public class SdsSslContextProviderTest { alpnProtocols, /* channelType= */ "inproc"); - return server - ? SdsSslContextProvider.getProviderForServer( - CommonTlsContextTestsUtil.buildDownstreamTlsContext( - commonTlsContext, /* requireClientCert= */ false), - node, - MoreExecutors.directExecutor(), - MoreExecutors.directExecutor()) - : SdsSslContextProvider.getProviderForClient( - SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext), - node, - MoreExecutors.directExecutor(), - MoreExecutors.directExecutor()); + return SdsClientSslContextProvider.getProvider( + SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext), + node, + MoreExecutors.directExecutor(), + MoreExecutors.directExecutor()); + } + + /** Helper method to build SdsServerSslContextProvider from given names. */ + private SdsServerSslContextProvider getSdsServerSslContextProvider( + String certName, + String validationContextName, + Iterable verifySubjectAltNames, + Iterable alpnProtocols) + throws IOException { + + CommonTlsContext commonTlsContext = + CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues( + certName, + /* certTargetUri= */ "inproc", + validationContextName, + /* validationContextTargetUri= */ "inproc", + verifySubjectAltNames, + alpnProtocols, + /* channelType= */ "inproc"); + + return SdsServerSslContextProvider.getProvider( + CommonTlsContextTestsUtil.buildDownstreamTlsContext( + commonTlsContext, /* requireClientCert= */ false), + node, + MoreExecutors.directExecutor(), + MoreExecutors.directExecutor()); } @Test @@ -98,8 +120,8 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor(/* name= */ "valid1")) .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider(/* server= */ true, "cert1", "valid1", null, null); + SdsServerSslContextProvider provider = + getSdsServerSslContextProvider("cert1", "valid1", null, null); SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.getValueThruCallback(provider); @@ -113,9 +135,8 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor("valid1")) .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider( - /* server= */ false, + SdsClientSslContextProvider provider = + getSdsClientSslContextProvider( /* certName= */ "cert1", /* validationContextName= */ "valid1", /* verifySubjectAltNames= */ null, @@ -131,10 +152,12 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor(/* name= */ "cert1")) .thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider( - /* server= */ true, /* certName= */ "cert1", /* validationContextName= */ null, - /* verifySubjectAltNames= */ null, /* alpnProtocols= */ null); + SdsServerSslContextProvider provider = + getSdsServerSslContextProvider( + /* certName= */ "cert1", + /* validationContextName= */ null, + /* verifySubjectAltNames= */ null, + /* alpnProtocols= */ null); SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.getValueThruCallback(provider); @@ -146,10 +169,12 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor(/* name= */ "valid1")) .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider( - /* server= */ false, /* certName= */ null, /* validationContextName= */ "valid1", - /* verifySubjectAltNames= */ null, null); + SdsClientSslContextProvider provider = + getSdsClientSslContextProvider( + /* certName= */ null, + /* validationContextName= */ "valid1", + /* verifySubjectAltNames= */ null, + null); SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.getValueThruCallback(provider); @@ -161,10 +186,12 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor(/* name= */ "valid1")) .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider( - /* server= */ true, /* certName= */ null, /* validationContextName= */ "valid1", - /* verifySubjectAltNames= */ null, /* alpnProtocols= */ null); + SdsServerSslContextProvider provider = + getSdsServerSslContextProvider( + /* certName= */ null, + /* validationContextName= */ "valid1", + /* verifySubjectAltNames= */ null, + /* alpnProtocols= */ null); SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.getValueThruCallback(provider); @@ -184,13 +211,11 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor("valid1")) .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider( - /* server= */ false, + SdsClientSslContextProvider provider = + getSdsClientSslContextProvider( /* certName= */ "cert1", /* validationContextName= */ "valid1", - Arrays.asList( - "spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"), + Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"), /* alpnProtocols= */ null); SecretVolumeSslContextProviderTest.TestCallback testCallback = @@ -205,9 +230,8 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor("valid1")) .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider( - /* server= */ false, + SdsClientSslContextProvider provider = + getSdsClientSslContextProvider( /* certName= */ "cert1", /* validationContextName= */ "valid1", /* verifySubjectAltNames= */ null, @@ -226,9 +250,8 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor(/* name= */ "valid1")) .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); - SdsSslContextProvider provider = - getSdsSslContextProvider( - /* server= */ true, + SdsServerSslContextProvider provider = + getSdsServerSslContextProvider( /* certName= */ "cert1", /* validationContextName= */ "valid1", /* verifySubjectAltNames= */ null, diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java index 86c4b816e5..e59affabec 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java @@ -41,7 +41,7 @@ import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link SecretVolumeSslContextProvider}. */ +/** Unit tests for {@link SecretVolumeClientSslContextProvider}. */ @RunWith(JUnit4.class) public class SecretVolumeSslContextProviderTest { @@ -51,7 +51,7 @@ public class SecretVolumeSslContextProviderTest { public void validateCertificateContext_nullAndNotOptional_throwsException() { // expect exception when certContext is null and not optional try { - SecretVolumeSslContextProvider.validateCertificateContext( + CommonTlsContextUtil.validateCertificateContext( /* certContext= */ null, /* optional= */ false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { @@ -64,8 +64,7 @@ public class SecretVolumeSslContextProviderTest { // expect exception when certContext has no CA and not optional CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); try { - SecretVolumeSslContextProvider.validateCertificateContext( - certContext, /* optional= */ false); + CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("certContext is required"); @@ -76,7 +75,7 @@ public class SecretVolumeSslContextProviderTest { public void validateCertificateContext_nullAndOptional() { // certContext argument can be null when optional CertificateValidationContext certContext = - SecretVolumeSslContextProvider.validateCertificateContext( + CommonTlsContextUtil.validateCertificateContext( /* certContext= */ null, /* optional= */ true); assertThat(certContext).isNull(); } @@ -85,9 +84,7 @@ public class SecretVolumeSslContextProviderTest { public void validateCertificateContext_missingTrustCaOptional() { // certContext argument can have missing CA when optional CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); - assertThat( - SecretVolumeSslContextProvider.validateCertificateContext( - certContext, /* optional= */ true)) + assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ true)) .isNull(); } @@ -99,8 +96,7 @@ public class SecretVolumeSslContextProviderTest { .setTrustedCa(DataSource.newBuilder().setInlineString("foo")) .build(); try { - SecretVolumeSslContextProvider.validateCertificateContext( - certContext, /* optional= */ false); + CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -114,9 +110,7 @@ public class SecretVolumeSslContextProviderTest { CertificateValidationContext.newBuilder() .setTrustedCa(DataSource.newBuilder().setFilename("bar")) .build(); - assertThat( - SecretVolumeSslContextProvider.validateCertificateContext( - certContext, /* optional= */ false)) + assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false)) .isSameInstanceAs(certContext); } @@ -124,7 +118,7 @@ public class SecretVolumeSslContextProviderTest { public void validateTlsCertificate_nullAndNotOptional_throwsException() { // expect exception when tlsCertificate is null and not optional try { - SecretVolumeSslContextProvider.validateTlsCertificate( + CommonTlsContextUtil.validateTlsCertificate( /* tlsCertificate= */ null, /* optional= */ false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { @@ -135,7 +129,7 @@ public class SecretVolumeSslContextProviderTest { @Test public void validateTlsCertificate_nullOptional() { assertThat( - SecretVolumeSslContextProvider.validateTlsCertificate( + CommonTlsContextUtil.validateTlsCertificate( /* tlsCertificate= */ null, /* optional= */ true)) .isNull(); } @@ -144,10 +138,7 @@ public class SecretVolumeSslContextProviderTest { public void validateTlsCertificate_defaultInstance_returnsNull() { // tlsCertificate is not null but has no value (default instance): expect null TlsCertificate tlsCert = TlsCertificate.getDefaultInstance(); - assertThat( - SecretVolumeSslContextProvider.validateTlsCertificate( - tlsCert, /* optional= */ true)) - .isNull(); + assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true)).isNull(); } @Test @@ -158,7 +149,7 @@ public class SecretVolumeSslContextProviderTest { .setPrivateKey(DataSource.newBuilder().setInlineString("foo")) .build(); try { - SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ false); + CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -173,7 +164,7 @@ public class SecretVolumeSslContextProviderTest { .setPrivateKey(DataSource.newBuilder().setInlineString("foo")) .build(); try { - SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); + CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -188,7 +179,7 @@ public class SecretVolumeSslContextProviderTest { .setCertificateChain(DataSource.newBuilder().setInlineString("foo")) .build(); try { - SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ false); + CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -203,7 +194,7 @@ public class SecretVolumeSslContextProviderTest { .setCertificateChain(DataSource.newBuilder().setInlineString("foo")) .build(); try { - SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); + CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -217,9 +208,7 @@ public class SecretVolumeSslContextProviderTest { .setCertificateChain(DataSource.newBuilder().setFilename("foo")) .setPrivateKey(DataSource.newBuilder().setFilename("bar")) .build(); - assertThat( - SecretVolumeSslContextProvider.validateTlsCertificate( - tlsCert, /* optional= */ true)) + assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true)) .isSameInstanceAs(tlsCert); } @@ -230,9 +219,7 @@ public class SecretVolumeSslContextProviderTest { .setCertificateChain(DataSource.newBuilder().setFilename("foo")) .setPrivateKey(DataSource.newBuilder().setFilename("bar")) .build(); - assertThat( - SecretVolumeSslContextProvider.validateTlsCertificate( - tlsCert, /* optional= */ false)) + assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false)) .isSameInstanceAs(tlsCert); } @@ -245,7 +232,7 @@ public class SecretVolumeSslContextProviderTest { .setPrivateKey(DataSource.newBuilder().setFilename("bar")) .build(); try { - SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); + CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -261,7 +248,7 @@ public class SecretVolumeSslContextProviderTest { .setCertificateChain(DataSource.newBuilder().setFilename("bar")) .build(); try { - SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); + CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -272,7 +259,7 @@ public class SecretVolumeSslContextProviderTest { public void getProviderForServer_defaultTlsCertificate_throwsException() { TlsCertificate tlsCert = TlsCertificate.getDefaultInstance(); try { - SecretVolumeSslContextProvider.getProviderForServer( + SecretVolumeServerSslContextProvider.getProvider( CommonTlsContextTestsUtil.buildDownstreamTlsContext( CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null), /* requireClientCert= */ false)); @@ -294,7 +281,7 @@ public class SecretVolumeSslContextProviderTest { .setTrustedCa(DataSource.newBuilder().setInlineString("foo")) .build(); try { - SecretVolumeSslContextProvider.getProviderForServer( + SecretVolumeServerSslContextProvider.getProvider( CommonTlsContextTestsUtil.buildDownstreamTlsContext( CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext), /* requireClientCert= */ false)); @@ -308,7 +295,7 @@ public class SecretVolumeSslContextProviderTest { public void getProviderForClient_defaultCertContext_throwsException() { CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); try { - SecretVolumeSslContextProvider.getProviderForClient( + SecretVolumeClientSslContextProvider.getProvider( buildUpstreamTlsContext( CommonTlsContextTestsUtil.getCommonTlsContext( /* tlsCertificate= */ null, certContext))); @@ -330,7 +317,7 @@ public class SecretVolumeSslContextProviderTest { .setTrustedCa(DataSource.newBuilder().setFilename("foo")) .build(); try { - SecretVolumeSslContextProvider.getProviderForClient( + SecretVolumeClientSslContextProvider.getProvider( buildUpstreamTlsContext( CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); Assert.fail("no exception thrown"); @@ -351,7 +338,7 @@ public class SecretVolumeSslContextProviderTest { .setTrustedCa(DataSource.newBuilder().setFilename("foo")) .build(); try { - SecretVolumeSslContextProvider.getProviderForClient( + SecretVolumeClientSslContextProvider.getProvider( buildUpstreamTlsContext( CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); Assert.fail("no exception thrown"); @@ -360,22 +347,6 @@ public class SecretVolumeSslContextProviderTest { } } - /** Helper method to build SecretVolumeSslContextProvider from given files. */ - private static SecretVolumeSslContextProvider getSslContextSecretVolumeSecretProvider( - boolean server, - String certChainFilename, - String privateKeyFilename, - String trustedCaFilename) { - - return server - ? SecretVolumeSslContextProvider.getProviderForServer( - CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( - privateKeyFilename, certChainFilename, trustedCaFilename)) - : SecretVolumeSslContextProvider.getProviderForClient( - CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( - privateKeyFilename, certChainFilename, trustedCaFilename)); - } - /** * Helper method to build SecretVolumeSslContextProvider, call buildSslContext on it and * check returned SslContext. @@ -383,10 +354,22 @@ public class SecretVolumeSslContextProviderTest { private static void sslContextForEitherWithBothCertAndTrust( boolean server, String pemFile, String keyFile, String caFile) throws IOException, CertificateException, CertStoreException { - SecretVolumeSslContextProvider provider = - getSslContextSecretVolumeSecretProvider(server, pemFile, keyFile, caFile); + SslContext sslContext = null; + if (server) { + SecretVolumeServerSslContextProvider provider = + SecretVolumeServerSslContextProvider.getProvider( + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + keyFile, pemFile, caFile)); - SslContext sslContext = provider.buildSslContextFromSecrets(); + sslContext = provider.buildSslContextFromSecrets(); + } else { + SecretVolumeClientSslContextProvider provider = + SecretVolumeClientSslContextProvider.getProvider( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + keyFile, pemFile, caFile)); + + sslContext = provider.buildSslContextFromSecrets(); + } doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null); } @@ -469,7 +452,7 @@ public class SecretVolumeSslContextProviderTest { * Helper method to get the value thru directExecutor callback. Because of directExecutor this is * a synchronous callback - so need to provide a listener. */ - static TestCallback getValueThruCallback(SslContextProvider provider) { + static TestCallback getValueThruCallback(SslContextProvider provider) { TestCallback testCallback = new TestCallback(); provider.addCallback(testCallback, MoreExecutors.directExecutor()); return testCallback; @@ -477,9 +460,10 @@ public class SecretVolumeSslContextProviderTest { @Test public void getProviderForServer_both_callsback() throws IOException { - SecretVolumeSslContextProvider provider = - getSslContextSecretVolumeSecretProvider( - true, SERVER_1_PEM_FILE, SERVER_1_KEY_FILE, CA_PEM_FILE); + SecretVolumeServerSslContextProvider provider = + SecretVolumeServerSslContextProvider.getProvider( + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE)); TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null); @@ -487,9 +471,10 @@ public class SecretVolumeSslContextProviderTest { @Test public void getProviderForClient_both_callsback() throws IOException { - SecretVolumeSslContextProvider provider = - getSslContextSecretVolumeSecretProvider( - false, CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE); + SecretVolumeClientSslContextProvider provider = + SecretVolumeClientSslContextProvider.getProvider( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE)); TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); @@ -498,9 +483,10 @@ public class SecretVolumeSslContextProviderTest { // note this test generates stack-trace but can be safely ignored @Test public void getProviderForClient_both_callsback_setException() throws IOException { - SecretVolumeSslContextProvider provider = - getSslContextSecretVolumeSecretProvider( - false, CLIENT_PEM_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + SecretVolumeClientSslContextProvider provider = + SecretVolumeClientSslContextProvider.getProvider( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + CLIENT_PEM_FILE, CLIENT_PEM_FILE, CA_PEM_FILE)); TestCallback testCallback = getValueThruCallback(provider); assertThat(testCallback.updatedSslContext).isNull(); assertThat(testCallback.updatedThrowable).isInstanceOf(IllegalArgumentException.class); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index ad8542f9d7..bc9bb47f42 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -41,7 +41,7 @@ public class ServerSslContextProviderFactoryTest { CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); - SslContextProvider sslContextProvider = + SslContextProvider sslContextProvider = serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); assertThat(sslContextProvider).isNotNull(); } @@ -56,7 +56,7 @@ public class ServerSslContextProviderFactoryTest { commonTlsContext, /* requireClientCert= */ false); try { - SslContextProvider unused = + SslContextProvider unused = serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { @@ -76,7 +76,7 @@ public class ServerSslContextProviderFactoryTest { commonTlsContext, /* requireClientCert= */ false); try { - SslContextProvider unused = + SslContextProvider unused = serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java index 0c3617f11e..85ff77d167 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java @@ -49,11 +49,9 @@ public class TlsContextManagerTest { @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); - @Mock - SslContextProviderFactory mockClientFactory; + @Mock SslContextProviderFactory mockClientFactory; - @Mock - SslContextProviderFactory mockServerFactory; + @Mock SslContextProviderFactory mockServerFactory; @Before public void clearInstance() throws NoSuchFieldException, IllegalAccessException { @@ -69,11 +67,11 @@ public class TlsContextManagerTest { SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); - SslContextProvider serverSecretProvider = + SslContextProvider serverSecretProvider = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isNotNull(); - SslContextProvider serverSecretProvider1 = + SslContextProvider serverSecretProvider1 = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider1).isSameInstanceAs(serverSecretProvider); } @@ -85,11 +83,11 @@ public class TlsContextManagerTest { /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); - SslContextProvider clientSecretProvider = + SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isNotNull(); - SslContextProvider clientSecretProvider1 = + SslContextProvider clientSecretProvider1 = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider1).isSameInstanceAs(clientSecretProvider); } @@ -101,14 +99,14 @@ public class TlsContextManagerTest { SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); - SslContextProvider serverSecretProvider = + SslContextProvider serverSecretProvider = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isNotNull(); DownstreamTlsContext downstreamTlsContext1 = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE); - SslContextProvider serverSecretProvider1 = + SslContextProvider serverSecretProvider1 = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1); assertThat(serverSecretProvider1).isNotNull(); assertThat(serverSecretProvider1).isNotSameInstanceAs(serverSecretProvider); @@ -121,7 +119,7 @@ public class TlsContextManagerTest { /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); - SslContextProvider clientSecretProvider = + SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isNotNull(); @@ -129,7 +127,7 @@ public class TlsContextManagerTest { CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); - SslContextProvider clientSecretProvider1 = + SslContextProvider clientSecretProvider1 = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1); assertThat(clientSecretProvider1).isNotSameInstanceAs(clientSecretProvider); } @@ -143,13 +141,13 @@ public class TlsContextManagerTest { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); @SuppressWarnings("unchecked") - SslContextProvider mockProvider = mock(SslContextProvider.class); + SslContextProvider mockProvider = mock(SslContextProvider.class); when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider); - SslContextProvider serverSecretProvider = + SslContextProvider serverSecretProvider = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isSameInstanceAs(mockProvider); verify(mockProvider, never()).close(); - when(mockProvider.getSource()).thenReturn(downstreamTlsContext); + when(mockProvider.getDownstreamTlsContext()).thenReturn(downstreamTlsContext); tlsContextManagerImpl.releaseServerSslContextProvider(mockProvider); verify(mockProvider, times(1)).close(); } @@ -163,13 +161,13 @@ public class TlsContextManagerTest { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); @SuppressWarnings("unchecked") - SslContextProvider mockProvider = mock(SslContextProvider.class); + SslContextProvider mockProvider = mock(SslContextProvider.class); when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider); - SslContextProvider clientSecretProvider = + SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isSameInstanceAs(mockProvider); verify(mockProvider, never()).close(); - when(mockProvider.getSource()).thenReturn(upstreamTlsContext); + when(mockProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext); tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider); verify(mockProvider, times(1)).close(); }