xds: replace generic with individual client and server SslContextProviders (#7059)

This commit is contained in:
sanjaypujare 2020-05-27 12:31:54 -07:00 committed by GitHub
parent 7d2d2ec035
commit 62620ccd00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 858 additions and 579 deletions

View File

@ -206,10 +206,10 @@ public final class CdsLoadBalancer extends LoadBalancer {
private static final class EdsLoadBalancingHelper extends ForwardingLoadBalancerHelper {
private final Helper delegate;
private final AtomicReference<SslContextProvider<UpstreamTlsContext>> sslContextProvider;
private final AtomicReference<SslContextProvider> sslContextProvider;
EdsLoadBalancingHelper(Helper helper,
AtomicReference<SslContextProvider<UpstreamTlsContext>> sslContextProvider) {
AtomicReference<SslContextProvider> sslContextProvider) {
this.delegate = helper;
this.sslContextProvider = sslContextProvider;
}
@ -222,7 +222,7 @@ public final class CdsLoadBalancer extends LoadBalancer {
.toBuilder()
.setAddresses(
addUpstreamTlsContext(createSubchannelArgs.getAddresses(),
sslContextProvider.get().getSource()))
sslContextProvider.get().getUpstreamTlsContext()))
.build();
}
return delegate.createSubchannel(createSubchannelArgs);
@ -267,7 +267,7 @@ public final class CdsLoadBalancer extends LoadBalancer {
ClusterWatcherImpl(Helper helper, ResolvedAddresses resolvedAddresses) {
this.helper = new EdsLoadBalancingHelper(helper,
new AtomicReference<SslContextProvider<UpstreamTlsContext>>());
new AtomicReference<SslContextProvider>());
this.resolvedAddresses = resolvedAddresses;
}
@ -303,10 +303,10 @@ public final class CdsLoadBalancer extends LoadBalancer {
/** For new UpstreamTlsContext value, release old SslContextProvider. */
private void updateSslContextProvider(UpstreamTlsContext newUpstreamTlsContext) {
SslContextProvider<UpstreamTlsContext> oldSslContextProvider =
SslContextProvider oldSslContextProvider =
helper.sslContextProvider.get();
if (oldSslContextProvider != null) {
UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getSource();
UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getUpstreamTlsContext();
if (oldUpstreamTlsContext.equals(newUpstreamTlsContext)) {
return;
@ -314,7 +314,7 @@ public final class CdsLoadBalancer extends LoadBalancer {
tlsContextManager.releaseClientSslContextProvider(oldSslContextProvider);
}
if (newUpstreamTlsContext != null) {
SslContextProvider<UpstreamTlsContext> newSslContextProvider =
SslContextProvider newSslContextProvider =
tlsContextManager.findOrCreateClientSslContextProvider(newUpstreamTlsContext);
helper.sslContextProvider.set(newSslContextProvider);
} else {

View File

@ -32,18 +32,17 @@ final class ClientSslContextProviderFactory
/** Creates an SslContextProvider from the given UpstreamTlsContext. */
@Override
public SslContextProvider<UpstreamTlsContext> createSslContextProvider(
UpstreamTlsContext upstreamTlsContext) {
public SslContextProvider createSslContextProvider(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
checkArgument(
upstreamTlsContext.hasCommonTlsContext(),
"upstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename(upstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeSslContextProvider.getProviderForClient(upstreamTlsContext);
return SecretVolumeClientSslContextProvider.getProvider(upstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
upstreamTlsContext.getCommonTlsContext())) {
try {
return SdsSslContextProvider.getProviderForClient(
return SdsClientSslContextProvider.getProvider(
upstreamTlsContext,
Bootstrapper.getInstance().readBootstrap().getNode(),
Executors.newSingleThreadExecutor(new ThreadFactoryBuilder()

View File

@ -16,9 +16,16 @@
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.ValidationContextTypeCase;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.core.DataSource.SpecifierCase;
import javax.annotation.Nullable;
/** Class for utility functions for {@link CommonTlsContext}. */
final class CommonTlsContextUtil {
@ -40,4 +47,53 @@ final class CommonTlsContextUtil {
return (commonTlsContext.getTlsCertificatesCount() == 0)
&& !commonTlsContext.hasValidationContext();
}
@Nullable
static CertificateValidationContext getCertificateValidationContext(
CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext");
ValidationContextTypeCase type = commonTlsContext.getValidationContextTypeCase();
checkState(
type == ValidationContextTypeCase.VALIDATION_CONTEXT
|| type == ValidationContextTypeCase.VALIDATIONCONTEXTTYPE_NOT_SET,
"incorrect ValidationContextTypeCase");
return type == ValidationContextTypeCase.VALIDATION_CONTEXT
? commonTlsContext.getValidationContext()
: null;
}
@Nullable
static CertificateValidationContext validateCertificateContext(
@Nullable CertificateValidationContext certContext, boolean optional) {
if (certContext == null || !certContext.hasTrustedCa()) {
checkArgument(optional, "certContext is required");
return null;
}
checkArgument(
certContext.getTrustedCa().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return certContext;
}
@Nullable
static TlsCertificate validateTlsCertificate(
@Nullable TlsCertificate tlsCertificate, boolean optional) {
if (tlsCertificate == null) {
checkArgument(optional, "tlsCertificate is required");
return null;
}
if (optional
&& (tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.SPECIFIER_NOT_SET)
&& (tlsCertificate.getCertificateChain().getSpecifierCase()
== SpecifierCase.SPECIFIER_NOT_SET)) {
return null;
}
checkArgument(
tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
checkArgument(
tlsCertificate.getCertificateChain().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return tlsCertificate;
}
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
final class DownstreamTlsContextHolder implements TlsContextHolder {
private final DownstreamTlsContext downstreamTlsContext;
DownstreamTlsContextHolder(DownstreamTlsContext downstreamTlsContext) {
this.downstreamTlsContext = checkNotNull(downstreamTlsContext, "downstreamTlsContext");
}
public DownstreamTlsContext getDownstreamTlsContext() {
return downstreamTlsContext;
}
@Override
public CommonTlsContext getCommonTlsContext() {
return downstreamTlsContext.getCommonTlsContext();
}
}

View File

@ -38,7 +38,7 @@ import javax.annotation.concurrent.ThreadSafe;
@ThreadSafe
final class ReferenceCountingSslContextProviderMap<K> {
private final Map<K, Instance<K>> instances = new HashMap<>();
private final Map<K, Instance> instances = new HashMap<>();
private final SslContextProviderFactory<K> sslContextProviderFactory;
ReferenceCountingSslContextProviderMap(SslContextProviderFactory<K> sslContextProviderFactory) {
@ -51,7 +51,7 @@ final class ReferenceCountingSslContextProviderMap<K> {
* using the provided {@link SslContextProviderFactory&lt;K&gt;}
*/
@CheckReturnValue
public SslContextProvider<K> get(K key) {
public SslContextProvider get(K key) {
checkNotNull(key, "key");
return getInternal(key);
}
@ -65,19 +65,20 @@ final class ReferenceCountingSslContextProviderMap<K> {
* <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method.
*
* @param key for the instance to be released
* @param value the instance to be released
* @return a null which the caller can use to clear the reference to that instance.
*/
public SslContextProvider<K> release(final SslContextProvider<K> value) {
public SslContextProvider release(K key, SslContextProvider value) {
checkNotNull(key, "key");
checkNotNull(value, "value");
K key = value.getSource();
return releaseInternal(key, value);
}
private synchronized SslContextProvider<K> getInternal(K key) {
Instance<K> instance = instances.get(key);
private synchronized SslContextProvider getInternal(K key) {
Instance instance = instances.get(key);
if (instance == null) {
instance = new Instance<>(sslContextProviderFactory.createSslContextProvider(key));
instance = new Instance(sslContextProviderFactory.createSslContextProvider(key));
instances.put(key, instance);
return instance.sslContextProvider;
} else {
@ -85,9 +86,8 @@ final class ReferenceCountingSslContextProviderMap<K> {
}
}
private synchronized SslContextProvider<K> releaseInternal(
final K key, final SslContextProvider<K> instance) {
final Instance<K> cached = instances.get(key);
private synchronized SslContextProvider releaseInternal(K key, SslContextProvider instance) {
Instance cached = instances.get(key);
checkArgument(cached != null, "No cached instance found for %s", key);
checkArgument(instance == cached.sslContextProvider, "Releasing the wrong instance");
if (cached.release()) {
@ -103,15 +103,15 @@ final class ReferenceCountingSslContextProviderMap<K> {
/** A factory to create an SslContextProvider from the given key. */
public interface SslContextProviderFactory<K> {
SslContextProvider<K> createSslContextProvider(K key);
SslContextProvider createSslContextProvider(K key);
}
private static class Instance<K> {
final SslContextProvider<K> sslContextProvider;
private static class Instance {
final SslContextProvider sslContextProvider;
private int refCount;
/** Increment refCount and acquire a reference to sslContextProvider. */
SslContextProvider<K> acquire() {
SslContextProvider acquire() {
refCount++;
return sslContextProvider;
}
@ -122,7 +122,7 @@ final class ReferenceCountingSslContextProviderMap<K> {
return --refCount == 0;
}
Instance(SslContextProvider<K> sslContextProvider) {
Instance(SslContextProvider sslContextProvider) {
this.sslContextProvider = sslContextProvider;
this.refCount = 1;
}

View File

@ -0,0 +1,107 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
/** A client SslContext provider that uses SDS to fetch secrets. */
final class SdsClientSslContextProvider extends SdsSslContextProvider {
private SdsClientSslContextProvider(
Node node,
SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig,
CertificateValidationContext staticCertValidationContext,
Executor watcherExecutor,
Executor channelExecutor,
UpstreamTlsContext upstreamTlsContext) {
super(node,
certSdsConfig,
validationContextSdsConfig,
staticCertValidationContext,
watcherExecutor,
channelExecutor, new UpstreamTlsContextHolder(upstreamTlsContext));
}
static SdsClientSslContextProvider getProvider(
UpstreamTlsContext upstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
SdsSecretConfig validationContextSdsConfig = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig =
combinedValidationContext.getValidationContextSdsSecretConfig();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
return new SdsClientSslContextProvider(
node,
certSdsConfig,
validationContextSdsConfig,
staticCertValidationContext,
watcherExecutor,
channelExecutor,
upstreamTlsContext);
}
@Override
SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forClient()
.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
if (tlsCertificate != null) {
sslContextBuilder.keyManager(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null);
}
return sslContextBuilder;
}
}

View File

@ -192,7 +192,7 @@ public final class SdsProtocolNegotiators {
final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
final SslContextProvider<UpstreamTlsContext> sslContextProvider =
final SslContextProvider sslContextProvider =
TlsContextManagerImpl.getInstance()
.findOrCreateClientSslContextProvider(upstreamTlsContext);
@ -349,7 +349,7 @@ public final class SdsProtocolNegotiators {
final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
final SslContextProvider<DownstreamTlsContext> sslContextProvider =
final SslContextProvider sslContextProvider =
TlsContextManagerImpl.getInstance()
.findOrCreateServerSslContextProvider(downstreamTlsContext);

View File

@ -0,0 +1,91 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.netty.GrpcSslContexts;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
/** A server SslContext provider that uses SDS to fetch secrets. */
final class SdsServerSslContextProvider extends SdsSslContextProvider {
private SdsServerSslContextProvider(
Node node,
SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig,
Executor watcherExecutor,
Executor channelExecutor,
DownstreamTlsContext downstreamTlsContext) {
super(node,
certSdsConfig,
validationContextSdsConfig,
null,
watcherExecutor,
channelExecutor, new DownstreamTlsContextHolder(downstreamTlsContext));
}
static SdsServerSslContextProvider getProvider(
DownstreamTlsContext downstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
SdsSecretConfig validationContextSdsConfig = null;
if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
}
return new SdsServerSslContextProvider(
node,
certSdsConfig,
validationContextSdsConfig,
watcherExecutor,
channelExecutor,
downstreamTlsContext);
}
@Override
SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forServer(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword()
? tlsCertificate.getPassword().getInlineString()
: null);
setClientAuthValues(sslContextBuilder, localCertValidationContext);
return sslContextBuilder;
}
}

View File

@ -21,16 +21,11 @@ import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.Secret;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.Status;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
@ -44,12 +39,8 @@ import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
* An SslContext provider that uses SDS to fetch secrets. Used for both server and client
* SslContexts
*/
final class SdsSslContextProvider<K> extends SslContextProvider<K>
implements SdsClient.SecretWatcher {
/** Base class for SdsClientSslContextProvider and SdsServerSslContextProvider. */
abstract class SdsSslContextProvider extends SslContextProvider implements SdsClient.SecretWatcher {
private static final Logger logger = Logger.getLogger(SdsSslContextProvider.class.getName());
@ -59,20 +50,19 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
@Nullable private final SdsSecretConfig validationContextSdsConfig;
@Nullable private final CertificateValidationContext staticCertificateValidationContext;
private final List<CallbackPair> pendingCallbacks = new ArrayList<>();
@Nullable private TlsCertificate tlsCertificate;
@Nullable protected TlsCertificate tlsCertificate;
@Nullable private CertificateValidationContext certificateValidationContext;
@Nullable private SslContext sslContext;
private SdsSslContextProvider(
SdsSslContextProvider(
Node node,
SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig,
CertificateValidationContext staticCertValidationContext,
Executor watcherExecutor,
Executor channelExecutor,
boolean server,
K source) {
super(source, server);
TlsContextHolder tlsContextHolder) {
super(tlsContextHolder);
this.certSdsConfig = certSdsConfig;
this.validationContextSdsConfig = validationContextSdsConfig;
this.staticCertificateValidationContext = staticCertValidationContext;
@ -95,73 +85,6 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
}
}
static SdsSslContextProvider<UpstreamTlsContext> getProviderForClient(
UpstreamTlsContext upstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
SdsSecretConfig validationContextSdsConfig = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig =
combinedValidationContext.getValidationContextSdsSecretConfig();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
return new SdsSslContextProvider<>(
node,
certSdsConfig,
validationContextSdsConfig,
staticCertValidationContext,
watcherExecutor,
channelExecutor,
false,
upstreamTlsContext);
}
static SdsSslContextProvider<DownstreamTlsContext> getProviderForServer(
DownstreamTlsContext downstreamTlsContext,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
}
SdsSecretConfig validationContextSdsConfig = null;
if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
}
return new SdsSslContextProvider<>(
node,
certSdsConfig,
validationContextSdsConfig,
null,
watcherExecutor,
channelExecutor,
true,
downstreamTlsContext);
}
@Override
public void addCallback(Callback callback, Executor executor) {
checkNotNull(callback, "callback");
@ -219,34 +142,16 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
}
}
/** Gets a server or client side SslContextBuilder. */
abstract SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException;
// this gets called only when requested secrets are ready...
private void updateSslContext() {
try {
SslContextBuilder sslContextBuilder;
CertificateValidationContext localCertValidationContext =
mergeStaticAndDynamicCertContexts();
if (server) {
logger.log(Level.FINEST, "for server");
sslContextBuilder =
GrpcSslContexts.forServer(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword()
? tlsCertificate.getPassword().getInlineString()
: null);
setClientAuthValues(sslContextBuilder, localCertValidationContext);
} else {
logger.log(Level.FINEST, "for client");
sslContextBuilder =
GrpcSslContexts.forClient()
.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
if (tlsCertificate != null) {
sslContextBuilder.keyManager(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null);
}
}
CertificateValidationContext localCertValidationContext = mergeStaticAndDynamicCertContexts();
SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext);
CommonTlsContext commonTlsContext = getCommonTlsContext();
if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) {
List<String> alpnList = commonTlsContext.getAlpnProtocolsList();

View File

@ -0,0 +1,125 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateTlsCertificate;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** A client SslContext provider that uses file-based secrets (secret volume). */
final class SecretVolumeClientSslContextProvider extends SslContextProvider {
@Nullable private final String privateKey;
@Nullable private final String privateKeyPassword;
@Nullable private final String certificateChain;
@Nullable private final CertificateValidationContext certContext;
private SecretVolumeClientSslContextProvider(
@Nullable String privateKey,
@Nullable String privateKeyPassword,
@Nullable String certificateChain,
@Nullable CertificateValidationContext certContext,
UpstreamTlsContext upstreamTlsContext) {
super(new UpstreamTlsContextHolder(upstreamTlsContext));
this.privateKey = privateKey;
this.privateKeyPassword = privateKeyPassword;
this.certificateChain = certificateChain;
this.certContext = certContext;
}
static SecretVolumeClientSslContextProvider getProvider(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// first validate
validateCertificateContext(certificateValidationContext, /* optional= */ false);
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// tlsCertificate exists in case of mTLS, else null for a client
if (tlsCertificate != null) {
tlsCertificate = validateTlsCertificate(tlsCertificate, /* optional= */ true);
}
String privateKey = null;
String privateKeyPassword = null;
String certificateChain = null;
if (tlsCertificate != null) {
privateKey = tlsCertificate.getPrivateKey().getFilename();
if (tlsCertificate.hasPassword()) {
privateKeyPassword = tlsCertificate.getPassword().getInlineString();
}
certificateChain = tlsCertificate.getCertificateChain().getFilename();
}
return new SecretVolumeClientSslContextProvider(
privateKey,
privateKeyPassword,
certificateChain,
certificateValidationContext,
upstreamTlsContext);
}
@Override
public void addCallback(final Callback callback, Executor executor) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
new SslContextGetter() {
@Override
public SslContext get() throws CertificateException, IOException, CertStoreException {
return buildSslContextFromSecrets();
}
},
callback,
executor);
}
@Override
public void close() {}
@VisibleForTesting
SslContext buildSslContextFromSecrets()
throws IOException, CertificateException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext));
if (privateKey != null && certificateChain != null) {
sslContextBuilder.keyManager(
new File(certificateChain), new File(privateKey), privateKeyPassword);
}
return sslContextBuilder.build();
}
}

View File

@ -0,0 +1,116 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateTlsCertificate;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.grpc.netty.GrpcSslContexts;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** A server SslContext provider that uses file-based secrets (secret volume). */
final class SecretVolumeServerSslContextProvider extends SslContextProvider {
@Nullable private final String privateKey;
@Nullable private final String privateKeyPassword;
@Nullable private final String certificateChain;
@Nullable private final CertificateValidationContext certContext;
private SecretVolumeServerSslContextProvider(
@Nullable String privateKey,
@Nullable String privateKeyPassword,
@Nullable String certificateChain,
@Nullable CertificateValidationContext certContext,
DownstreamTlsContext downstreamTlsContext) {
super(new DownstreamTlsContextHolder(downstreamTlsContext));
this.privateKey = privateKey;
this.privateKeyPassword = privateKeyPassword;
this.certificateChain = certificateChain;
this.certContext = certContext;
}
static SecretVolumeServerSslContextProvider getProvider(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// first validate
validateTlsCertificate(tlsCertificate, /* optional= */ false);
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// certificateValidationContext exists in case of mTLS, else null for a server
if (certificateValidationContext != null) {
certificateValidationContext =
validateCertificateContext(certificateValidationContext, /* optional= */ true);
}
String privateKeyPassword =
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null;
return new SecretVolumeServerSslContextProvider(
tlsCertificate.getPrivateKey().getFilename(),
privateKeyPassword,
tlsCertificate.getCertificateChain().getFilename(),
certificateValidationContext,
downstreamTlsContext);
}
@Override
public void addCallback(final Callback callback, Executor executor) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
new SslContextGetter() {
@Override
public SslContext get() throws CertificateException, IOException, CertStoreException {
return buildSslContextFromSecrets();
}
},
callback,
executor);
}
@Override
public void close() {}
@VisibleForTesting
SslContext buildSslContextFromSecrets()
throws IOException, CertificateException, CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forServer(
new File(certificateChain), new File(privateKey), privateKeyPassword);
setClientAuthValues(sslContextBuilder, certContext);
return sslContextBuilder.build();
}
}

View File

@ -1,219 +0,0 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.ValidationContextTypeCase;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.DataSource.SpecifierCase;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/**
* An SslContext provider that uses file-based secrets (secret volume). Used for both server and
* client SslContexts
*/
final class SecretVolumeSslContextProvider<K> extends SslContextProvider<K> {
@Nullable private final String privateKey;
@Nullable private final String privateKeyPassword;
@Nullable private final String certificateChain;
@Nullable private final CertificateValidationContext certContext;
private SecretVolumeSslContextProvider(
@Nullable String privateKey,
@Nullable String privateKeyPassword,
@Nullable String certificateChain,
@Nullable CertificateValidationContext certContext,
boolean server,
K source) {
super(source, server);
this.privateKey = privateKey;
this.privateKeyPassword = privateKeyPassword;
this.certificateChain = certificateChain;
this.certContext = certContext;
}
@VisibleForTesting
@Nullable
static CertificateValidationContext validateCertificateContext(
@Nullable CertificateValidationContext certContext, boolean optional) {
if (certContext == null || !certContext.hasTrustedCa()) {
checkArgument(optional, "certContext is required");
return null;
}
checkArgument(
certContext.getTrustedCa().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return certContext;
}
@VisibleForTesting
@Nullable
static TlsCertificate validateTlsCertificate(
@Nullable TlsCertificate tlsCertificate, boolean optional) {
if (tlsCertificate == null) {
checkArgument(optional, "tlsCertificate is required");
return null;
}
if (optional
&& (tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.SPECIFIER_NOT_SET)
&& (tlsCertificate.getCertificateChain().getSpecifierCase()
== SpecifierCase.SPECIFIER_NOT_SET)) {
return null;
}
checkArgument(
tlsCertificate.getPrivateKey().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
checkArgument(
tlsCertificate.getCertificateChain().getSpecifierCase() == SpecifierCase.FILENAME,
"filename expected");
return tlsCertificate;
}
static SecretVolumeSslContextProvider<DownstreamTlsContext> getProviderForServer(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// first validate
validateTlsCertificate(tlsCertificate, /* optional= */ false);
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// certificateValidationContext exists in case of mTLS, else null for a server
if (certificateValidationContext != null) {
certificateValidationContext =
validateCertificateContext(certificateValidationContext, /* optional= */ true);
}
String privateKeyPassword =
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null;
return new SecretVolumeSslContextProvider<>(
tlsCertificate.getPrivateKey().getFilename(),
privateKeyPassword,
tlsCertificate.getCertificateChain().getFilename(),
certificateValidationContext,
/* server= */ true,
downstreamTlsContext);
}
static SecretVolumeSslContextProvider<UpstreamTlsContext> getProviderForClient(
UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// first validate
validateCertificateContext(certificateValidationContext, /* optional= */ false);
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);
}
// tlsCertificate exists in case of mTLS, else null for a client
if (tlsCertificate != null) {
tlsCertificate = validateTlsCertificate(tlsCertificate, /* optional= */ true);
}
String privateKey = null;
String privateKeyPassword = null;
String certificateChain = null;
if (tlsCertificate != null) {
privateKey = tlsCertificate.getPrivateKey().getFilename();
if (tlsCertificate.hasPassword()) {
privateKeyPassword = tlsCertificate.getPassword().getInlineString();
}
certificateChain = tlsCertificate.getCertificateChain().getFilename();
}
return new SecretVolumeSslContextProvider<>(
privateKey,
privateKeyPassword,
certificateChain,
certificateValidationContext,
/* server= */ false,
upstreamTlsContext);
}
private static CertificateValidationContext getCertificateValidationContext(
CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext");
ValidationContextTypeCase type = commonTlsContext.getValidationContextTypeCase();
checkState(
type == ValidationContextTypeCase.VALIDATION_CONTEXT
|| type == ValidationContextTypeCase.VALIDATIONCONTEXTTYPE_NOT_SET,
"incorrect ValidationContextTypeCase");
return type == ValidationContextTypeCase.VALIDATION_CONTEXT
? commonTlsContext.getValidationContext()
: null;
}
@Override
public void addCallback(final Callback callback, Executor executor) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
new SslContextGetter() {
@Override
public SslContext get() throws CertificateException, IOException, CertStoreException {
return buildSslContextFromSecrets();
}
},
callback,
executor);
}
@Override
public void close() {}
@VisibleForTesting
SslContext buildSslContextFromSecrets()
throws IOException, CertificateException, CertStoreException {
SslContextBuilder sslContextBuilder;
if (server) {
sslContextBuilder =
GrpcSslContexts.forServer(
new File(certificateChain), new File(privateKey), privateKeyPassword);
setClientAuthValues(sslContextBuilder, certContext);
} else {
sslContextBuilder =
GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext));
if (privateKey != null && certificateChain != null) {
sslContextBuilder.keyManager(
new File(certificateChain), new File(privateKey), privateKeyPassword);
}
}
return sslContextBuilder.build();
}
}

View File

@ -30,9 +30,9 @@ import java.util.concurrent.Executors;
final class ServerSslContextProviderFactory
implements SslContextProviderFactory<DownstreamTlsContext> {
/** Creates an SslContextProvider from the given DownstreamTlsContext. */
/** Creates a SslContextProvider from the given DownstreamTlsContext. */
@Override
public SslContextProvider<DownstreamTlsContext> createSslContextProvider(
public SslContextProvider createSslContextProvider(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
checkArgument(
@ -40,11 +40,11 @@ final class ServerSslContextProviderFactory
"downstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
downstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeSslContextProvider.getProviderForServer(downstreamTlsContext);
return SecretVolumeServerSslContextProvider.getProvider(downstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
downstreamTlsContext.getCommonTlsContext())) {
try {
return SdsSslContextProvider.getProviderForServer(
return SdsServerSslContextProvider.getProvider(
downstreamTlsContext,
Bootstrapper.getInstance().readBootstrap().getNode(),
Executors.newSingleThreadExecutor(new ThreadFactoryBuilder()

View File

@ -16,7 +16,6 @@
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
@ -41,14 +40,11 @@ import java.util.logging.Logger;
* stream that is receiving the requested secret(s) or it could represent file-system based
* secret(s) that are dynamic.
*/
// TODO(sanjaypujare): replace generic K with DownstreamTlsContext & UpstreamTlsContext in
// separate client&server classes
public abstract class SslContextProvider<K> {
public abstract class SslContextProvider {
private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName());
protected final boolean server;
private final K source;
protected final TlsContextHolder tlsContextHolder;
public interface Callback {
/** Informs callee of new/updated SslContext. */
@ -58,36 +54,20 @@ public abstract class SslContextProvider<K> {
void onException(Throwable throwable);
}
protected SslContextProvider(K source, boolean server) {
if (server) {
checkArgument(source instanceof DownstreamTlsContext, "expecting DownstreamTlsContext");
} else {
checkArgument(source instanceof UpstreamTlsContext, "expecting UpstreamTlsContext");
}
this.source = source;
this.server = server;
}
public K getSource() {
return source;
SslContextProvider(TlsContextHolder tlsContextHolder) {
this.tlsContextHolder = checkNotNull(tlsContextHolder, "tlsContextHolder");
}
CommonTlsContext getCommonTlsContext() {
if (source instanceof UpstreamTlsContext) {
return ((UpstreamTlsContext) source).getCommonTlsContext();
} else if (source instanceof DownstreamTlsContext) {
return ((DownstreamTlsContext) source).getCommonTlsContext();
}
return null;
return tlsContextHolder.getCommonTlsContext();
}
protected void setClientAuthValues(
SslContextBuilder sslContextBuilder, CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
checkState(server, "server side SslContextProvider expected");
DownstreamTlsContext downstreamTlsContext = getDownstreamTlsContext();
if (localCertValidationContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
DownstreamTlsContext downstreamTlsContext = (DownstreamTlsContext)getSource();
sslContextBuilder.clientAuth(
downstreamTlsContext.hasRequireClientCertificate()
? ClientAuth.REQUIRE
@ -97,6 +77,20 @@ public abstract class SslContextProvider<K> {
}
}
/** Returns the DownstreamTlsContext in this SslContextProvider if this is server side. **/
public DownstreamTlsContext getDownstreamTlsContext() {
checkState(tlsContextHolder instanceof DownstreamTlsContextHolder,
"expected DownstreamTlsContextHolder");
return ((DownstreamTlsContextHolder) tlsContextHolder).getDownstreamTlsContext();
}
/** Returns the UpstreamTlsContext in this SslContextProvider if this is client side. **/
public UpstreamTlsContext getUpstreamTlsContext() {
checkState(tlsContextHolder instanceof UpstreamTlsContextHolder,
"expected UpstreamTlsContextHolder");
return ((UpstreamTlsContextHolder) tlsContextHolder).getUpstreamTlsContext();
}
/** Closes this provider and releases any resources. */
void close() {}
@ -106,7 +100,7 @@ public abstract class SslContextProvider<K> {
*/
public abstract void addCallback(Callback callback, Executor executor);
protected void performCallback(
final void performCallback(
final SslContextGetter sslContextGetter, final Callback callback, Executor executor) {
checkNotNull(sslContextGetter, "sslContextGetter");
checkNotNull(callback, "callback");

View File

@ -0,0 +1,29 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
/**
* A holder of {@link UpstreamTlsContext} or
* {@link io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext}.
*/
public interface TlsContextHolder {
CommonTlsContext getCommonTlsContext();
}

View File

@ -22,11 +22,11 @@ import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
public interface TlsContextManager {
/** Creates a SslContextProvider. Used for retrieving a server-side SslContext. */
SslContextProvider<DownstreamTlsContext> findOrCreateServerSslContextProvider(
SslContextProvider findOrCreateServerSslContextProvider(
DownstreamTlsContext downstreamTlsContext);
/** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */
SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvider(
SslContextProvider findOrCreateClientSslContextProvider(
UpstreamTlsContext upstreamTlsContext);
/**
@ -38,8 +38,7 @@ public interface TlsContextManager {
* <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method.
*/
SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider(
SslContextProvider<UpstreamTlsContext> sslContextProvider);
SslContextProvider releaseClientSslContextProvider(SslContextProvider sslContextProvider);
/**
* Releases an instance of the given server-side {@link SslContextProvider}.
@ -50,6 +49,5 @@ public interface TlsContextManager {
* <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method.
*/
SslContextProvider<DownstreamTlsContext> releaseServerSslContextProvider(
SslContextProvider<DownstreamTlsContext> sslContextProvider);
SslContextProvider releaseServerSslContextProvider(SslContextProvider sslContextProvider);
}

View File

@ -59,30 +59,32 @@ public final class TlsContextManagerImpl implements TlsContextManager {
}
@Override
public SslContextProvider<DownstreamTlsContext> findOrCreateServerSslContextProvider(
public SslContextProvider findOrCreateServerSslContextProvider(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
return mapForServers.get(downstreamTlsContext);
}
@Override
public SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvider(
public SslContextProvider findOrCreateClientSslContextProvider(
UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
return mapForClients.get(upstreamTlsContext);
}
@Override
public SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider(
SslContextProvider<UpstreamTlsContext> sslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider");
return mapForClients.release(sslContextProvider);
public SslContextProvider releaseClientSslContextProvider(
SslContextProvider clientSslContextProvider) {
checkNotNull(clientSslContextProvider, "clientSslContextProvider");
return mapForClients.release(clientSslContextProvider.getUpstreamTlsContext(),
clientSslContextProvider);
}
@Override
public SslContextProvider<DownstreamTlsContext> releaseServerSslContextProvider(
SslContextProvider<DownstreamTlsContext> sslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider");
return mapForServers.release(sslContextProvider);
public SslContextProvider releaseServerSslContextProvider(
SslContextProvider serverSslContextProvider) {
checkNotNull(serverSslContextProvider, "serverSslContextProvider");
return mapForServers.release(serverSslContextProvider.getDownstreamTlsContext(),
serverSslContextProvider);
}
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
final class UpstreamTlsContextHolder implements TlsContextHolder {
private final UpstreamTlsContext upstreamTlsContext;
UpstreamTlsContextHolder(UpstreamTlsContext upstreamTlsContext) {
this.upstreamTlsContext = checkNotNull(upstreamTlsContext, "upstreamTlsContext");
}
public UpstreamTlsContext getUpstreamTlsContext() {
return upstreamTlsContext;
}
@Override
public CommonTlsContext getCommonTlsContext() {
return upstreamTlsContext.getCommonTlsContext();
}
}

View File

@ -356,9 +356,8 @@ public class CdsLoadBalancerTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> mockSslContextProvider =
(SslContextProvider<UpstreamTlsContext>) mock(SslContextProvider.class);
doReturn(upstreamTlsContext).when(mockSslContextProvider).getSource();
SslContextProvider mockSslContextProvider = mock(SslContextProvider.class);
doReturn(upstreamTlsContext).when(mockSslContextProvider).getUpstreamTlsContext();
doReturn(mockSslContextProvider).when(mockTlsContextManager)
.findOrCreateClientSslContextProvider(same(upstreamTlsContext));
@ -373,8 +372,8 @@ public class CdsLoadBalancerTest {
assertThat(edsLbHelpers).hasSize(1);
assertThat(edsLoadBalancers).hasSize(1);
verify(mockTlsContextManager, never()).releaseClientSslContextProvider(
(SslContextProvider<UpstreamTlsContext>) any(SslContextProvider.class));
verify(mockTlsContextManager, never())
.releaseClientSslContextProvider(any(SslContextProvider.class));
Helper edsLbHelper1 = edsLbHelpers.poll();
ArrayList<EquivalentAddressGroup> eagList = new ArrayList<>();
@ -403,8 +402,8 @@ public class CdsLoadBalancerTest {
.setUpstreamTlsContext(upstreamTlsContext)
.build());
verify(mockTlsContextManager, never()).releaseClientSslContextProvider(
(SslContextProvider<UpstreamTlsContext>) any(SslContextProvider.class));
verify(mockTlsContextManager, never())
.releaseClientSslContextProvider(any(SslContextProvider.class));
verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider(
any(UpstreamTlsContext.class));
@ -414,9 +413,8 @@ public class CdsLoadBalancerTest {
UpstreamTlsContext upstreamTlsContext1 =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> mockSslContextProvider1 =
(SslContextProvider<UpstreamTlsContext>) mock(SslContextProvider.class);
doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getSource();
SslContextProvider mockSslContextProvider1 = mock(SslContextProvider.class);
doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getUpstreamTlsContext();
doReturn(mockSslContextProvider1).when(mockTlsContextManager)
.findOrCreateClientSslContextProvider(same(upstreamTlsContext1));
clusterWatcher1.onClusterChanged(

View File

@ -41,7 +41,7 @@ public class ClientSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> sslContextProvider =
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
assertThat(sslContextProvider).isNotNull();
}
@ -55,7 +55,7 @@ public class ClientSslContextProviderFactoryTest {
SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext);
try {
SslContextProvider<UpstreamTlsContext> unused =
SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
@ -77,7 +77,7 @@ public class ClientSslContextProviderFactoryTest {
SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext);
try {
SslContextProvider<UpstreamTlsContext> unused =
SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {

View File

@ -51,60 +51,53 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test
public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException {
SslContextProvider<Integer> valueFor3 = getTypedMock();
SslContextProvider valueFor3 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
SslContextProvider<Integer> val = map.get(3);
SslContextProvider val = map.get(3);
assertThat(val).isSameInstanceAs(valueFor3);
verify(valueFor3, never()).close();
val = map.get(3);
assertThat(val).isSameInstanceAs(valueFor3);
// at this point ref-count is 2
when(valueFor3.getSource()).thenReturn(3);
assertThat(map.release(val)).isNull();
assertThat(map.release(3, val)).isNull();
verify(valueFor3, never()).close();
assertThat(map.release(val)).isNull(); // after this ref-count is 0
assertThat(map.release(3, val)).isNull(); // after this ref-count is 0
verify(valueFor3, times(1)).close();
}
@SuppressWarnings("unchecked")
private static SslContextProvider<Integer> getTypedMock() {
private static SslContextProvider getTypedMock() {
return mock(SslContextProvider.class);
}
@Test
public void referenceCountingMap_distinctElements() throws InterruptedException {
SslContextProvider<Integer> valueFor3 = getTypedMock();
SslContextProvider<Integer> valueFor4 = getTypedMock();
when(valueFor3.getSource()).thenReturn(3);
when(valueFor4.getSource()).thenReturn(4);
SslContextProvider valueFor3 = getTypedMock();
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> val3 = map.get(3);
SslContextProvider val3 = map.get(3);
assertThat(val3).isSameInstanceAs(valueFor3);
SslContextProvider<Integer> val4 = map.get(4);
SslContextProvider val4 = map.get(4);
assertThat(val4).isSameInstanceAs(valueFor4);
assertThat(map.release(val3)).isNull();
assertThat(map.release(3, val3)).isNull();
verify(valueFor3, times(1)).close();
verify(valueFor4, never()).close();
assertThat(map.release(val4)).isNull();
assertThat(map.release(4, val4)).isNull();
verify(valueFor4, times(1)).close();
}
@Test
public void referenceCountingMap_releaseWrongElement_expectException()
throws InterruptedException {
SslContextProvider<Integer> valueFor3 = getTypedMock();
SslContextProvider<Integer> valueFor4 = getTypedMock();
when(valueFor3.getSource()).thenReturn(3);
when(valueFor4.getSource()).thenReturn(4);
SslContextProvider valueFor3 = getTypedMock();
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> unused = map.get(3);
SslContextProvider<Integer> val4 = map.get(4);
SslContextProvider unused = map.get(3);
SslContextProvider val4 = map.get(4);
// now provide wrong key (3) and value (val4) combination
when(valueFor4.getSource()).thenReturn(3);
try {
map.release(val4);
map.release(3, val4);
fail("exception expected");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().contains("Releasing the wrong instance");
@ -113,16 +106,15 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test
public void referenceCountingMap_excessRelease_expectException() throws InterruptedException {
SslContextProvider<Integer> valueFor4 = getTypedMock();
when(valueFor4.getSource()).thenReturn(4);
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> val = map.get(4);
SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1
map.release(val);
map.release(4, val);
// at this point ref-count is 0
try {
map.release(val);
map.release(4, val);
fail("exception expected");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().contains("No cached instance found for 4");
@ -131,16 +123,15 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test
public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException {
SslContextProvider<Integer> valueFor4 = getTypedMock();
when(valueFor4.getSource()).thenReturn(4);
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
SslContextProvider<Integer> val = map.get(4);
SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1
map.release(val);
map.release(4, val);
// at this point ref-count is 0 and val is removed
// should get another instance for 4
SslContextProvider<Integer> valueFor4a = getTypedMock();
SslContextProvider valueFor4a = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4a);
val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4a);

View File

@ -40,7 +40,7 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link SdsSslContextProvider}. */
/** Unit tests for {@link SdsClientSslContextProvider}. */
@RunWith(JUnit4.class)
public class SdsSslContextProviderTest {
@ -62,10 +62,13 @@ public class SdsSslContextProviderTest {
server.shutdown();
}
/** Helper method to build SdsSslContextProvider from given names. */
private SdsSslContextProvider<?> getSdsSslContextProvider(
boolean server, String certName, String validationContextName,
Iterable<String> verifySubjectAltNames, Iterable<String> alpnProtocols) throws IOException {
/** Helper method to build SdsClientSslContextProvider from given names. */
private SdsClientSslContextProvider getSdsClientSslContextProvider(
String certName,
String validationContextName,
Iterable<String> verifySubjectAltNames,
Iterable<String> alpnProtocols)
throws IOException {
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
@ -77,18 +80,37 @@ public class SdsSslContextProviderTest {
alpnProtocols,
/* channelType= */ "inproc");
return server
? SdsSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor())
: SdsSslContextProvider.getProviderForClient(
SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor());
return SdsClientSslContextProvider.getProvider(
SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(commonTlsContext),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor());
}
/** Helper method to build SdsServerSslContextProvider from given names. */
private SdsServerSslContextProvider getSdsServerSslContextProvider(
String certName,
String validationContextName,
Iterable<String> verifySubjectAltNames,
Iterable<String> alpnProtocols)
throws IOException {
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
certName,
/* certTargetUri= */ "inproc",
validationContextName,
/* validationContextTargetUri= */ "inproc",
verifySubjectAltNames,
alpnProtocols,
/* channelType= */ "inproc");
return SdsServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor());
}
@Test
@ -98,8 +120,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(/* server= */ true, "cert1", "valid1", null, null);
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider("cert1", "valid1", null, null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -113,9 +135,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ false,
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null,
@ -131,10 +152,12 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ true, /* certName= */ "cert1", /* validationContextName= */ null,
/* verifySubjectAltNames= */ null, /* alpnProtocols= */ null);
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ null,
/* verifySubjectAltNames= */ null,
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -146,10 +169,12 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ false, /* certName= */ null, /* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null, null);
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ null,
/* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null,
null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -161,10 +186,12 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ true, /* certName= */ null, /* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null, /* alpnProtocols= */ null);
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider(
/* certName= */ null,
/* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null,
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@ -184,13 +211,11 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ false,
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
Arrays.asList(
"spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
@ -205,9 +230,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ false,
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null,
@ -226,9 +250,8 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ true,
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
/* verifySubjectAltNames= */ null,

View File

@ -41,7 +41,7 @@ import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link SecretVolumeSslContextProvider}. */
/** Unit tests for {@link SecretVolumeClientSslContextProvider}. */
@RunWith(JUnit4.class)
public class SecretVolumeSslContextProviderTest {
@ -51,7 +51,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateCertificateContext_nullAndNotOptional_throwsException() {
// expect exception when certContext is null and not optional
try {
SecretVolumeSslContextProvider.validateCertificateContext(
CommonTlsContextUtil.validateCertificateContext(
/* certContext= */ null, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
@ -64,8 +64,7 @@ public class SecretVolumeSslContextProviderTest {
// expect exception when certContext has no CA and not optional
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
try {
SecretVolumeSslContextProvider.validateCertificateContext(
certContext, /* optional= */ false);
CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("certContext is required");
@ -76,7 +75,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateCertificateContext_nullAndOptional() {
// certContext argument can be null when optional
CertificateValidationContext certContext =
SecretVolumeSslContextProvider.validateCertificateContext(
CommonTlsContextUtil.validateCertificateContext(
/* certContext= */ null, /* optional= */ true);
assertThat(certContext).isNull();
}
@ -85,9 +84,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateCertificateContext_missingTrustCaOptional() {
// certContext argument can have missing CA when optional
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
assertThat(
SecretVolumeSslContextProvider.validateCertificateContext(
certContext, /* optional= */ true))
assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ true))
.isNull();
}
@ -99,8 +96,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
SecretVolumeSslContextProvider.validateCertificateContext(
certContext, /* optional= */ false);
CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -114,9 +110,7 @@ public class SecretVolumeSslContextProviderTest {
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename("bar"))
.build();
assertThat(
SecretVolumeSslContextProvider.validateCertificateContext(
certContext, /* optional= */ false))
assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false))
.isSameInstanceAs(certContext);
}
@ -124,7 +118,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateTlsCertificate_nullAndNotOptional_throwsException() {
// expect exception when tlsCertificate is null and not optional
try {
SecretVolumeSslContextProvider.validateTlsCertificate(
CommonTlsContextUtil.validateTlsCertificate(
/* tlsCertificate= */ null, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
@ -135,7 +129,7 @@ public class SecretVolumeSslContextProviderTest {
@Test
public void validateTlsCertificate_nullOptional() {
assertThat(
SecretVolumeSslContextProvider.validateTlsCertificate(
CommonTlsContextUtil.validateTlsCertificate(
/* tlsCertificate= */ null, /* optional= */ true))
.isNull();
}
@ -144,10 +138,7 @@ public class SecretVolumeSslContextProviderTest {
public void validateTlsCertificate_defaultInstance_returnsNull() {
// tlsCertificate is not null but has no value (default instance): expect null
TlsCertificate tlsCert = TlsCertificate.getDefaultInstance();
assertThat(
SecretVolumeSslContextProvider.validateTlsCertificate(
tlsCert, /* optional= */ true))
.isNull();
assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true)).isNull();
}
@Test
@ -158,7 +149,7 @@ public class SecretVolumeSslContextProviderTest {
.setPrivateKey(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ false);
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -173,7 +164,7 @@ public class SecretVolumeSslContextProviderTest {
.setPrivateKey(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true);
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -188,7 +179,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ false);
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -203,7 +194,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true);
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -217,9 +208,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
assertThat(
SecretVolumeSslContextProvider.validateTlsCertificate(
tlsCert, /* optional= */ true))
assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true))
.isSameInstanceAs(tlsCert);
}
@ -230,9 +219,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
assertThat(
SecretVolumeSslContextProvider.validateTlsCertificate(
tlsCert, /* optional= */ false))
assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false))
.isSameInstanceAs(tlsCert);
}
@ -245,7 +232,7 @@ public class SecretVolumeSslContextProviderTest {
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true);
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -261,7 +248,7 @@ public class SecretVolumeSslContextProviderTest {
.setCertificateChain(DataSource.newBuilder().setFilename("bar"))
.build();
try {
SecretVolumeSslContextProvider.validateTlsCertificate(tlsCert, /* optional= */ true);
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -272,7 +259,7 @@ public class SecretVolumeSslContextProviderTest {
public void getProviderForServer_defaultTlsCertificate_throwsException() {
TlsCertificate tlsCert = TlsCertificate.getDefaultInstance();
try {
SecretVolumeSslContextProvider.getProviderForServer(
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null),
/* requireClientCert= */ false));
@ -294,7 +281,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
SecretVolumeSslContextProvider.getProviderForServer(
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext),
/* requireClientCert= */ false));
@ -308,7 +295,7 @@ public class SecretVolumeSslContextProviderTest {
public void getProviderForClient_defaultCertContext_throwsException() {
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
try {
SecretVolumeSslContextProvider.getProviderForClient(
SecretVolumeClientSslContextProvider.getProvider(
buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(
/* tlsCertificate= */ null, certContext)));
@ -330,7 +317,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setFilename("foo"))
.build();
try {
SecretVolumeSslContextProvider.getProviderForClient(
SecretVolumeClientSslContextProvider.getProvider(
buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown");
@ -351,7 +338,7 @@ public class SecretVolumeSslContextProviderTest {
.setTrustedCa(DataSource.newBuilder().setFilename("foo"))
.build();
try {
SecretVolumeSslContextProvider.getProviderForClient(
SecretVolumeClientSslContextProvider.getProvider(
buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown");
@ -360,22 +347,6 @@ public class SecretVolumeSslContextProviderTest {
}
}
/** Helper method to build SecretVolumeSslContextProvider from given files. */
private static SecretVolumeSslContextProvider<?> getSslContextSecretVolumeSecretProvider(
boolean server,
String certChainFilename,
String privateKeyFilename,
String trustedCaFilename) {
return server
? SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
privateKeyFilename, certChainFilename, trustedCaFilename))
: SecretVolumeSslContextProvider.getProviderForClient(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
privateKeyFilename, certChainFilename, trustedCaFilename));
}
/**
* Helper method to build SecretVolumeSslContextProvider, call buildSslContext on it and
* check returned SslContext.
@ -383,10 +354,22 @@ public class SecretVolumeSslContextProviderTest {
private static void sslContextForEitherWithBothCertAndTrust(
boolean server, String pemFile, String keyFile, String caFile)
throws IOException, CertificateException, CertStoreException {
SecretVolumeSslContextProvider<?> provider =
getSslContextSecretVolumeSecretProvider(server, pemFile, keyFile, caFile);
SslContext sslContext = null;
if (server) {
SecretVolumeServerSslContextProvider provider =
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
keyFile, pemFile, caFile));
SslContext sslContext = provider.buildSslContextFromSecrets();
sslContext = provider.buildSslContextFromSecrets();
} else {
SecretVolumeClientSslContextProvider provider =
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
keyFile, pemFile, caFile));
sslContext = provider.buildSslContextFromSecrets();
}
doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null);
}
@ -469,7 +452,7 @@ public class SecretVolumeSslContextProviderTest {
* Helper method to get the value thru directExecutor callback. Because of directExecutor this is
* a synchronous callback - so need to provide a listener.
*/
static TestCallback getValueThruCallback(SslContextProvider<?> provider) {
static TestCallback getValueThruCallback(SslContextProvider provider) {
TestCallback testCallback = new TestCallback();
provider.addCallback(testCallback, MoreExecutors.directExecutor());
return testCallback;
@ -477,9 +460,10 @@ public class SecretVolumeSslContextProviderTest {
@Test
public void getProviderForServer_both_callsback() throws IOException {
SecretVolumeSslContextProvider<?> provider =
getSslContextSecretVolumeSecretProvider(
true, SERVER_1_PEM_FILE, SERVER_1_KEY_FILE, CA_PEM_FILE);
SecretVolumeServerSslContextProvider provider =
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
@ -487,9 +471,10 @@ public class SecretVolumeSslContextProviderTest {
@Test
public void getProviderForClient_both_callsback() throws IOException {
SecretVolumeSslContextProvider<?> provider =
getSslContextSecretVolumeSecretProvider(
false, CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE);
SecretVolumeClientSslContextProvider provider =
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
@ -498,9 +483,10 @@ public class SecretVolumeSslContextProviderTest {
// note this test generates stack-trace but can be safely ignored
@Test
public void getProviderForClient_both_callsback_setException() throws IOException {
SecretVolumeSslContextProvider<?> provider =
getSslContextSecretVolumeSecretProvider(
false, CLIENT_PEM_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SecretVolumeClientSslContextProvider provider =
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_PEM_FILE, CLIENT_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider);
assertThat(testCallback.updatedSslContext).isNull();
assertThat(testCallback.updatedThrowable).isInstanceOf(IllegalArgumentException.class);

View File

@ -41,7 +41,7 @@ public class ServerSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
SslContextProvider<DownstreamTlsContext> sslContextProvider =
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
assertThat(sslContextProvider).isNotNull();
}
@ -56,7 +56,7 @@ public class ServerSslContextProviderFactoryTest {
commonTlsContext, /* requireClientCert= */ false);
try {
SslContextProvider<DownstreamTlsContext> unused =
SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
@ -76,7 +76,7 @@ public class ServerSslContextProviderFactoryTest {
commonTlsContext, /* requireClientCert= */ false);
try {
SslContextProvider<DownstreamTlsContext> unused =
SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {

View File

@ -49,11 +49,9 @@ public class TlsContextManagerTest {
@Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock
SslContextProviderFactory<UpstreamTlsContext> mockClientFactory;
@Mock SslContextProviderFactory<UpstreamTlsContext> mockClientFactory;
@Mock
SslContextProviderFactory<DownstreamTlsContext> mockServerFactory;
@Mock SslContextProviderFactory<DownstreamTlsContext> mockServerFactory;
@Before
public void clearInstance() throws NoSuchFieldException, IllegalAccessException {
@ -69,11 +67,11 @@ public class TlsContextManagerTest {
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<DownstreamTlsContext> serverSecretProvider =
SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isNotNull();
SslContextProvider<DownstreamTlsContext> serverSecretProvider1 =
SslContextProvider serverSecretProvider1 =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider1).isSameInstanceAs(serverSecretProvider);
}
@ -85,11 +83,11 @@ public class TlsContextManagerTest {
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<UpstreamTlsContext> clientSecretProvider =
SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isNotNull();
SslContextProvider<UpstreamTlsContext> clientSecretProvider1 =
SslContextProvider clientSecretProvider1 =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider1).isSameInstanceAs(clientSecretProvider);
}
@ -101,14 +99,14 @@ public class TlsContextManagerTest {
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<DownstreamTlsContext> serverSecretProvider =
SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isNotNull();
DownstreamTlsContext downstreamTlsContext1 =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE);
SslContextProvider<DownstreamTlsContext> serverSecretProvider1 =
SslContextProvider serverSecretProvider1 =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1);
assertThat(serverSecretProvider1).isNotNull();
assertThat(serverSecretProvider1).isNotSameInstanceAs(serverSecretProvider);
@ -121,7 +119,7 @@ public class TlsContextManagerTest {
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
SslContextProvider<UpstreamTlsContext> clientSecretProvider =
SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isNotNull();
@ -129,7 +127,7 @@ public class TlsContextManagerTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> clientSecretProvider1 =
SslContextProvider clientSecretProvider1 =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1);
assertThat(clientSecretProvider1).isNotSameInstanceAs(clientSecretProvider);
}
@ -143,13 +141,13 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
@SuppressWarnings("unchecked")
SslContextProvider<DownstreamTlsContext> mockProvider = mock(SslContextProvider.class);
SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider<DownstreamTlsContext> serverSecretProvider =
SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isSameInstanceAs(mockProvider);
verify(mockProvider, never()).close();
when(mockProvider.getSource()).thenReturn(downstreamTlsContext);
when(mockProvider.getDownstreamTlsContext()).thenReturn(downstreamTlsContext);
tlsContextManagerImpl.releaseServerSslContextProvider(mockProvider);
verify(mockProvider, times(1)).close();
}
@ -163,13 +161,13 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
@SuppressWarnings("unchecked")
SslContextProvider<UpstreamTlsContext> mockProvider = mock(SslContextProvider.class);
SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider<UpstreamTlsContext> clientSecretProvider =
SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isSameInstanceAs(mockProvider);
verify(mockProvider, never()).close();
when(mockProvider.getSource()).thenReturn(upstreamTlsContext);
when(mockProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext);
tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider);
verify(mockProvider, times(1)).close();
}