xds: Add CertProviderSslContextProviders to Client&Server SslContextProviderFactories (#7338)

This commit is contained in:
sanjaypujare 2020-08-21 14:08:39 -07:00 committed by GitHub
parent 80480e69ef
commit e6ab167334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 525 additions and 51 deletions

View File

@ -72,7 +72,7 @@ public final class EnvoyServerProtoData {
public static final class UpstreamTlsContext extends BaseTlsContext { public static final class UpstreamTlsContext extends BaseTlsContext {
@VisibleForTesting @VisibleForTesting
UpstreamTlsContext(CommonTlsContext commonTlsContext) { public UpstreamTlsContext(CommonTlsContext commonTlsContext) {
super(commonTlsContext); super(commonTlsContext);
} }
@ -93,7 +93,8 @@ public final class EnvoyServerProtoData {
private final boolean requireClientCertificate; private final boolean requireClientCertificate;
@VisibleForTesting @VisibleForTesting
DownstreamTlsContext(CommonTlsContext commonTlsContext, boolean requireClientCertificate) { public DownstreamTlsContext(
CommonTlsContext commonTlsContext, boolean requireClientCertificate) {
super(commonTlsContext); super(commonTlsContext);
this.requireClientCertificate = requireClientCertificate; this.requireClientCertificate = requireClientCertificate;
} }

View File

@ -23,6 +23,7 @@ import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
@ -33,7 +34,8 @@ import java.security.cert.X509Certificate;
import java.util.Map; import java.util.Map;
/** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */
final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { @Internal
public final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider {
private CertProviderClientSslContextProvider( private CertProviderClientSslContextProvider(
Node node, Node node,
@ -70,20 +72,22 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP
} }
/** Creates CertProviderClientSslContextProvider. */ /** Creates CertProviderClientSslContextProvider. */
static final class Factory { @Internal
public static final class Factory {
private static final Factory DEFAULT_INSTANCE = private static final Factory DEFAULT_INSTANCE =
new Factory(CertificateProviderStore.getInstance()); new Factory(CertificateProviderStore.getInstance());
private final CertificateProviderStore certificateProviderStore; private final CertificateProviderStore certificateProviderStore;
@VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) { @VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) {
this.certificateProviderStore = certificateProviderStore; this.certificateProviderStore = certificateProviderStore;
} }
static Factory getInstance() { public static Factory getInstance() {
return DEFAULT_INSTANCE; return DEFAULT_INSTANCE;
} }
CertProviderClientSslContextProvider getProvider( /** Creates a {@link CertProviderClientSslContextProvider}. */
public CertProviderClientSslContextProvider getProvider(
UpstreamTlsContext upstreamTlsContext, UpstreamTlsContext upstreamTlsContext,
Node node, Node node,
Map<String, CertificateProviderInfo> certProviders) { Map<String, CertificateProviderInfo> certProviders) {

View File

@ -23,6 +23,7 @@ import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
@ -36,7 +37,8 @@ import java.security.cert.X509Certificate;
import java.util.Map; import java.util.Map;
/** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */
final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { @Internal
public final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider {
private CertProviderServerSslContextProvider( private CertProviderServerSslContextProvider(
Node node, Node node,
@ -73,20 +75,22 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP
} }
/** Creates CertProviderServerSslContextProvider. */ /** Creates CertProviderServerSslContextProvider. */
static final class Factory { @Internal
public static final class Factory {
private static final Factory DEFAULT_INSTANCE = private static final Factory DEFAULT_INSTANCE =
new Factory(CertificateProviderStore.getInstance()); new Factory(CertificateProviderStore.getInstance());
private final CertificateProviderStore certificateProviderStore; private final CertificateProviderStore certificateProviderStore;
@VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) { @VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) {
this.certificateProviderStore = certificateProviderStore; this.certificateProviderStore = certificateProviderStore;
} }
static Factory getInstance() { public static Factory getInstance() {
return DEFAULT_INSTANCE; return DEFAULT_INSTANCE;
} }
CertProviderServerSslContextProvider getProvider( /** Creates a {@link CertProviderServerSslContextProvider}. */
public CertProviderServerSslContextProvider getProvider(
DownstreamTlsContext downstreamTlsContext, DownstreamTlsContext downstreamTlsContext,
Node node, Node node,
Map<String, CertificateProviderInfo> certProviders) { Map<String, CertificateProviderInfo> certProviders) {

View File

@ -23,6 +23,7 @@ import io.grpc.Status;
import io.grpc.xds.internal.sds.Closeable; import io.grpc.xds.internal.sds.Closeable;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -48,7 +49,7 @@ public abstract class CertificateProvider implements Closeable {
} }
@VisibleForTesting @VisibleForTesting
static final class DistributorWatcher implements Watcher { public static final class DistributorWatcher implements Watcher {
private PrivateKey privateKey; private PrivateKey privateKey;
private List<X509Certificate> certChain; private List<X509Certificate> certChain;
private List<X509Certificate> trustedRoots; private List<X509Certificate> trustedRoots;
@ -70,6 +71,10 @@ public abstract class CertificateProvider implements Closeable {
downstreamWatchers.remove(watcher); downstreamWatchers.remove(watcher);
} }
@VisibleForTesting public Set<Watcher> getDownstreamWatchers() {
return Collections.unmodifiableSet(downstreamWatchers);
}
private void sendLastCertificateUpdate(Watcher watcher) { private void sendLastCertificateUpdate(Watcher watcher) {
watcher.updateCertificate(privateKey, certChain); watcher.updateCertificate(privateKey, certChain);
} }

View File

@ -16,13 +16,15 @@
package io.grpc.xds.internal.certprovider; package io.grpc.xds.internal.certprovider;
import io.grpc.Internal;
import io.grpc.xds.internal.certprovider.CertificateProvider.Watcher; import io.grpc.xds.internal.certprovider.CertificateProvider.Watcher;
/** /**
* Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may * Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may
* move this out of the internal package and make this an official API in the future. * move this out of the internal package and make this an official API in the future.
*/ */
interface CertificateProviderProvider { @Internal
public interface CertificateProviderProvider {
/** Returns the unique name of the {@link CertificateProvider} plugin. */ /** Returns the unique name of the {@link CertificateProvider} plugin. */
String getName(); String getName();

View File

@ -31,7 +31,7 @@ public final class CertificateProviderRegistry {
new LinkedHashMap<>(); new LinkedHashMap<>();
@VisibleForTesting @VisibleForTesting
CertificateProviderRegistry() { public CertificateProviderRegistry() {
} }
/** Returns the singleton registry. */ /** Returns the singleton registry. */

View File

@ -139,7 +139,7 @@ public final class CertificateProviderStore {
} }
@VisibleForTesting @VisibleForTesting
CertificateProviderStore(CertificateProviderRegistry certificateProviderRegistry) { public CertificateProviderStore(CertificateProviderRegistry certificateProviderRegistry) {
this.certificateProviderRegistry = certificateProviderRegistry; this.certificateProviderRegistry = certificateProviderRegistry;
certProviderMap = new ReferenceCountingMap<>(new CertProviderFactory()); certProviderMap = new ReferenceCountingMap<>(new CertProviderFactory());
} }

View File

@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper; import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -29,6 +30,20 @@ import java.util.concurrent.Executors;
final class ClientSslContextProviderFactory final class ClientSslContextProviderFactory
implements ValueFactory<UpstreamTlsContext, SslContextProvider> { implements ValueFactory<UpstreamTlsContext, SslContextProvider> {
private final Bootstrapper bootstrapper;
private final CertProviderClientSslContextProvider.Factory
certProviderClientSslContextProviderFactory;
ClientSslContextProviderFactory() {
this(Bootstrapper.getInstance(), CertProviderClientSslContextProvider.Factory.getInstance());
}
ClientSslContextProviderFactory(
Bootstrapper bootstrapper, CertProviderClientSslContextProvider.Factory factory) {
this.bootstrapper = bootstrapper;
this.certProviderClientSslContextProviderFactory = factory;
}
/** Creates an SslContextProvider from the given UpstreamTlsContext. */ /** Creates an SslContextProvider from the given UpstreamTlsContext. */
@Override @Override
public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) {
@ -52,8 +67,17 @@ final class ClientSslContextProviderFactory
} catch (IOException ioe) { } catch (IOException ioe) {
throw new RuntimeException(ioe); throw new RuntimeException(ioe);
} }
} else if (CommonTlsContextUtil.hasCertProviderInstance(
upstreamTlsContext.getCommonTlsContext())) {
try {
return certProviderClientSslContextProviderFactory.getProvider(
upstreamTlsContext,
bootstrapper.readBootstrap().getNode().toEnvoyProtoNode(),
bootstrapper.readBootstrap().getCertProviders());
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
} }
throw new UnsupportedOperationException( throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!");
"UpstreamTlsContext to have all filenames or all SdsConfig");
} }
} }

View File

@ -23,6 +23,7 @@ import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.config.core.v3.DataSource.SpecifierCase; import io.envoyproxy.envoy.config.core.v3.DataSource.SpecifierCase;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.ValidationContextTypeCase; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.ValidationContextTypeCase;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -34,18 +35,31 @@ final class CommonTlsContextUtil {
/** Returns true only if given CommonTlsContext uses no SdsSecretConfigs. */ /** Returns true only if given CommonTlsContext uses no SdsSecretConfigs. */
static boolean hasAllSecretsUsingFilename(CommonTlsContext commonTlsContext) { static boolean hasAllSecretsUsingFilename(CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext"); return commonTlsContext != null
// return true if it has no SdsSecretConfig(s) && (commonTlsContext.getTlsCertificatesCount() > 0
return (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0) || commonTlsContext.hasValidationContext());
&& !commonTlsContext.hasValidationContextSdsSecretConfig();
} }
/** Returns true only if given CommonTlsContext uses only SdsSecretConfigs. */ /** Returns true only if given CommonTlsContext uses only SdsSecretConfigs. */
static boolean hasAllSecretsUsingSds(CommonTlsContext commonTlsContext) { static boolean hasAllSecretsUsingSds(CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext"); return commonTlsContext != null
// return true if it has only SdsSecretConfig(s) && (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0
return (commonTlsContext.getTlsCertificatesCount() == 0) || commonTlsContext.hasValidationContextSdsSecretConfig());
&& !commonTlsContext.hasValidationContext(); }
static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) {
return commonTlsContext != null
&& (commonTlsContext.hasTlsCertificateCertificateProviderInstance()
|| hasCertProviderValidationContext(commonTlsContext));
}
private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedCertificateValidationContext =
commonTlsContext.getCombinedValidationContext();
return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance();
}
return commonTlsContext.hasValidationContextCertificateProviderInstance();
} }
@Nullable @Nullable

View File

@ -16,6 +16,7 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
@ -60,6 +61,9 @@ final class SecretVolumeClientSslContextProvider extends SslContextProvider {
static SecretVolumeClientSslContextProvider getProvider(UpstreamTlsContext upstreamTlsContext) { static SecretVolumeClientSslContextProvider getProvider(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext"); checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
checkArgument(
commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0,
"unexpected TlsCertificateSdsSecretConfigs");
CertificateValidationContext certificateValidationContext = CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext); getCertificateValidationContext(commonTlsContext);
// first validate // first validate

View File

@ -16,6 +16,7 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
@ -61,6 +62,9 @@ final class SecretVolumeServerSslContextProvider extends SslContextProvider {
DownstreamTlsContext downstreamTlsContext) { DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext"); checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
checkArgument(
commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0,
"unexpected TlsCertificateSdsSecretConfigs");
TlsCertificate tlsCertificate = null; TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) { if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0); tlsCertificate = commonTlsContext.getTlsCertificates(0);

View File

@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper; import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -29,6 +30,20 @@ import java.util.concurrent.Executors;
final class ServerSslContextProviderFactory final class ServerSslContextProviderFactory
implements ValueFactory<DownstreamTlsContext, SslContextProvider> { implements ValueFactory<DownstreamTlsContext, SslContextProvider> {
private final Bootstrapper bootstrapper;
private final CertProviderServerSslContextProvider.Factory
certProviderServerSslContextProviderFactory;
ServerSslContextProviderFactory() {
this(Bootstrapper.getInstance(), CertProviderServerSslContextProvider.Factory.getInstance());
}
ServerSslContextProviderFactory(
Bootstrapper bootstrapper, CertProviderServerSslContextProvider.Factory factory) {
this.bootstrapper = bootstrapper;
this.certProviderServerSslContextProviderFactory = factory;
}
/** Creates a SslContextProvider from the given DownstreamTlsContext. */ /** Creates a SslContextProvider from the given DownstreamTlsContext. */
@Override @Override
public SslContextProvider create( public SslContextProvider create(
@ -54,8 +69,17 @@ final class ServerSslContextProviderFactory
} catch (IOException ioe) { } catch (IOException ioe) {
throw new RuntimeException(ioe); throw new RuntimeException(ioe);
} }
} else if (CommonTlsContextUtil.hasCertProviderInstance(
downstreamTlsContext.getCommonTlsContext())) {
try {
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapper.readBootstrap().getNode().toEnvoyProtoNode(),
bootstrapper.readBootstrap().getCertProviders());
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
} }
throw new UnsupportedOperationException( throw new UnsupportedOperationException("Unsupported configurations in DownstreamTlsContext!");
"DownstreamTlsContext to have all filenames or all SdsConfig");
} }
} }

View File

@ -54,7 +54,8 @@ public class CommonCertProviderTestUtils {
"-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer "-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer
Pattern.CASE_INSENSITIVE); Pattern.CASE_INSENSITIVE);
static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException { /** Creates a test bootstrap info object. */
public static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException {
String rawData = String rawData =
"{\n" "{\n"
+ " \"xds_servers\": [],\n" + " \"xds_servers\": [],\n"

View File

@ -22,12 +22,13 @@ public class TestCertificateProvider extends CertificateProvider {
int closeCalled = 0; int closeCalled = 0;
int startCalled = 0; int startCalled = 0;
TestCertificateProvider( /** Creates a TestCertificateProvider instance. */
DistributorWatcher watcher, public TestCertificateProvider(
boolean notifyCertUpdates, DistributorWatcher watcher,
Object config, boolean notifyCertUpdates,
CertificateProviderProvider certificateProviderProvider, Object config,
boolean throwExceptionForCertUpdates) { CertificateProviderProvider certificateProviderProvider,
boolean throwExceptionForCertUpdates) {
super(watcher, notifyCertUpdates); super(watcher, notifyCertUpdates);
if (throwExceptionForCertUpdates && notifyCertUpdates) { if (throwExceptionForCertUpdates && notifyCertUpdates) {
throw new UnsupportedOperationException("Provider does not support Certificate Updates."); throw new UnsupportedOperationException("Provider does not support Certificate Updates.");

View File

@ -20,20 +20,54 @@ import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderRegistry;
import io.grpc.xds.internal.certprovider.CertificateProviderStore;
import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils;
import io.grpc.xds.internal.certprovider.TestCertificateProvider;
import java.io.IOException;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link ClientSslContextProviderFactory}. */ /** Unit tests for {@link ClientSslContextProviderFactory}. */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class ClientSslContextProviderFactoryTest { public class ClientSslContextProviderFactoryTest {
ClientSslContextProviderFactory clientSslContextProviderFactory = Bootstrapper bootstrapper;
new ClientSslContextProviderFactory(); CertificateProviderRegistry certificateProviderRegistry;
CertificateProviderStore certificateProviderStore;
CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory;
ClientSslContextProviderFactory clientSslContextProviderFactory;
@Before
public void setUp() {
bootstrapper = mock(Bootstrapper.class);
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
certProviderClientSslContextProviderFactory =
new CertProviderClientSslContextProvider.Factory(certificateProviderStore);
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapper, certProviderClientSslContextProviderFactory);
}
@Test @Test
public void createSslContextProvider_allFilenames() { public void createSslContextProvider_allFilenames() {
@ -55,13 +89,10 @@ public class ClientSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContext(commonTlsContext); CommonTlsContextTestsUtil.buildUpstreamTlsContext(commonTlsContext);
try { try {
SslContextProvider unused = clientSslContextProviderFactory.create(upstreamTlsContext);
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected) assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs");
.hasMessageThat()
.isEqualTo("UpstreamTlsContext to have all filenames or all SdsConfig");
} }
} }
@ -80,10 +111,188 @@ public class ClientSslContextProviderFactoryTest {
SslContextProvider unused = SslContextProvider unused =
clientSslContextProviderFactory.create(upstreamTlsContext); clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase");
}
}
@Test
public void createCertProviderClientSslContextProvider() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_onlyRootCert() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_withStaticContext() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_2providers() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
@Test
public void createCertProviderClientSslContextProvider_ioException() throws IOException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
when(bootstrapper.readBootstrap()).thenThrow(new IOException("test IOException"));
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (RuntimeException expected) {
assertThat(expected).hasMessageThat().isEqualTo("java.io.IOException: test IOException");
}
}
@Test
public void createEmptyCommonTlsContext_exception() throws IOException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(null, null, null);
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
assertThat(expected) assertThat(expected)
.hasMessageThat() .hasMessageThat()
.isEqualTo("UpstreamTlsContext to have all filenames or all SdsConfig"); .isEqualTo("Unsupported configurations in UpstreamTlsContext!");
} }
} }
@Test
public void createNullCommonTlsContext_exception() throws IOException {
UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null);
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (NullPointerException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("upstreamTlsContext should have CommonTlsContext");
}
}
static void createAndRegisterProviderProvider(
CertificateProviderRegistry certificateProviderRegistry,
final CertificateProvider.DistributorWatcher[] watcherCaptor,
String testca,
final int i) {
final CertificateProviderProvider mockProviderProviderTestCa =
mock(CertificateProviderProvider.class);
when(mockProviderProviderTestCa.getName()).thenReturn(testca);
when(mockProviderProviderTestCa.createCertificateProvider(
any(Object.class), any(CertificateProvider.DistributorWatcher.class), eq(true)))
.thenAnswer(
new Answer<CertificateProvider>() {
@Override
public CertificateProvider answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
CertificateProvider.DistributorWatcher watcher =
(CertificateProvider.DistributorWatcher) args[1];
watcherCaptor[i] = watcher;
return new TestCertificateProvider(
watcher, true, args[0], mockProviderProviderTestCa, false);
}
});
certificateProviderRegistry.register(mockProviderProviderTestCa);
}
static void verifyWatcher(
SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) {
assertThat(watcherCaptor).isNotNull();
assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1);
assertThat(watcherCaptor.getDownstreamWatchers().iterator().next())
.isSameInstanceAs(sslContextProvider);
}
} }

