xds: replace generic with individual client and server SslContextProviders (#7059)

This commit is contained in:
sanjaypujare 2020-05-27 12:31:54 -07:00 committed by GitHub
parent 7d2d2ec035
commit 62620ccd00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 858 additions and 579 deletions

View File

@ -206,10 +206,10 @@ public final class CdsLoadBalancer extends LoadBalancer {
private static final class EdsLoadBalancingHelper extends ForwardingLoadBalancerHelper { private static final class EdsLoadBalancingHelper extends ForwardingLoadBalancerHelper {
private final Helper delegate; private final Helper delegate;
private final AtomicReference<SslContextProvider<UpstreamTlsContext>> sslContextProvider; private final AtomicReference<SslContextProvider> sslContextProvider;
EdsLoadBalancingHelper(Helper helper, EdsLoadBalancingHelper(Helper helper,
AtomicReference<SslContextProvider<UpstreamTlsContext>> sslContextProvider) { AtomicReference<SslContextProvider> sslContextProvider) {
this.delegate = helper; this.delegate = helper;
this.sslContextProvider = sslContextProvider; this.sslContextProvider = sslContextProvider;
} }
@ -222,7 +222,7 @@ public final class CdsLoadBalancer extends LoadBalancer {
.toBuilder() .toBuilder()
.setAddresses( .setAddresses(
addUpstreamTlsContext(createSubchannelArgs.getAddresses(), addUpstreamTlsContext(createSubchannelArgs.getAddresses(),
sslContextProvider.get().getSource())) sslContextProvider.get().getUpstreamTlsContext()))
.build(); .build();
} }
return delegate.createSubchannel(createSubchannelArgs); return delegate.createSubchannel(createSubchannelArgs);
@ -267,7 +267,7 @@ public final class CdsLoadBalancer extends LoadBalancer {
ClusterWatcherImpl(Helper helper, ResolvedAddresses resolvedAddresses) { ClusterWatcherImpl(Helper helper, ResolvedAddresses resolvedAddresses) {
this.helper = new EdsLoadBalancingHelper(helper, this.helper = new EdsLoadBalancingHelper(helper,
new AtomicReference<SslContextProvider<UpstreamTlsContext>>()); new AtomicReference<SslContextProvider>());
this.resolvedAddresses = resolvedAddresses; this.resolvedAddresses = resolvedAddresses;
} }
@ -303,10 +303,10 @@ public final class CdsLoadBalancer extends LoadBalancer {
/** For new UpstreamTlsContext value, release old SslContextProvider. */ /** For new UpstreamTlsContext value, release old SslContextProvider. */
private void updateSslContextProvider(UpstreamTlsContext newUpstreamTlsContext) { private void updateSslContextProvider(UpstreamTlsContext newUpstreamTlsContext) {
SslContextProvider<UpstreamTlsContext> oldSslContextProvider = SslContextProvider oldSslContextProvider =
helper.sslContextProvider.get(); helper.sslContextProvider.get();
if (oldSslContextProvider != null) { if (oldSslContextProvider != null) {
UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getSource(); UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getUpstreamTlsContext();
if (oldUpstreamTlsContext.equals(newUpstreamTlsContext)) { if (oldUpstreamTlsContext.equals(newUpstreamTlsContext)) {
return; return;
@ -314,7 +314,7 @@ public final class CdsLoadBalancer extends LoadBalancer {
tlsContextManager.releaseClientSslContextProvider(oldSslContextProvider); tlsContextManager.releaseClientSslContextProvider(oldSslContextProvider);
} }
if (newUpstreamTlsContext != null) { if (newUpstreamTlsContext != null) {
SslContextProvider<UpstreamTlsContext> newSslContextProvider = SslContextProvider newSslContextProvider =
tlsContextManager.findOrCreateClientSslContextProvider(newUpstreamTlsContext); tlsContextManager.findOrCreateClientSslContextProvider(newUpstreamTlsContext);
helper.sslContextProvider.set(newSslContextProvider); helper.sslContextProvider.set(newSslContextProvider);
} else { } else {

View File

@ -32,18 +32,17 @@ final class ClientSslContextProviderFactory
/** Creates an SslContextProvider from the given UpstreamTlsContext. */ /** Creates an SslContextProvider from the given UpstreamTlsContext. */
@Override @Override
public SslContextProvider<UpstreamTlsContext> createSslContextProvider( public SslContextProvider createSslContextProvider(UpstreamTlsContext upstreamTlsContext) {
UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext"); checkNotNull(upstreamTlsContext, "upstreamTlsContext");
checkArgument( checkArgument(
upstreamTlsContext.hasCommonTlsContext(), upstreamTlsContext.hasCommonTlsContext(),
"upstreamTlsContext should have CommonTlsContext"); "upstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename(upstreamTlsContext.getCommonTlsContext())) { if (CommonTlsContextUtil.hasAllSecretsUsingFilename(upstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeSslContextProvider.getProviderForClient(upstreamTlsContext); return SecretVolumeClientSslContextProvider.getProvider(upstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds( } else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
upstreamTlsContext.getCommonTlsContext())) { upstreamTlsContext.getCommonTlsContext())) {
try { try {
return SdsSslContextProvider.getProviderForClient( return SdsClientSslContextProvider.getProvider(
upstreamTlsContext, upstreamTlsContext,
Bootstrapper.getInstance().readBootstrap().getNode(), Bootstrapper.getInstance().readBootstrap().getNode(),
Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() Executors.newSingleThreadExecutor(new ThreadFactoryBuilder()

View File

@ -16,9 +16,16 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.ValidationContextTypeCase;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.core.DataSource.SpecifierCase;
import javax.annotation.Nullable;
/** Class for utility functions for {@link CommonTlsContext}. */ /** Class for utility functions for {@link CommonTlsContext}. */
final class CommonTlsContextUtil { final class CommonTlsContextUtil {
@ -40,4 +47,53 @@ final class CommonTlsContextUtil {
return (commonTlsContext.getTlsCertificatesCount() == 0) return (commonTlsContext.getTlsCertificatesCount() == 0)
&& !commonTlsContext.hasValidationContext(); && !commonTlsContext.hasValidationContext();
} }
@Nullable
static CertificateValidationContext getCertificateValidationContext(
CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext");
ValidationContextTypeCase type = commonTlsContext.getValidationContextTypeCase();
checkState(
type == ValidationContextTypeCase.VALIDATION_CONTEXT
|| type == ValidationContextTypeCase.VALIDATIONCONTEXTTYPE_NOT_SET,
"incorrect ValidationContextTypeCase");
return type == ValidationContextTypeCase.VALIDATION_CONTEXT
? commonTlsContext.getValidationContext()
: null;
}
@Nullable
static CertificateValidationContext validateCertificateContext(
@Nullable CertificateValidationContext certContext, boolean optional) {
if (certContext == null || !certContext.hasTrustedCa()) {
checkArgument(optional, "certContext is required");
return null;
}
checkArgument(
certContext.getTrustedCa().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return certContext;
}
@Nullable
static TlsCertificate validateTlsCertificate(
@Nullable TlsCertificate tlsCertificate, boolean optional) {
if (tlsCertificate == null) {
checkArgument(optional, "tlsCertificate is required");
return null;
}
if (optional
&& (tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.SPECIFIER_NOT_SET)
&& (tlsCertificate.getCertificateChain().getSpecifierCase()
== SpecifierCase.SPECIFIER_NOT_SET)) {
return null;
}
checkArgument(
tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
checkArgument(
tlsCertificate.getCertificateChain().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return tlsCertificate;
}
} }

View File

@ -0,0 +1,40 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
final class DownstreamTlsContextHolder implements TlsContextHolder {
private final DownstreamTlsContext downstreamTlsContext;
DownstreamTlsContextHolder(DownstreamTlsContext downstreamTlsContext) {
this.downstreamTlsContext = checkNotNull(downstreamTlsContext, "downstreamTlsContext");
}
public DownstreamTlsContext getDownstreamTlsContext() {
return downstreamTlsContext;
}
@Override
public CommonTlsContext getCommonTlsContext() {
return downstreamTlsContext.getCommonTlsContext();
}
}

View File

@ -38,7 +38,7 @@ import javax.annotation.concurrent.ThreadSafe;
@ThreadSafe @ThreadSafe
final class ReferenceCountingSslContextProviderMap<K> { final class ReferenceCountingSslContextProviderMap<K> {
private final Map<K, Instance<K>> instances = new HashMap<>(); private final Map<K, Instance> instances = new HashMap<>();
private final SslContextProviderFactory<K> sslContextProviderFactory; private final SslContextProviderFactory<K> sslContextProviderFactory;
ReferenceCountingSslContextProviderMap(SslContextProviderFactory<K> sslContextProviderFactory) { ReferenceCountingSslContextProviderMap(SslContextProviderFactory<K> sslContextProviderFactory) {
@ -51,7 +51,7 @@ final class ReferenceCountingSslContextProviderMap<K> {
* using the provided {@link SslContextProviderFactory&lt;K&gt;} * using the provided {@link SslContextProviderFactory&lt;K&gt;}
*/ */
@CheckReturnValue @CheckReturnValue
public SslContextProvider<K> get(K key) { public SslContextProvider get(K key) {
checkNotNull(key, "key"); checkNotNull(key, "key");
return getInternal(key); return getInternal(key);
} }
@ -65,19 +65,20 @@ final class ReferenceCountingSslContextProviderMap<K> {
* <p>Caller must not release a reference more than once. It's advised that you clear the * <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method. * reference to the instance with the null returned by this method.
* *
* @param key for the instance to be released
* @param value the instance to be released * @param value the instance to be released
* @return a null which the caller can use to clear the reference to that instance. * @return a null which the caller can use to clear the reference to that instance.
*/ */
public SslContextProvider<K> release(final SslContextProvider<K> value) { public SslContextProvider release(K key, SslContextProvider value) {
checkNotNull(key, "key");
checkNotNull(value, "value"); checkNotNull(value, "value");
K key = value.getSource();
return releaseInternal(key, value); return releaseInternal(key, value);
} }
private synchronized SslContextProvider<K> getInternal(K key) { private synchronized SslContextProvider getInternal(K key) {
Instance<K> instance = instances.get(key); Instance instance = instances.get(key);
if (instance == null) { if (instance == null) {
instance = new Instance<>(sslContextProviderFactory.createSslContextProvider(key)); instance = new Instance(sslContextProviderFactory.createSslContextProvider(key));
instances.put(key, instance); instances.put(key, instance);
return instance.sslContextProvider; return instance.sslContextProvider;
} else { } else {
@ -85,9 +86,8 @@ final class ReferenceCountingSslContextProviderMap<K> {
} }
} }
private synchronized SslContextProvider<K> releaseInternal( private synchronized SslContextProvider releaseInternal(K key, SslContextProvider instance) {
final K key, final SslContextProvider<K> instance) { Instance cached = instances.get(key);
final Instance<K> cached = instances.get(key);
checkArgument(cached != null, "No cached instance found for %s", key); checkArgument(cached != null, "No cached instance found for %s", key);
checkArgument(instance == cached.sslContextProvider, "Releasing the wrong instance"); checkArgument(instance == cached.sslContextProvider, "Releasing the wrong instance");
if (cached.release()) { if (cached.release()) {
@ -103,15 +103,15 @@ final class ReferenceCountingSslContextProviderMap<K> {
/** A factory to create an SslContextProvider from the given key. */ /** A factory to create an SslContextProvider from the given key. */
public interface SslContextProviderFactory<K> { public interface SslContextProviderFactory<K> {
SslContextProvider<K> createSslContextProvider(K key); SslContextProvider createSslContextProvider(K key);
} }
private static class Instance<K> { private static class Instance {
final SslContextProvider<K> sslContextProvider; final SslContextProvider sslContextProvider;
private int refCount; private int refCount;
/** Increment refCount and acquire a reference to sslContextProvider. */ /** Increment refCount and acquire a reference to sslContextProvider. */
SslContextProvider<K> acquire() { SslContextProvider acquire() {
refCount++; refCount++;
return sslContextProvider; return sslContextProvider;
} }
@ -122,7 +122,7 @@ final class ReferenceCountingSslContextProviderMap<K> {
return --refCount == 0; return --refCount == 0;
} }
Instance(SslContextProvider<K> sslContextProvider) { Instance(SslContextProvider sslContextProvider) {
this.sslContextProvider = sslContextProvider; this.sslContextProvider = sslContextProvider;
this.refCount = 1; this.refCount = 1;
} }

View File

@ -0,0 +1,107 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
/** A client SslContext provider that uses SDS to fetch secrets. */
final class SdsClientSslContextProvider extends SdsSslContextProvider {
private SdsClientSslContextProvider(
Node node,
SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig,
CertificateValidationContext staticCertValidationContext,
Executor watcherExecutor,
Executor channelExecutor,
UpstreamTlsContext upstreamTlsContext) {
super(node,
certSdsConfig,
validationContextSdsConfig,
staticCertValidationContext,
watcherExecutor,
channelExecutor, new UpstreamTlsContextHolder(upstreamTlsContext));
}
static SdsClientSslContextProvider getProvider(
UpstreamTlsContext upstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
SdsSecretConfig validationContextSdsConfig = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig =
combinedValidationContext.getValidationContextSdsSecretConfig();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
return new SdsClientSslContextProvider(
node,
certSdsConfig,
validationContextSdsConfig,
staticCertValidationContext,
watcherExecutor,
channelExecutor,
upstreamTlsContext);
}
@Override
SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forClient()
.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
if (tlsCertificate != null) {
sslContextBuilder.keyManager(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null);
}
return sslContextBuilder;
}
}

View File

@ -192,7 +192,7 @@ public final class SdsProtocolNegotiators {
final BufferReadsHandler bufferReads = new BufferReadsHandler(); final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads); ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
final SslContextProvider<UpstreamTlsContext> sslContextProvider = final SslContextProvider sslContextProvider =
TlsContextManagerImpl.getInstance() TlsContextManagerImpl.getInstance()
.findOrCreateClientSslContextProvider(upstreamTlsContext); .findOrCreateClientSslContextProvider(upstreamTlsContext);
@ -349,7 +349,7 @@ public final class SdsProtocolNegotiators {
final BufferReadsHandler bufferReads = new BufferReadsHandler(); final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads); ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
final SslContextProvider<DownstreamTlsContext> sslContextProvider = final SslContextProvider sslContextProvider =
TlsContextManagerImpl.getInstance() TlsContextManagerImpl.getInstance()
.findOrCreateServerSslContextProvider(downstreamTlsContext); .findOrCreateServerSslContextProvider(downstreamTlsContext);

View File

@ -0,0 +1,91 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.netty.GrpcSslContexts;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
/** A server SslContext provider that uses SDS to fetch secrets. */
final class SdsServerSslContextProvider extends SdsSslContextProvider {
private SdsServerSslContextProvider(
Node node,
SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig,
Executor watcherExecutor,
Executor channelExecutor,
DownstreamTlsContext downstreamTlsContext) {
super(node,
certSdsConfig,
validationContextSdsConfig,
null,
watcherExecutor,
channelExecutor, new DownstreamTlsContextHolder(downstreamTlsContext));
}
static SdsServerSslContextProvider getProvider(
DownstreamTlsContext downstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
SdsSecretConfig validationContextSdsConfig = null;
if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
}
return new SdsServerSslContextProvider(
node,
certSdsConfig,
validationContextSdsConfig,
watcherExecutor,
channelExecutor,
downstreamTlsContext);
}
@Override
SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forServer(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword()
? tlsCertificate.getPassword().getInlineString()
: null);
setClientAuthValues(sslContextBuilder, localCertValidationContext);
return sslContextBuilder;
}
}

View File

@ -21,16 +21,11 @@ import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.Secret; import io.envoyproxy.envoy.api.v2.auth.Secret;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.Node; import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.ApplicationProtocolConfig; import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslContextBuilder;
@ -44,12 +39,8 @@ import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.annotation.Nullable; import javax.annotation.Nullable;
/** /** Base class for SdsClientSslContextProvider and SdsServerSslContextProvider. */
* An SslContext provider that uses SDS to fetch secrets. Used for both server and client abstract class SdsSslContextProvider extends SslContextProvider implements SdsClient.SecretWatcher {
* SslContexts
*/
final class SdsSslContextProvider<K> extends SslContextProvider<K>
implements SdsClient.SecretWatcher {
private static final Logger logger = Logger.getLogger(SdsSslContextProvider.class.getName()); private static final Logger logger = Logger.getLogger(SdsSslContextProvider.class.getName());
@ -59,20 +50,19 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
@Nullable private final SdsSecretConfig validationContextSdsConfig; @Nullable private final SdsSecretConfig validationContextSdsConfig;
@Nullable private final CertificateValidationContext staticCertificateValidationContext; @Nullable private final CertificateValidationContext staticCertificateValidationContext;
private final List<CallbackPair> pendingCallbacks = new ArrayList<>(); private final List<CallbackPair> pendingCallbacks = new ArrayList<>();
@Nullable private TlsCertificate tlsCertificate; @Nullable protected TlsCertificate tlsCertificate;
@Nullable private CertificateValidationContext certificateValidationContext; @Nullable private CertificateValidationContext certificateValidationContext;
@Nullable private SslContext sslContext; @Nullable private SslContext sslContext;
private SdsSslContextProvider( SdsSslContextProvider(
Node node, Node node,
SdsSecretConfig certSdsConfig, SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig, SdsSecretConfig validationContextSdsConfig,
CertificateValidationContext staticCertValidationContext, CertificateValidationContext staticCertValidationContext,
Executor watcherExecutor, Executor watcherExecutor,
Executor channelExecutor, Executor channelExecutor,
boolean server, TlsContextHolder tlsContextHolder) {
K source) { super(tlsContextHolder);
super(source, server);
this.certSdsConfig = certSdsConfig; this.certSdsConfig = certSdsConfig;
this.validationContextSdsConfig = validationContextSdsConfig; this.validationContextSdsConfig = validationContextSdsConfig;
this.staticCertificateValidationContext = staticCertValidationContext; this.staticCertificateValidationContext = staticCertValidationContext;
@ -95,73 +85,6 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
} }
} }
static SdsSslContextProvider<UpstreamTlsContext> getProviderForClient(
UpstreamTlsContext upstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
SdsSecretConfig validationContextSdsConfig = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig =
combinedValidationContext.getValidationContextSdsSecretConfig();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
return new SdsSslContextProvider<>(
node,
certSdsConfig,
validationContextSdsConfig,
staticCertValidationContext,
watcherExecutor,
channelExecutor,
false,
upstreamTlsContext);
}
static SdsSslContextProvider<DownstreamTlsContext> getProviderForServer(
DownstreamTlsContext downstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
SdsSecretConfig validationContextSdsConfig = null;
if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
}
return new SdsSslContextProvider<>(
node,
certSdsConfig,
validationContextSdsConfig,
null,
watcherExecutor,
channelExecutor,
true,
downstreamTlsContext);
}
@Override @Override
public void addCallback(Callback callback, Executor executor) { public void addCallback(Callback callback, Executor executor) {
checkNotNull(callback, "callback"); checkNotNull(callback, "callback");
@ -219,34 +142,16 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
} }
} }
/** Gets a server or client side SslContextBuilder. */
abstract SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException;
// this gets called only when requested secrets are ready... // this gets called only when requested secrets are ready...
private void updateSslContext() { private void updateSslContext() {
try { try {
SslContextBuilder sslContextBuilder; CertificateValidationContext localCertValidationContext = mergeStaticAndDynamicCertContexts();
CertificateValidationContext localCertValidationContext = SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext);
mergeStaticAndDynamicCertContexts();
if (server) {
logger.log(Level.FINEST, "for server");
sslContextBuilder =
GrpcSslContexts.forServer(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword()
? tlsCertificate.getPassword().getInlineString()
: null);
setClientAuthValues(sslContextBuilder, localCertValidationContext);
} else {
logger.log(Level.FINEST, "for client");
sslContextBuilder =
GrpcSslContexts.forClient()
.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
if (tlsCertificate != null) {
sslContextBuilder.keyManager(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null);
}
}
CommonTlsContext commonTlsContext = getCommonTlsContext(); CommonTlsContext commonTlsContext = getCommonTlsContext();
if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) {
List<String> alpnList = commonTlsContext.getAlpnProtocolsList(); List<String> alpnList = commonTlsContext.getAlpnProtocolsList();

View File

@ -0,0 +1,125 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateTlsCertificate;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** A client SslContext provider that uses file-based secrets (secret volume). */
final class SecretVolumeClientSslContextProvider extends SslContextProvider {
@Nullable private final String privateKey;
@Nullable private final String privateKeyPassword;
@Nullable private final String certificateChain;
@Nullable private final CertificateValidationContext certContext;
private SecretVolumeClientSslContextProvider(
@Nullable String privateKey,
@Nullable String privateKeyPassword,
@Nullable String certificateChain,
@Nullable CertificateValidationContext certContext,
UpstreamTlsContext upstreamTlsContext) {
super(new UpstreamTlsContextHolder(upstreamTlsContext));
this.privateKey = privateKey;
this.privateKeyPassword = privateKeyPassword;
this.certificateChain = certificateChain;
this.certContext = certContext;
}
static SecretVolumeClientSslContextProvider getProvider(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// first validate
validateCertificateContext(certificateValidationContext, /* optional= */ false);
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// tlsCertificate exists in case of mTLS, else null for a client
if (tlsCertificate != null) {
tlsCertificate = validateTlsCertificate(tlsCertificate, /* optional= */ true);
}
String privateKey = null;
String privateKeyPassword = null;
String certificateChain = null;
if (tlsCertificate != null) {
privateKey = tlsCertificate.getPrivateKey().getFilename();
if (tlsCertificate.hasPassword()) {
privateKeyPassword = tlsCertificate.getPassword().getInlineString();
}
certificateChain = tlsCertificate.getCertificateChain().getFilename();
}
return new SecretVolumeClientSslContextProvider(
privateKey,
privateKeyPassword,
certificateChain,
certificateValidationContext,
upstreamTlsContext);
}
@Override
public void addCallback(final Callback callback, Executor executor) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
new SslContextGetter() {
@Override
public SslContext get() throws CertificateException, IOException, CertStoreException {
return buildSslContextFromSecrets();
}
},
callback,
executor);
}
@Override
public void close() {}
@VisibleForTesting
SslContext buildSslContextFromSecrets()
throws IOException, CertificateException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext));
if (privateKey != null && certificateChain != null) {
sslContextBuilder.keyManager(
new File(certificateChain), new File(privateKey), privateKeyPassword);
}
return sslContextBuilder.build();
}
}

View File

@ -0,0 +1,116 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateTlsCertificate;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.grpc.netty.GrpcSslContexts;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** A server SslContext provider that uses file-based secrets (secret volume). */
final class SecretVolumeServerSslContextProvider extends SslContextProvider {
@Nullable private final String privateKey;
@Nullable private final String privateKeyPassword;
@Nullable private final String certificateChain;
@Nullable private final CertificateValidationContext certContext;
private SecretVolumeServerSslContextProvider(
@Nullable String privateKey,
@Nullable String privateKeyPassword,
@Nullable String certificateChain,
@Nullable CertificateValidationContext certContext,
DownstreamTlsContext downstreamTlsContext) {
super(new DownstreamTlsContextHolder(downstreamTlsContext));
this.privateKey = privateKey;
this.privateKeyPassword = privateKeyPassword;
this.certificateChain = certificateChain;
this.certContext = certContext;
}
static SecretVolumeServerSslContextProvider getProvider(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// first validate
validateTlsCertificate(tlsCertificate, /* optional= */ false);
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// certificateValidationContext exists in case of mTLS, else null for a server
if (certificateValidationContext != null) {
certificateValidationContext =
validateCertificateContext(certificateValidationContext, /* optional= */ true);
}
String privateKeyPassword =
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null;
return new SecretVolumeServerSslContextProvider(
tlsCertificate.getPrivateKey().getFilename(),
privateKeyPassword,
tlsCertificate.getCertificateChain().getFilename(),
certificateValidationContext,
downstreamTlsContext);
}
@Override
public void addCallback(final Callback callback, Executor executor) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
new SslContextGetter() {
@Override
public SslContext get() throws CertificateException, IOException, CertStoreException {
return buildSslContextFromSecrets();
}
},
callback,
executor);
}
@Override
public void close() {}
@VisibleForTesting
SslContext buildSslContextFromSecrets()
throws IOException, CertificateException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forServer(
new File(certificateChain), new File(privateKey), privateKeyPassword);
setClientAuthValues(sslContextBuilder, certContext);
return sslContextBuilder.build();
}
}

View File

@ -1,219 +0,0 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.ValidationContextTypeCase;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.DataSource.SpecifierCase;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/**
* An SslContext provider that uses file-based secrets (secret volume). Used for both server and
* client SslContexts
*/
final class SecretVolumeSslContextProvider<K> extends SslContextProvider<K> {
@Nullable private final String privateKey;
@Nullable private final String privateKeyPassword;
@Nullable private final String certificateChain;
@Nullable private final CertificateValidationContext certContext;
private SecretVolumeSslContextProvider(
@Nullable String privateKey,
@Nullable String privateKeyPassword,
@Nullable String certificateChain,
@Nullable CertificateValidationContext certContext,
boolean server,
K source) {
super(source, server);
this.privateKey = privateKey;
this.privateKeyPassword = privateKeyPassword;
this.certificateChain = certificateChain;
this.certContext = certContext;
}
@VisibleForTesting
@Nullable
static CertificateValidationContext validateCertificateContext(
@Nullable CertificateValidationContext certContext, boolean optional) {
if (certContext == null || !certContext.hasTrustedCa()) {
checkArgument(optional, "certContext is required");
return null;
}
checkArgument(
certContext.getTrustedCa().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return certContext;
}
@VisibleForTesting
@Nullable
static TlsCertificate validateTlsCertificate(
@Nullable TlsCertificate tlsCertificate, boolean optional) {
if (tlsCertificate == null) {
checkArgument(optional, "tlsCertificate is required");
return null;
}
if (optional
&& (tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.SPECIFIER_NOT_SET)
&& (tlsCertificate.getCertificateChain().getSpecifierCase()
== SpecifierCase.SPECIFIER_NOT_SET)) {
return null;
}
checkArgument(
tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
checkArgument(
tlsCertificate.getCertificateChain().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return tlsCertificate;
}
static SecretVolumeSslContextProvider<DownstreamTlsContext> getProviderForServer(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// first validate
validateTlsCertificate(tlsCertificate, /* optional= */ false);
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// certificateValidationContext exists in case of mTLS, else null for a server
if (certificateValidationContext != null) {
certificateValidationContext =
validateCertificateContext(certificateValidationContext, /* optional= */ true);
}
String privateKeyPassword =
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null;
return new SecretVolumeSslContextProvider<>(
tlsCertificate.getPrivateKey().getFilename(),
privateKeyPassword,
tlsCertificate.getCertificateChain().getFilename(),
certificateValidationContext,
/* server= */ true,
downstreamTlsContext);
}
static SecretVolumeSslContextProvider<UpstreamTlsContext> getProviderForClient(
UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// first validate
validateCertificateContext(certificateValidationContext, /* optional= */ false);
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// tlsCertificate exists in case of mTLS, else null for a client
if (tlsCertificate != null) {
tlsCertificate = validateTlsCertificate(tlsCertificate, /* optional= */ true);
}
String privateKey = null;
String privateKeyPassword = null;
String certificateChain = null;
if (tlsCertificate != null) {
privateKey = tlsCertificate.getPrivateKey().getFilename();
if (tlsCertificate.hasPassword()) {
privateKeyPassword = tlsCertificate.getPassword().getInlineString();
}
certificateChain = tlsCertificate.getCertificateChain().getFilename();
}
return new SecretVolumeSslContextProvider<>(
privateKey,
privateKeyPassword,
certificateChain,
certificateValidationContext,
/* server= */ false,
upstreamTlsContext);
}
private static CertificateValidationContext getCertificateValidationContext(
CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext");
ValidationContextTypeCase type = commonTlsContext.getValidationContextTypeCase();
checkState(
type == ValidationContextTypeCase.VALIDATION_CONTEXT
|| type == ValidationContextTypeCase.VALIDATIONCONTEXTTYPE_NOT_SET,
"incorrect ValidationContextTypeCase");
return type == ValidationContextTypeCase.VALIDATION_CONTEXT
? commonTlsContext.getValidationContext()
: null;
}
@Override
public void addCallback(final Callback callback, Executor executor) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
new SslContextGetter() {
@Override
public SslContext get() throws CertificateException, IOException, CertStoreException {
return buildSslContextFromSecrets();
}
},
callback,
executor);
}
@Override
public void close() {}
@VisibleForTesting
SslContext buildSslContextFromSecrets()
throws IOException, CertificateException, CertStoreException {
SslContextBuilder sslContextBuilder;
if (server) {
sslContextBuilder =
GrpcSslContexts.forServer(
new File(certificateChain), new File(privateKey), privateKeyPassword);
setClientAuthValues(sslContextBuilder, certContext);
} else {
sslContextBuilder =
GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext));
if (privateKey != null && certificateChain != null) {
sslContextBuilder.keyManager(
new File(certificateChain), new File(privateKey), privateKeyPassword);
}
}
return sslContextBuilder.build();
}
}

View File

@ -30,9 +30,9 @@ import java.util.concurrent.Executors;
final class ServerSslContextProviderFactory final class ServerSslContextProviderFactory
implements SslContextProviderFactory<DownstreamTlsContext> { implements SslContextProviderFactory<DownstreamTlsContext> {
/** Creates an SslContextProvider from the given DownstreamTlsContext. */ /** Creates a SslContextProvider from the given DownstreamTlsContext. */
@Override @Override
public SslContextProvider<DownstreamTlsContext> createSslContextProvider( public SslContextProvider createSslContextProvider(
DownstreamTlsContext downstreamTlsContext) { DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext"); checkNotNull(downstreamTlsContext, "downstreamTlsContext");
checkArgument( checkArgument(
@ -40,11 +40,11 @@ final class ServerSslContextProviderFactory
"downstreamTlsContext should have CommonTlsContext"); "downstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename( if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
downstreamTlsContext.getCommonTlsContext())) { downstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeSslContextProvider.getProviderForServer(downstreamTlsContext); return SecretVolumeServerSslContextProvider.getProvider(downstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds( } else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
downstreamTlsContext.getCommonTlsContext())) { downstreamTlsContext.getCommonTlsContext())) {
try { try {
return SdsSslContextProvider.getProviderForServer( return SdsServerSslContextProvider.getProvider(
downstreamTlsContext, downstreamTlsContext,
Bootstrapper.getInstance().readBootstrap().getNode(), Bootstrapper.getInstance().readBootstrap().getNode(),
Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() Executors.newSingleThreadExecutor(new ThreadFactoryBuilder()

View File

@ -16,7 +16,6 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
@ -41,14 +40,11 @@ import java.util.logging.Logger;
* stream that is receiving the requested secret(s) or it could represent file-system based * stream that is receiving the requested secret(s) or it could represent file-system based
* secret(s) that are dynamic. * secret(s) that are dynamic.
*/ */
// TODO(sanjaypujare): replace generic K with DownstreamTlsContext & UpstreamTlsContext in public abstract class SslContextProvider {
// separate client&server classes
public abstract class SslContextProvider<K> {
private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName()); private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName());
protected final boolean server; protected final TlsContextHolder tlsContextHolder;
private final K source;
public interface Callback { public interface Callback {
/** Informs callee of new/updated SslContext. */ /** Informs callee of new/updated SslContext. */
@ -58,36 +54,20 @@ public abstract class SslContextProvider<K> {
void onException(Throwable throwable); void onException(Throwable throwable);
} }
protected SslContextProvider(K source, boolean server) { SslContextProvider(TlsContextHolder tlsContextHolder) {
if (server) { this.tlsContextHolder = checkNotNull(tlsContextHolder, "tlsContextHolder");
checkArgument(source instanceof DownstreamTlsContext, "expecting DownstreamTlsContext");
} else {
checkArgument(source instanceof UpstreamTlsContext, "expecting UpstreamTlsContext");
}
this.source = source;
this.server = server;
}
public K getSource() {
return source;
} }
CommonTlsContext getCommonTlsContext() { CommonTlsContext getCommonTlsContext() {
if (source instanceof UpstreamTlsContext) { return tlsContextHolder.getCommonTlsContext();
return ((UpstreamTlsContext) source).getCommonTlsContext();
} else if (source instanceof DownstreamTlsContext) {
return ((DownstreamTlsContext) source).getCommonTlsContext();
}
return null;
} }
protected void setClientAuthValues( protected void setClientAuthValues(
SslContextBuilder sslContextBuilder, CertificateValidationContext localCertValidationContext) SslContextBuilder sslContextBuilder, CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException { throws CertificateException, IOException, CertStoreException {
checkState(server, "server side SslContextProvider expected"); DownstreamTlsContext downstreamTlsContext = getDownstreamTlsContext();
if (localCertValidationContext != null) { if (localCertValidationContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext)); sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
DownstreamTlsContext downstreamTlsContext = (DownstreamTlsContext)getSource();
sslContextBuilder.clientAuth( sslContextBuilder.clientAuth(
downstreamTlsContext.hasRequireClientCertificate() downstreamTlsContext.hasRequireClientCertificate()
? ClientAuth.REQUIRE ? ClientAuth.REQUIRE
@ -97,6 +77,20 @@ public abstract class SslContextProvider<K> {
} }
} }
/** Returns the DownstreamTlsContext in this SslContextProvider if this is server side. **/
public DownstreamTlsContext getDownstreamTlsContext() {
checkState(tlsContextHolder instanceof DownstreamTlsContextHolder,
"expected DownstreamTlsContextHolder");
return ((DownstreamTlsContextHolder) tlsContextHolder).getDownstreamTlsContext();
}
/** Returns the UpstreamTlsContext in this SslContextProvider if this is client side. **/
public UpstreamTlsContext getUpstreamTlsContext() {
checkState(tlsContextHolder instanceof UpstreamTlsContextHolder,
"expected UpstreamTlsContextHolder");
return ((UpstreamTlsContextHolder) tlsContextHolder).getUpstreamTlsContext();
}
/** Closes this provider and releases any resources. */ /** Closes this provider and releases any resources. */
void close() {} void close() {}
@ -106,7 +100,7 @@ public abstract class SslContextProvider<K> {
*/ */
public abstract void addCallback(Callback callback, Executor executor); public abstract void addCallback(Callback callback, Executor executor);
protected void performCallback( final void performCallback(
final SslContextGetter sslContextGetter, final Callback callback, Executor executor) { final SslContextGetter sslContextGetter, final Callback callback, Executor executor) {
checkNotNull(sslContextGetter, "sslContextGetter"); checkNotNull(sslContextGetter, "sslContextGetter");
checkNotNull(callback, "callback"); checkNotNull(callback, "callback");

View File

@ -0,0 +1,29 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
/**
* A holder of {@link UpstreamTlsContext} or
* {@link io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext}.
*/
public interface TlsContextHolder {
CommonTlsContext getCommonTlsContext();
}

View File

@ -22,11 +22,11 @@ import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
public interface TlsContextManager { public interface TlsContextManager {
/** Creates a SslContextProvider. Used for retrieving a server-side SslContext. */ /** Creates a SslContextProvider. Used for retrieving a server-side SslContext. */
SslContextProvider<DownstreamTlsContext> findOrCreateServerSslContextProvider( SslContextProvider findOrCreateServerSslContextProvider(
DownstreamTlsContext downstreamTlsContext); DownstreamTlsContext downstreamTlsContext);
/** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */ /** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */
SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvider( SslContextProvider findOrCreateClientSslContextProvider(
UpstreamTlsContext upstreamTlsContext); UpstreamTlsContext upstreamTlsContext);
/** /**
@ -38,8 +38,7 @@ public interface TlsContextManager {
* <p>Caller must not release a reference more than once. It's advised that you clear the * <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method. * reference to the instance with the null returned by this method.
*/ */
SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider( SslContextProvider releaseClientSslContextProvider(SslContextProvider sslContextProvider);
SslContextProvider<UpstreamTlsContext> sslContextProvider);
/** /**
* Releases an instance of the given server-side {@link SslContextProvider}. * Releases an instance of the given server-side {@link SslContextProvider}.
@ -50,6 +49,5 @@ public interface TlsContextManager {
* <p>Caller must not release a reference more than once. It's advised that you clear the * <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method. * reference to the instance with the null returned by this method.
*/ */
SslContextProvider<DownstreamTlsContext> releaseServerSslContextProvider( SslContextProvider releaseServerSslContextProvider(SslContextProvider sslContextProvider);
SslContextProvider<DownstreamTlsContext> sslContextProvider);
} }

View File

@ -59,30 +59,32 @@ public final class TlsContextManagerImpl implements TlsContextManager {
} }
@Override @Override
public SslContextProvider<DownstreamTlsContext> findOrCreateServerSslContextProvider( public SslContextProvider findOrCreateServerSslContextProvider(
DownstreamTlsContext downstreamTlsContext) { DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext"); checkNotNull(downstreamTlsContext, "downstreamTlsContext");
return mapForServers.get(downstreamTlsContext); return mapForServers.get(downstreamTlsContext);
} }
@Override @Override
public SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvider( public SslContextProvider findOrCreateClientSslContextProvider(
UpstreamTlsContext upstreamTlsContext) { UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext"); checkNotNull(upstreamTlsContext, "upstreamTlsContext");
return mapForClients.get(upstreamTlsContext); return mapForClients.get(upstreamTlsContext);
} }
@Override @Override
public SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider( public SslContextProvider releaseClientSslContextProvider(
SslContextProvider<UpstreamTlsContext> sslContextProvider) { SslContextProvider clientSslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider"); checkNotNull(clientSslContextProvider, "clientSslContextProvider");
return mapForClients.release(sslContextProvider); return mapForClients.release(clientSslContextProvider.getUpstreamTlsContext(),
clientSslContextProvider);
} }
@Override @Override
public SslContextProvider<DownstreamTlsContext> releaseServerSslContextProvider( public SslContextProvider releaseServerSslContextProvider(
SslContextProvider<DownstreamTlsContext> sslContextProvider) { SslContextProvider serverSslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider"); checkNotNull(serverSslContextProvider, "serverSslContextProvider");
return mapForServers.release(sslContextProvider); return mapForServers.release(serverSslContextProvider.getDownstreamTlsContext(),
serverSslContextProvider);
} }
} }

View File

@ -0,0 +1,40 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
final class UpstreamTlsContextHolder implements TlsContextHolder {
private final UpstreamTlsContext upstreamTlsContext;
UpstreamTlsContextHolder(UpstreamTlsContext upstreamTlsContext) {
this.upstreamTlsContext = checkNotNull(upstreamTlsContext, "upstreamTlsContext");
}
public UpstreamTlsContext getUpstreamTlsContext() {
return upstreamTlsContext;
}
@Override
public CommonTlsContext getCommonTlsContext() {
return upstreamTlsContext.getCommonTlsContext();
}
}

View File

@ -356,9 +356,8 @@ public class CdsLoadBalancerTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> mockSslContextProvider = SslContextProvider mockSslContextProvider = mock(SslContextProvider.class);
(SslContextProvider<UpstreamTlsContext>) mock(SslContextProvider.class); doReturn(upstreamTlsContext).when(mockSslContextProvider).getUpstreamTlsContext();
doReturn(upstreamTlsContext).when(mockSslContextProvider).getSource();
doReturn(mockSslContextProvider).when(mockTlsContextManager) doReturn(mockSslContextProvider).when(mockTlsContextManager)
.findOrCreateClientSslContextProvider(same(upstreamTlsContext)); .findOrCreateClientSslContextProvider(same(upstreamTlsContext));
@ -373,8 +372,8 @@ public class CdsLoadBalancerTest {
assertThat(edsLbHelpers).hasSize(1); assertThat(edsLbHelpers).hasSize(1);
assertThat(edsLoadBalancers).hasSize(1); assertThat(edsLoadBalancers).hasSize(1);
verify(mockTlsContextManager, never()).releaseClientSslContextProvider( verify(mockTlsContextManager, never())
(SslContextProvider<UpstreamTlsContext>) any(SslContextProvider.class)); .releaseClientSslContextProvider(any(SslContextProvider.class));
Helper edsLbHelper1 = edsLbHelpers.poll(); Helper edsLbHelper1 = edsLbHelpers.poll();
ArrayList<EquivalentAddressGroup> eagList = new ArrayList<>(); ArrayList<EquivalentAddressGroup> eagList = new ArrayList<>();
@ -403,8 +402,8 @@ public class CdsLoadBalancerTest {
.setUpstreamTlsContext(upstreamTlsContext) .setUpstreamTlsContext(upstreamTlsContext)
.build()); .build());
verify(mockTlsContextManager, never()).releaseClientSslContextProvider( verify(mockTlsContextManager, never())
(SslContextProvider<UpstreamTlsContext>) any(SslContextProvider.class)); .releaseClientSslContextProvider(any(SslContextProvider.class));
verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider( verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider(
any(UpstreamTlsContext.class)); any(UpstreamTlsContext.class));
@ -414,9 +413,8 @@ public class CdsLoadBalancerTest {
UpstreamTlsContext upstreamTlsContext1 = UpstreamTlsContext upstreamTlsContext1 =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> mockSslContextProvider1 = SslContextProvider mockSslContextProvider1 = mock(SslContextProvider.class);
(SslContextProvider<UpstreamTlsContext>) mock(SslContextProvider.class); doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getUpstreamTlsContext();
doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getSource();
doReturn(mockSslContextProvider1).when(mockTlsContextManager) doReturn(mockSslContextProvider1).when(mockTlsContextManager)
.findOrCreateClientSslContextProvider(same(upstreamTlsContext1)); .findOrCreateClientSslContextProvider(same(upstreamTlsContext1));
clusterWatcher1.onClusterChanged( clusterWatcher1.onClusterChanged(

View File

@ -41,7 +41,7 @@ public class ClientSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> sslContextProvider = SslContextProvider sslContextProvider =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
assertThat(sslContextProvider).isNotNull(); assertThat(sslContextProvider).isNotNull();
} }
@ -55,7 +55,7 @@ public class ClientSslContextProviderFactoryTest {
SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext); SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext);
try { try {
SslContextProvider<UpstreamTlsContext> unused = SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
@ -77,7 +77,7 @@ public class ClientSslContextProviderFactoryTest {
SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext); SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext);
try { try {
SslContextProvider<UpstreamTlsContext> unused = SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {

View File

@ -51,60 +51,53 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test @Test
public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException { public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException {
SslContextProvider<Integer> valueFor3 = getTypedMock(); SslContextProvider valueFor3 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
SslContextProvider<Integer> val = map.get(3); SslContextProvider val = map.get(3);
assertThat(val).isSameInstanceAs(valueFor3); assertThat(val).isSameInstanceAs(valueFor3);
verify(valueFor3, never()).close(); verify(valueFor3, never()).close();
val = map.get(3); val = map.get(3);
assertThat(val).isSameInstanceAs(valueFor3); assertThat(val).isSameInstanceAs(valueFor3);
// at this point ref-count is 2 // at this point ref-count is 2
when(valueFor3.getSource()).thenReturn(3); assertThat(map.release(3, val)).isNull();
assertThat(map.release(val)).isNull();
verify(valueFor3, never()).close(); verify(valueFor3, never()).close();
assertThat(map.release(val)).isNull(); // after this ref-count is 0 assertThat(map.release(3, val)).isNull(); // after this ref-count is 0
verify(valueFor3, times(1)).close(); verify(valueFor3, times(1)).close();
} }
@SuppressWarnings("unchecked") private static SslContextProvider getTypedMock() {
private static SslContextProvider<Integer> getTypedMock() {
return mock(SslContextProvider.class); return mock(SslContextProvider.class);
} }
@Test @Test
public void referenceCountingMap_distinctElements() throws InterruptedException { public void referenceCountingMap_distinctElements() throws InterruptedException {
SslContextProvider<Integer> valueFor3 = getTypedMock(); SslContextProvider valueFor3 = getTypedMock();
SslContextProvider<Integer> valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(valueFor3.getSource()).thenReturn(3);
when(valueFor4.getSource()).thenReturn(4);
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> val3 = map.get(3); SslContextProvider val3 = map.get(3);
assertThat(val3).isSameInstanceAs(valueFor3); assertThat(val3).isSameInstanceAs(valueFor3);
SslContextProvider<Integer> val4 = map.get(4); SslContextProvider val4 = map.get(4);
assertThat(val4).isSameInstanceAs(valueFor4); assertThat(val4).isSameInstanceAs(valueFor4);
assertThat(map.release(val3)).isNull(); assertThat(map.release(3, val3)).isNull();
verify(valueFor3, times(1)).close(); verify(valueFor3, times(1)).close();
verify(valueFor4, never()).close(); verify(valueFor4, never()).close();
assertThat(map.release(val4)).isNull(); assertThat(map.release(4, val4)).isNull();
verify(valueFor4, times(1)).close(); verify(valueFor4, times(1)).close();
} }
@Test @Test
public void referenceCountingMap_releaseWrongElement_expectException() public void referenceCountingMap_releaseWrongElement_expectException()
throws InterruptedException { throws InterruptedException {
SslContextProvider<Integer> valueFor3 = getTypedMock(); SslContextProvider valueFor3 = getTypedMock();
SslContextProvider<Integer> valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(valueFor3.getSource()).thenReturn(3);
when(valueFor4.getSource()).thenReturn(4);
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> unused = map.get(3); SslContextProvider unused = map.get(3);
SslContextProvider<Integer> val4 = map.get(4); SslContextProvider val4 = map.get(4);
// now provide wrong key (3) and value (val4) combination // now provide wrong key (3) and value (val4) combination
when(valueFor4.getSource()).thenReturn(3);
try { try {
map.release(val4); map.release(3, val4);
fail("exception expected"); fail("exception expected");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().contains("Releasing the wrong instance"); assertThat(expected).hasMessageThat().contains("Releasing the wrong instance");
@ -113,16 +106,15 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test @Test
public void referenceCountingMap_excessRelease_expectException() throws InterruptedException { public void referenceCountingMap_excessRelease_expectException() throws InterruptedException {
SslContextProvider<Integer> valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(valueFor4.getSource()).thenReturn(4);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> val = map.get(4); SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4); assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1 // at this point ref-count is 1
map.release(val); map.release(4, val);
// at this point ref-count is 0 // at this point ref-count is 0
try { try {
map.release(val); map.release(4, val);
fail("exception expected"); fail("exception expected");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().contains("No cached instance found for 4"); assertThat(expected).hasMessageThat().contains("No cached instance found for 4");
@ -131,16 +123,15 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test @Test
public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException { public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException {
SslContextProvider<Integer> valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(valueFor4.getSource()).thenReturn(4);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> val = map.get(4); SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4); assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1 // at this point ref-count is 1
map.release(val); map.release(4, val);
// at this point ref-count is 0 and val is removed // at this point ref-count is 0 and val is removed
// should get another instance for 4 // should get another instance for 4
SslContextProvider<Integer> valueFor4a = getTypedMock(); SslContextProvider valueFor4a = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4a); when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4a);
val = map.get(4); val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4a); assertThat(val).isSameInstanceAs(valueFor4a);

View File

@ -40,7 +40,7 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
/** Unit tests for {@link SdsSslContextProvider}. */ /** Unit tests for {@link SdsClientSslContextProvider}. */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class SdsSslContextProviderTest { public class SdsSslContextProviderTest {
@ -62,10 +62,13 @@ public class SdsSslContextProviderTest {
server.shutdown(); server.shutdown();
} }
/** Helper method to build SdsSslContextProvider from given names. */ /** Helper method to build SdsClientSslContextProvider from given names. */
private SdsSslContextProvider<?> getSdsSslContextProvider( private SdsClientSslContextProvider getSdsClientSslContextProvider(
boolean server, String certName, String validationContextName, String certName,
Iterable<String> verifySubjectAltNames, Iterable<String> alpnProtocols) throws IOException { String validationContextName,
Iterable<String> verifySubjectAltNames,
Iterable<String> alpnProtocols)
throws IOException {
CommonTlsContext commonTlsContext = CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues( CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
@ -77,18 +80,37 @@ public class SdsSslContextProviderTest {
alpnProtocols, alpnProtocols,
/* channelType= */ "inproc"); /* channelType= */ "inproc");
return server return SdsClientSslContextProvider.getProvider(
? SdsSslContextProvider.getProviderForServer( SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext),
CommonTlsContextTestsUtil.buildDownstreamTlsContext( node,
commonTlsContext, /* requireClientCert= */ false), MoreExecutors.directExecutor(),
node, MoreExecutors.directExecutor());
MoreExecutors.directExecutor(), }
MoreExecutors.directExecutor())
: SdsSslContextProvider.getProviderForClient( /** Helper method to build SdsServerSslContextProvider from given names. */
SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext), private SdsServerSslContextProvider getSdsServerSslContextProvider(
node, String certName,
MoreExecutors.directExecutor(), String validationContextName,
MoreExecutors.directExecutor()); Iterable<String> verifySubjectAltNames,
Iterable<String> alpnProtocols)
throws IOException {
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
certName,
/* certTargetUri= */ "inproc",
validationContextName,
/* validationContextTargetUri= */ "inproc",
verifySubjectAltNames,
alpnProtocols,
/* channelType= */ "inproc");
return SdsServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor());
} }
@Test @Test
@ -98,8 +120,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1")) when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider = SdsServerSslContextProvider provider =
getSdsSslContextProvider(/* server= */ true, "cert1", "valid1", null, null); getSdsServerSslContextProvider("cert1", "valid1", null, null);
SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider); SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -113,9 +135,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor("valid1")) when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider = SdsClientSslContextProvider provider =
getSdsSslContextProvider( getSdsClientSslContextProvider(
/* server= */ false,
/* certName= */ "cert1", /* certName= */ "cert1",
/* validationContextName= */ "valid1", /* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null, /* verifySubjectAltNames= */ null,
@ -131,10 +152,12 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "cert1")) when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE)); .thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
SdsSslContextProvider<?> provider = SdsServerSslContextProvider provider =
getSdsSslContextProvider( getSdsServerSslContextProvider(
/* server= */ true, /* certName= */ "cert1", /* validationContextName= */ null, /* certName= */ "cert1",
/* verifySubjectAltNames= */ null, /* alpnProtocols= */ null); /* validationContextName= */ null,
/* verifySubjectAltNames= */ null,
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider); SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -146,10 +169,12 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1")) when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider = SdsClientSslContextProvider provider =
getSdsSslContextProvider( getSdsClientSslContextProvider(
/* server= */ false, /* certName= */ null, /* validationContextName= */ "valid1", /* certName= */ null,
/* verifySubjectAltNames= */ null, null); /* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null,
null);
SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider); SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -161,10 +186,12 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1")) when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider = SdsServerSslContextProvider provider =
getSdsSslContextProvider( getSdsServerSslContextProvider(
/* server= */ true, /* certName= */ null, /* validationContextName= */ "valid1", /* certName= */ null,
/* verifySubjectAltNames= */ null, /* alpnProtocols= */ null); /* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null,
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider); SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -184,13 +211,11 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor("valid1")) when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider = SdsClientSslContextProvider provider =
getSdsSslContextProvider( getSdsClientSslContextProvider(
/* server= */ false,
/* certName= */ "cert1", /* certName= */ "cert1",
/* validationContextName= */ "valid1", /* validationContextName= */ "valid1",
Arrays.asList( Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
"spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
/* alpnProtocols= */ null); /* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback = SecretVolumeSslContextProviderTest.TestCallback testCallback =
@ -205,9 +230,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor("valid1")) when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider = SdsClientSslContextProvider provider =
getSdsSslContextProvider( getSdsClientSslContextProvider(
/* server= */ false,
/* certName= */ "cert1", /* certName= */ "cert1",
/* validationContextName= */ "valid1", /* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null, /* verifySubjectAltNames= */ null,
@ -226,9 +250,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1")) when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider = SdsServerSslContextProvider provider =
getSdsSslContextProvider( getSdsServerSslContextProvider(
/* server= */ true,
/* certName= */ "cert1", /* certName= */ "cert1",
/* validationContextName= */ "valid1", /* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null, /* verifySubjectAltNames= */ null,

View File

@ -41,7 +41,7 @@ import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
/** Unit tests for {@link SecretVolumeSslContextProvider}. */ /** Unit tests for {@link SecretVolumeClientSslContextProvider}. */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class SecretVolumeSslContextProviderTest { public class SecretVolumeSslContextProviderTest {
@ -51,7 +51,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateCertificateContext_nullAndNotOptional_throwsException() { public void validateCertificateContext_nullAndNotOptional_throwsException() {
// expect exception when certContext is null and not optional // expect exception when certContext is null and not optional
try { try {
SecretVolumeSslContextProvider.validateCertificateContext( CommonTlsContextUtil.validateCertificateContext(
/* certContext= */ null, /* optional= */ false); /* certContext= */ null, /* optional= */ false);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
@ -64,8 +64,7 @@ public class SecretVolumeSslContextProviderTest {
// expect exception when certContext has no CA and not optional // expect exception when certContext has no CA and not optional
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
try { try {
SecretVolumeSslContextProvider.validateCertificateContext( CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false);
certContext, /* optional= */ false);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("certContext is required"); assertThat(expected).hasMessageThat().isEqualTo("certContext is required");
@ -76,7 +75,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateCertificateContext_nullAndOptional() { public void validateCertificateContext_nullAndOptional() {
// certContext argument can be null when optional // certContext argument can be null when optional
CertificateValidationContext certContext = CertificateValidationContext certContext =
SecretVolumeSslContextProvider.validateCertificateContext( CommonTlsContextUtil.validateCertificateContext(
/* certContext= */ null, /* optional= */ true); /* certContext= */ null, /* optional= */ true);
assertThat(certContext).isNull(); assertThat(certContext).isNull();
} }
@ -85,9 +84,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateCertificateContext_missingTrustCaOptional() { public void validateCertificateContext_missingTrustCaOptional() {
// certContext argument can have missing CA when optional // certContext argument can have missing CA when optional
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
assertThat( assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ true))
SecretVolumeSslContextProvider.validateCertificateContext(
certContext, /* optional= */ true))
.isNull(); .isNull();
} }
@ -99,8 +96,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setInlineString("foo")) .setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.validateCertificateContext( CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false);
certContext, /* optional= */ false);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -114,9 +110,7 @@ public class SecretVolumeSslContextProviderTest {
CertificateValidationContext.newBuilder() CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename("bar")) .setTrustedCa(DataSource.newBuilder().setFilename("bar"))
.build(); .build();
assertThat( assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false))
SecretVolumeSslContextProvider.validateCertificateContext(
certContext, /* optional= */ false))
.isSameInstanceAs(certContext); .isSameInstanceAs(certContext);
} }
@ -124,7 +118,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateTlsCertificate_nullAndNotOptional_throwsException() { public void validateTlsCertificate_nullAndNotOptional_throwsException() {
// expect exception when tlsCertificate is null and not optional // expect exception when tlsCertificate is null and not optional
try { try {
SecretVolumeSslContextProvider.validateTlsCertificate( CommonTlsContextUtil.validateTlsCertificate(
/* tlsCertificate= */ null, /* optional= */ false); /* tlsCertificate= */ null, /* optional= */ false);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
@ -135,7 +129,7 @@ public class SecretVolumeSslContextProviderTest {
@Test @Test
public void validateTlsCertificate_nullOptional() { public void validateTlsCertificate_nullOptional() {
assertThat( assertThat(
SecretVolumeSslContextProvider.validateTlsCertificate( CommonTlsContextUtil.validateTlsCertificate(
/* tlsCertificate= */ null, /* optional= */ true)) /* tlsCertificate= */ null, /* optional= */ true))
.isNull(); .isNull();
} }
@ -144,10 +138,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateTlsCertificate_defaultInstance_returnsNull() { public void validateTlsCertificate_defaultInstance_returnsNull() {
// tlsCertificate is not null but has no value (default instance): expect null // tlsCertificate is not null but has no value (default instance): expect null
TlsCertificate tlsCert = TlsCertificate.getDefaultInstance(); TlsCertificate tlsCert = TlsCertificate.getDefaultInstance();
assertThat( assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true)).isNull();
SecretVolumeSslContextProvider.validateTlsCertificate(
tlsCert, /* optional= */ true))
.isNull();
} }
@Test @Test
@ -158,7 +149,7 @@ public class SecretVolumeSslContextProviderTest {
.setPrivateKey(DataSource.newBuilder().setInlineString("foo")) .setPrivateKey(DataSource.newBuilder().setInlineString("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ false); CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -173,7 +164,7 @@ public class SecretVolumeSslContextProviderTest {
.setPrivateKey(DataSource.newBuilder().setInlineString("foo")) .setPrivateKey(DataSource.newBuilder().setInlineString("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -188,7 +179,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setInlineString("foo")) .setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ false); CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -203,7 +194,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setInlineString("foo")) .setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -217,9 +208,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setFilename("foo")) .setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar")) .setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build(); .build();
assertThat( assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true))
SecretVolumeSslContextProvider.validateTlsCertificate(
tlsCert, /* optional= */ true))
.isSameInstanceAs(tlsCert); .isSameInstanceAs(tlsCert);
} }
@ -230,9 +219,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setFilename("foo")) .setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar")) .setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build(); .build();
assertThat( assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false))
SecretVolumeSslContextProvider.validateTlsCertificate(
tlsCert, /* optional= */ false))
.isSameInstanceAs(tlsCert); .isSameInstanceAs(tlsCert);
} }
@ -245,7 +232,7 @@ public class SecretVolumeSslContextProviderTest {
.setPrivateKey(DataSource.newBuilder().setFilename("bar")) .setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -261,7 +248,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setFilename("bar")) .setCertificateChain(DataSource.newBuilder().setFilename("bar"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true); CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -272,7 +259,7 @@ public class SecretVolumeSslContextProviderTest {
public void getProviderForServer_defaultTlsCertificate_throwsException() { public void getProviderForServer_defaultTlsCertificate_throwsException() {
TlsCertificate tlsCert = TlsCertificate.getDefaultInstance(); TlsCertificate tlsCert = TlsCertificate.getDefaultInstance();
try { try {
SecretVolumeSslContextProvider.getProviderForServer( SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContext( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null), CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null),
/* requireClientCert= */ false)); /* requireClientCert= */ false));
@ -294,7 +281,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setInlineString("foo")) .setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.getProviderForServer( SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContext( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext), CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext),
/* requireClientCert= */ false)); /* requireClientCert= */ false));
@ -308,7 +295,7 @@ public class SecretVolumeSslContextProviderTest {
public void getProviderForClient_defaultCertContext_throwsException() { public void getProviderForClient_defaultCertContext_throwsException() {
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
try { try {
SecretVolumeSslContextProvider.getProviderForClient( SecretVolumeClientSslContextProvider.getProvider(
buildUpstreamTlsContext( buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext( CommonTlsContextTestsUtil.getCommonTlsContext(
/* tlsCertificate= */ null, certContext))); /* tlsCertificate= */ null, certContext)));
@ -330,7 +317,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setFilename("foo")) .setTrustedCa(DataSource.newBuilder().setFilename("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.getProviderForClient( SecretVolumeClientSslContextProvider.getProvider(
buildUpstreamTlsContext( buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
@ -351,7 +338,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setFilename("foo")) .setTrustedCa(DataSource.newBuilder().setFilename("foo"))
.build(); .build();
try { try {
SecretVolumeSslContextProvider.getProviderForClient( SecretVolumeClientSslContextProvider.getProvider(
buildUpstreamTlsContext( buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
@ -360,22 +347,6 @@ public class SecretVolumeSslContextProviderTest {
} }
} }
/** Helper method to build SecretVolumeSslContextProvider from given files. */
private static SecretVolumeSslContextProvider<?> getSslContextSecretVolumeSecretProvider(
boolean server,
String certChainFilename,
String privateKeyFilename,
String trustedCaFilename) {
return server
? SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
privateKeyFilename, certChainFilename, trustedCaFilename))
: SecretVolumeSslContextProvider.getProviderForClient(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
privateKeyFilename, certChainFilename, trustedCaFilename));
}
/** /**
* Helper method to build SecretVolumeSslContextProvider, call buildSslContext on it and * Helper method to build SecretVolumeSslContextProvider, call buildSslContext on it and
* check returned SslContext. * check returned SslContext.
@ -383,10 +354,22 @@ public class SecretVolumeSslContextProviderTest {
private static void sslContextForEitherWithBothCertAndTrust( private static void sslContextForEitherWithBothCertAndTrust(
boolean server, String pemFile, String keyFile, String caFile) boolean server, String pemFile, String keyFile, String caFile)
throws IOException, CertificateException, CertStoreException { throws IOException, CertificateException, CertStoreException {
SecretVolumeSslContextProvider<?> provider = SslContext sslContext = null;
getSslContextSecretVolumeSecretProvider(server, pemFile, keyFile, caFile); if (server) {
SecretVolumeServerSslContextProvider provider =
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
keyFile, pemFile, caFile));
SslContext sslContext = provider.buildSslContextFromSecrets(); sslContext = provider.buildSslContextFromSecrets();
} else {
SecretVolumeClientSslContextProvider provider =
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
keyFile, pemFile, caFile));
sslContext = provider.buildSslContextFromSecrets();
}
doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null); doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null);
} }
@ -469,7 +452,7 @@ public class SecretVolumeSslContextProviderTest {
* Helper method to get the value thru directExecutor callback. Because of directExecutor this is * Helper method to get the value thru directExecutor callback. Because of directExecutor this is
* a synchronous callback - so need to provide a listener. * a synchronous callback - so need to provide a listener.
*/ */
static TestCallback getValueThruCallback(SslContextProvider<?> provider) { static TestCallback getValueThruCallback(SslContextProvider provider) {
TestCallback testCallback = new TestCallback(); TestCallback testCallback = new TestCallback();
provider.addCallback(testCallback, MoreExecutors.directExecutor()); provider.addCallback(testCallback, MoreExecutors.directExecutor());
return testCallback; return testCallback;
@ -477,9 +460,10 @@ public class SecretVolumeSslContextProviderTest {
@Test @Test
public void getProviderForServer_both_callsback() throws IOException { public void getProviderForServer_both_callsback() throws IOException {
SecretVolumeSslContextProvider<?> provider = SecretVolumeServerSslContextProvider provider =
getSslContextSecretVolumeSecretProvider( SecretVolumeServerSslContextProvider.getProvider(
true, SERVER_1_PEM_FILE, SERVER_1_KEY_FILE, CA_PEM_FILE); CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider); TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null); doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
@ -487,9 +471,10 @@ public class SecretVolumeSslContextProviderTest {
@Test @Test
public void getProviderForClient_both_callsback() throws IOException { public void getProviderForClient_both_callsback() throws IOException {
SecretVolumeSslContextProvider<?> provider = SecretVolumeClientSslContextProvider provider =
getSslContextSecretVolumeSecretProvider( SecretVolumeClientSslContextProvider.getProvider(
false, CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE); CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider); TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
@ -498,9 +483,10 @@ public class SecretVolumeSslContextProviderTest {
// note this test generates stack-trace but can be safely ignored // note this test generates stack-trace but can be safely ignored
@Test @Test
public void getProviderForClient_both_callsback_setException() throws IOException { public void getProviderForClient_both_callsback_setException() throws IOException {
SecretVolumeSslContextProvider<?> provider = SecretVolumeClientSslContextProvider provider =
getSslContextSecretVolumeSecretProvider( SecretVolumeClientSslContextProvider.getProvider(
false, CLIENT_PEM_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_PEM_FILE, CLIENT_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider); TestCallback testCallback = getValueThruCallback(provider);
assertThat(testCallback.updatedSslContext).isNull(); assertThat(testCallback.updatedSslContext).isNull();
assertThat(testCallback.updatedThrowable).isInstanceOf(IllegalArgumentException.class); assertThat(testCallback.updatedThrowable).isInstanceOf(IllegalArgumentException.class);

View File

@ -41,7 +41,7 @@ public class ServerSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
SslContextProvider<DownstreamTlsContext> sslContextProvider = SslContextProvider sslContextProvider =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
assertThat(sslContextProvider).isNotNull(); assertThat(sslContextProvider).isNotNull();
} }
@ -56,7 +56,7 @@ public class ServerSslContextProviderFactoryTest {
commonTlsContext, /* requireClientCert= */ false); commonTlsContext, /* requireClientCert= */ false);
try { try {
SslContextProvider<DownstreamTlsContext> unused = SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
@ -76,7 +76,7 @@ public class ServerSslContextProviderFactoryTest {
commonTlsContext, /* requireClientCert= */ false); commonTlsContext, /* requireClientCert= */ false);
try { try {
SslContextProvider<DownstreamTlsContext> unused = SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {

View File

@ -49,11 +49,9 @@ public class TlsContextManagerTest {
@Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock @Mock SslContextProviderFactory<UpstreamTlsContext> mockClientFactory;
SslContextProviderFactory<UpstreamTlsContext> mockClientFactory;
@Mock @Mock SslContextProviderFactory<DownstreamTlsContext> mockServerFactory;
SslContextProviderFactory<DownstreamTlsContext> mockServerFactory;
@Before @Before
public void clearInstance() throws NoSuchFieldException, IllegalAccessException { public void clearInstance() throws NoSuchFieldException, IllegalAccessException {
@ -69,11 +67,11 @@ public class TlsContextManagerTest {
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<DownstreamTlsContext> serverSecretProvider = SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isNotNull(); assertThat(serverSecretProvider).isNotNull();
SslContextProvider<DownstreamTlsContext> serverSecretProvider1 = SslContextProvider serverSecretProvider1 =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider1).isSameInstanceAs(serverSecretProvider); assertThat(serverSecretProvider1).isSameInstanceAs(serverSecretProvider);
} }
@ -85,11 +83,11 @@ public class TlsContextManagerTest {
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<UpstreamTlsContext> clientSecretProvider = SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isNotNull(); assertThat(clientSecretProvider).isNotNull();
SslContextProvider<UpstreamTlsContext> clientSecretProvider1 = SslContextProvider clientSecretProvider1 =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider1).isSameInstanceAs(clientSecretProvider); assertThat(clientSecretProvider1).isSameInstanceAs(clientSecretProvider);
} }
@ -101,14 +99,14 @@ public class TlsContextManagerTest {
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<DownstreamTlsContext> serverSecretProvider = SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isNotNull(); assertThat(serverSecretProvider).isNotNull();
DownstreamTlsContext downstreamTlsContext1 = DownstreamTlsContext downstreamTlsContext1 =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE); SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE);
SslContextProvider<DownstreamTlsContext> serverSecretProvider1 = SslContextProvider serverSecretProvider1 =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1);
assertThat(serverSecretProvider1).isNotNull(); assertThat(serverSecretProvider1).isNotNull();
assertThat(serverSecretProvider1).isNotSameInstanceAs(serverSecretProvider); assertThat(serverSecretProvider1).isNotSameInstanceAs(serverSecretProvider);
@ -121,7 +119,7 @@ public class TlsContextManagerTest {
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<UpstreamTlsContext> clientSecretProvider = SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isNotNull(); assertThat(clientSecretProvider).isNotNull();
@ -129,7 +127,7 @@ public class TlsContextManagerTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> clientSecretProvider1 = SslContextProvider clientSecretProvider1 =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1);
assertThat(clientSecretProvider1).isNotSameInstanceAs(clientSecretProvider); assertThat(clientSecretProvider1).isNotSameInstanceAs(clientSecretProvider);
} }
@ -143,13 +141,13 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
SslContextProvider<DownstreamTlsContext> mockProvider = mock(SslContextProvider.class); SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider); when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider<DownstreamTlsContext> serverSecretProvider = SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isSameInstanceAs(mockProvider); assertThat(serverSecretProvider).isSameInstanceAs(mockProvider);
verify(mockProvider, never()).close(); verify(mockProvider, never()).close();
when(mockProvider.getSource()).thenReturn(downstreamTlsContext); when(mockProvider.getDownstreamTlsContext()).thenReturn(downstreamTlsContext);
tlsContextManagerImpl.releaseServerSslContextProvider(mockProvider); tlsContextManagerImpl.releaseServerSslContextProvider(mockProvider);
verify(mockProvider, times(1)).close(); verify(mockProvider, times(1)).close();
} }
@ -163,13 +161,13 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
SslContextProvider<UpstreamTlsContext> mockProvider = mock(SslContextProvider.class); SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider); when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider<UpstreamTlsContext> clientSecretProvider = SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isSameInstanceAs(mockProvider); assertThat(clientSecretProvider).isSameInstanceAs(mockProvider);
verify(mockProvider, never()).close(); verify(mockProvider, never()).close();
when(mockProvider.getSource()).thenReturn(upstreamTlsContext); when(mockProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext);
tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider); tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider);
verify(mockProvider, times(1)).close(); verify(mockProvider, times(1)).close();
} }