View File

@ -17,13 +17,28 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.createAndRegisterProviderProvider;
import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.verifyWatcher;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderRegistry;
import io.grpc.xds.internal.certprovider.CertificateProviderStore;
import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils;
import java.io.IOException;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -32,8 +47,23 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class ServerSslContextProviderFactoryTest { public class ServerSslContextProviderFactoryTest {
ServerSslContextProviderFactory serverSslContextProviderFactory = Bootstrapper bootstrapper;
new ServerSslContextProviderFactory(); CertificateProviderRegistry certificateProviderRegistry;
CertificateProviderStore certificateProviderStore;
CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory;
ServerSslContextProviderFactory serverSslContextProviderFactory;
@Before
public void setUp() {
bootstrapper = mock(Bootstrapper.class);
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
certProviderServerSslContextProviderFactory =
new CertProviderServerSslContextProvider.Factory(certificateProviderStore);
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
bootstrapper, certProviderServerSslContextProviderFactory);
}
@Test @Test
public void createSslContextProvider_allFilenames() { public void createSslContextProvider_allFilenames() {
@ -59,10 +89,8 @@ public class ServerSslContextProviderFactoryTest {
SslContextProvider unused = SslContextProvider unused =
serverSslContextProviderFactory.create(downstreamTlsContext); serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected) assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs");
.hasMessageThat()
.isEqualTo("DownstreamTlsContext to have all filenames or all SdsConfig");
} }
} }
@ -79,10 +107,159 @@ public class ServerSslContextProviderFactoryTest {
SslContextProvider unused = SslContextProvider unused =
serverSslContextProviderFactory.create(downstreamTlsContext); serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase");
}
}
@Test
public void createCertProviderServerSslContextProvider() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderServerSslContextProvider_onlyCertInstance() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
/* rootInstanceName= */ null,
/* rootCertName= */ null,
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderServerSslContextProvider_withStaticContext() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderServerSslContextProvider_2providers() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
@Test
public void createCertProviderServerSslContextProvider_ioException() throws IOException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
when(bootstrapper.readBootstrap()).thenThrow(new IOException("test IOException"));
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (RuntimeException expected) {
assertThat(expected).hasMessageThat().isEqualTo("java.io.IOException: test IOException");
}
}
@Test
public void createEmptyCommonTlsContext_exception() throws IOException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(null, null, null);
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
assertThat(expected) assertThat(expected)
.hasMessageThat() .hasMessageThat()
.isEqualTo("DownstreamTlsContext to have all filenames or all SdsConfig"); .isEqualTo("Unsupported configurations in DownstreamTlsContext!");
}
}
@Test
public void createNullCommonTlsContext_exception() throws IOException {
DownstreamTlsContext downstreamTlsContext = new DownstreamTlsContext(null, true);
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (NullPointerException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("downstreamTlsContext should have CommonTlsContext");
} }
} }
} }