From e6ab16733421c33e52f0d5d8755d41ec249ca6e4 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Fri, 21 Aug 2020 14:08:39 -0700 Subject: [PATCH] xds: Add CertProviderSslContextProviders to Client&Server SslContextProviderFactories (#7338) --- .../io/grpc/xds/EnvoyServerProtoData.java | 5 +- .../CertProviderClientSslContextProvider.java | 14 +- .../CertProviderServerSslContextProvider.java | 14 +- .../certprovider/CertificateProvider.java | 7 +- .../CertificateProviderProvider.java | 4 +- .../CertificateProviderRegistry.java | 2 +- .../CertificateProviderStore.java | 2 +- .../sds/ClientSslContextProviderFactory.java | 28 ++- .../internal/sds/CommonTlsContextUtil.java | 30 ++- .../SecretVolumeClientSslContextProvider.java | 4 + .../SecretVolumeServerSslContextProvider.java | 4 + .../sds/ServerSslContextProviderFactory.java | 28 ++- .../CommonCertProviderTestUtils.java | 3 +- .../certprovider/TestCertificateProvider.java | 13 +- .../ClientSslContextProviderFactoryTest.java | 227 +++++++++++++++++- .../ServerSslContextProviderFactoryTest.java | 191 ++++++++++++++- 16 files changed, 525 insertions(+), 51 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index c0ab9083fb..9911f011ed 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -72,7 +72,7 @@ public final class EnvoyServerProtoData { public static final class UpstreamTlsContext extends BaseTlsContext { @VisibleForTesting - UpstreamTlsContext(CommonTlsContext commonTlsContext) { + public UpstreamTlsContext(CommonTlsContext commonTlsContext) { super(commonTlsContext); } @@ -93,7 +93,8 @@ public final class EnvoyServerProtoData { private final boolean requireClientCertificate; @VisibleForTesting - DownstreamTlsContext(CommonTlsContext commonTlsContext, boolean requireClientCertificate) { + public DownstreamTlsContext( + CommonTlsContext commonTlsContext, boolean requireClientCertificate) { super(commonTlsContext); this.requireClientCertificate = requireClientCertificate; } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java index 562f27431c..1dc7be1be3 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java @@ -23,6 +23,7 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; +import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; @@ -33,7 +34,8 @@ import java.security.cert.X509Certificate; import java.util.Map; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ -final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { +@Internal +public final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { private CertProviderClientSslContextProvider( Node node, @@ -70,20 +72,22 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP } /** Creates CertProviderClientSslContextProvider. */ - static final class Factory { + @Internal + public static final class Factory { private static final Factory DEFAULT_INSTANCE = new Factory(CertificateProviderStore.getInstance()); private final CertificateProviderStore certificateProviderStore; - @VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) { + @VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) { this.certificateProviderStore = certificateProviderStore; } - static Factory getInstance() { + public static Factory getInstance() { return DEFAULT_INSTANCE; } - CertProviderClientSslContextProvider getProvider( + /** Creates a {@link CertProviderClientSslContextProvider}. */ + public CertProviderClientSslContextProvider getProvider( UpstreamTlsContext upstreamTlsContext, Node node, Map certProviders) { diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java index 5d1a6399b8..78e825f60f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java @@ -23,6 +23,7 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; +import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; @@ -36,7 +37,8 @@ import java.security.cert.X509Certificate; import java.util.Map; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ -final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { +@Internal +public final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { private CertProviderServerSslContextProvider( Node node, @@ -73,20 +75,22 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP } /** Creates CertProviderServerSslContextProvider. */ - static final class Factory { + @Internal + public static final class Factory { private static final Factory DEFAULT_INSTANCE = new Factory(CertificateProviderStore.getInstance()); private final CertificateProviderStore certificateProviderStore; - @VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) { + @VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) { this.certificateProviderStore = certificateProviderStore; } - static Factory getInstance() { + public static Factory getInstance() { return DEFAULT_INSTANCE; } - CertProviderServerSslContextProvider getProvider( + /** Creates a {@link CertProviderServerSslContextProvider}. */ + public CertProviderServerSslContextProvider getProvider( DownstreamTlsContext downstreamTlsContext, Node node, Map certProviders) { diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java index b5d149777e..04ed997fa5 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java @@ -23,6 +23,7 @@ import io.grpc.Status; import io.grpc.xds.internal.sds.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -48,7 +49,7 @@ public abstract class CertificateProvider implements Closeable { } @VisibleForTesting - static final class DistributorWatcher implements Watcher { + public static final class DistributorWatcher implements Watcher { private PrivateKey privateKey; private List certChain; private List trustedRoots; @@ -70,6 +71,10 @@ public abstract class CertificateProvider implements Closeable { downstreamWatchers.remove(watcher); } + @VisibleForTesting public Set getDownstreamWatchers() { + return Collections.unmodifiableSet(downstreamWatchers); + } + private void sendLastCertificateUpdate(Watcher watcher) { watcher.updateCertificate(privateKey, certChain); } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java index 92b2d4d6aa..a426542eea 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java @@ -16,13 +16,15 @@ package io.grpc.xds.internal.certprovider; +import io.grpc.Internal; import io.grpc.xds.internal.certprovider.CertificateProvider.Watcher; /** * Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may * move this out of the internal package and make this an official API in the future. */ -interface CertificateProviderProvider { +@Internal +public interface CertificateProviderProvider { /** Returns the unique name of the {@link CertificateProvider} plugin. */ String getName(); diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java index 36db37e5db..04a8f73f26 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java @@ -31,7 +31,7 @@ public final class CertificateProviderRegistry { new LinkedHashMap<>(); @VisibleForTesting - CertificateProviderRegistry() { + public CertificateProviderRegistry() { } /** Returns the singleton registry. */ diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java index 18f8a11ec2..43143ebb3a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java @@ -139,7 +139,7 @@ public final class CertificateProviderStore { } @VisibleForTesting - CertificateProviderStore(CertificateProviderRegistry certificateProviderRegistry) { + public CertificateProviderStore(CertificateProviderRegistry certificateProviderRegistry) { this.certificateProviderRegistry = certificateProviderRegistry; certProviderMap = new ReferenceCountingMap<>(new CertProviderFactory()); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java index 84b2f8284a..8e593fd560 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.util.concurrent.ThreadFactoryBuilder; import io.grpc.xds.Bootstrapper; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import java.io.IOException; import java.util.concurrent.Executors; @@ -29,6 +30,20 @@ import java.util.concurrent.Executors; final class ClientSslContextProviderFactory implements ValueFactory { + private final Bootstrapper bootstrapper; + private final CertProviderClientSslContextProvider.Factory + certProviderClientSslContextProviderFactory; + + ClientSslContextProviderFactory() { + this(Bootstrapper.getInstance(), CertProviderClientSslContextProvider.Factory.getInstance()); + } + + ClientSslContextProviderFactory( + Bootstrapper bootstrapper, CertProviderClientSslContextProvider.Factory factory) { + this.bootstrapper = bootstrapper; + this.certProviderClientSslContextProviderFactory = factory; + } + /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { @@ -52,8 +67,17 @@ final class ClientSslContextProviderFactory } catch (IOException ioe) { throw new RuntimeException(ioe); } + } else if (CommonTlsContextUtil.hasCertProviderInstance( + upstreamTlsContext.getCommonTlsContext())) { + try { + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapper.readBootstrap().getNode().toEnvoyProtoNode(), + bootstrapper.readBootstrap().getCertProviders()); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } } - throw new UnsupportedOperationException( - "UpstreamTlsContext to have all filenames or all SdsConfig"); + throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java index 54dd0e7117..5ffb356e18 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java @@ -23,6 +23,7 @@ import static com.google.common.base.Preconditions.checkState; import io.envoyproxy.envoy.config.core.v3.DataSource.SpecifierCase; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.ValidationContextTypeCase; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import javax.annotation.Nullable; @@ -34,18 +35,31 @@ final class CommonTlsContextUtil { /** Returns true only if given CommonTlsContext uses no SdsSecretConfigs. */ static boolean hasAllSecretsUsingFilename(CommonTlsContext commonTlsContext) { - checkNotNull(commonTlsContext, "commonTlsContext"); - // return true if it has no SdsSecretConfig(s) - return (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0) - && !commonTlsContext.hasValidationContextSdsSecretConfig(); + return commonTlsContext != null + && (commonTlsContext.getTlsCertificatesCount() > 0 + || commonTlsContext.hasValidationContext()); } /** Returns true only if given CommonTlsContext uses only SdsSecretConfigs. */ static boolean hasAllSecretsUsingSds(CommonTlsContext commonTlsContext) { - checkNotNull(commonTlsContext, "commonTlsContext"); - // return true if it has only SdsSecretConfig(s) - return (commonTlsContext.getTlsCertificatesCount() == 0) - && !commonTlsContext.hasValidationContext(); + return commonTlsContext != null + && (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0 + || commonTlsContext.hasValidationContextSdsSecretConfig()); + } + + static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) { + return commonTlsContext != null + && (commonTlsContext.hasTlsCertificateCertificateProviderInstance() + || hasCertProviderValidationContext(commonTlsContext)); + } + + private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasCombinedValidationContext()) { + CombinedCertificateValidationContext combinedCertificateValidationContext = + commonTlsContext.getCombinedValidationContext(); + return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance(); + } + return commonTlsContext.hasValidationContextCertificateProviderInstance(); } @Nullable diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java index fc3befd5f1..8698c6cc09 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java @@ -16,6 +16,7 @@ package io.grpc.xds.internal.sds; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext; @@ -60,6 +61,9 @@ final class SecretVolumeClientSslContextProvider extends SslContextProvider { static SecretVolumeClientSslContextProvider getProvider(UpstreamTlsContext upstreamTlsContext) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); + checkArgument( + commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0, + "unexpected TlsCertificateSdsSecretConfigs"); CertificateValidationContext certificateValidationContext = getCertificateValidationContext(commonTlsContext); // first validate diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java index 943b385c6a..f1bb93a651 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java @@ -16,6 +16,7 @@ package io.grpc.xds.internal.sds; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext; import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext; @@ -61,6 +62,9 @@ final class SecretVolumeServerSslContextProvider extends SslContextProvider { DownstreamTlsContext downstreamTlsContext) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); + checkArgument( + commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0, + "unexpected TlsCertificateSdsSecretConfigs"); TlsCertificate tlsCertificate = null; if (commonTlsContext.getTlsCertificatesCount() > 0) { tlsCertificate = commonTlsContext.getTlsCertificates(0); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java index 1cd2cfa8e9..6ae5308930 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.util.concurrent.ThreadFactoryBuilder; import io.grpc.xds.Bootstrapper; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; +import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import java.io.IOException; import java.util.concurrent.Executors; @@ -29,6 +30,20 @@ import java.util.concurrent.Executors; final class ServerSslContextProviderFactory implements ValueFactory { + private final Bootstrapper bootstrapper; + private final CertProviderServerSslContextProvider.Factory + certProviderServerSslContextProviderFactory; + + ServerSslContextProviderFactory() { + this(Bootstrapper.getInstance(), CertProviderServerSslContextProvider.Factory.getInstance()); + } + + ServerSslContextProviderFactory( + Bootstrapper bootstrapper, CertProviderServerSslContextProvider.Factory factory) { + this.bootstrapper = bootstrapper; + this.certProviderServerSslContextProviderFactory = factory; + } + /** Creates a SslContextProvider from the given DownstreamTlsContext. */ @Override public SslContextProvider create( @@ -54,8 +69,17 @@ final class ServerSslContextProviderFactory } catch (IOException ioe) { throw new RuntimeException(ioe); } + } else if (CommonTlsContextUtil.hasCertProviderInstance( + downstreamTlsContext.getCommonTlsContext())) { + try { + return certProviderServerSslContextProviderFactory.getProvider( + downstreamTlsContext, + bootstrapper.readBootstrap().getNode().toEnvoyProtoNode(), + bootstrapper.readBootstrap().getCertProviders()); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } } - throw new UnsupportedOperationException( - "DownstreamTlsContext to have all filenames or all SdsConfig"); + throw new UnsupportedOperationException("Unsupported configurations in DownstreamTlsContext!"); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java index 34347b68e2..ce494e5efa 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java @@ -54,7 +54,8 @@ public class CommonCertProviderTestUtils { "-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer Pattern.CASE_INSENSITIVE); - static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException { + /** Creates a test bootstrap info object. */ + public static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException { String rawData = "{\n" + " \"xds_servers\": [],\n" diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java index 406ae4f0bb..9253d071fb 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java @@ -22,12 +22,13 @@ public class TestCertificateProvider extends CertificateProvider { int closeCalled = 0; int startCalled = 0; - TestCertificateProvider( - DistributorWatcher watcher, - boolean notifyCertUpdates, - Object config, - CertificateProviderProvider certificateProviderProvider, - boolean throwExceptionForCertUpdates) { + /** Creates a TestCertificateProvider instance. */ + public TestCertificateProvider( + DistributorWatcher watcher, + boolean notifyCertUpdates, + Object config, + CertificateProviderProvider certificateProviderProvider, + boolean throwExceptionForCertUpdates) { super(watcher, notifyCertUpdates); if (throwExceptionForCertUpdates && notifyCertUpdates) { throw new UnsupportedOperationException("Provider does not support Certificate Updates."); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index 8fe011f0ee..c1c2b30e68 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -20,20 +20,54 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableSet; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.xds.Bootstrapper; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider; +import io.grpc.xds.internal.certprovider.CertificateProvider; +import io.grpc.xds.internal.certprovider.CertificateProviderProvider; +import io.grpc.xds.internal.certprovider.CertificateProviderRegistry; +import io.grpc.xds.internal.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils; +import io.grpc.xds.internal.certprovider.TestCertificateProvider; +import java.io.IOException; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; /** Unit tests for {@link ClientSslContextProviderFactory}. */ @RunWith(JUnit4.class) public class ClientSslContextProviderFactoryTest { - ClientSslContextProviderFactory clientSslContextProviderFactory = - new ClientSslContextProviderFactory(); + Bootstrapper bootstrapper; + CertificateProviderRegistry certificateProviderRegistry; + CertificateProviderStore certificateProviderStore; + CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory; + ClientSslContextProviderFactory clientSslContextProviderFactory; + + @Before + public void setUp() { + bootstrapper = mock(Bootstrapper.class); + certificateProviderRegistry = new CertificateProviderRegistry(); + certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); + certProviderClientSslContextProviderFactory = + new CertProviderClientSslContextProvider.Factory(certificateProviderStore); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + bootstrapper, certProviderClientSslContextProviderFactory); + } @Test public void createSslContextProvider_allFilenames() { @@ -55,13 +89,10 @@ public class ClientSslContextProviderFactoryTest { CommonTlsContextTestsUtil.buildUpstreamTlsContext(commonTlsContext); try { - SslContextProvider unused = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(upstreamTlsContext); Assert.fail("no exception thrown"); - } catch (UnsupportedOperationException expected) { - assertThat(expected) - .hasMessageThat() - .isEqualTo("UpstreamTlsContext to have all filenames or all SdsConfig"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs"); } } @@ -80,10 +111,188 @@ public class ClientSslContextProviderFactoryTest { SslContextProvider unused = clientSslContextProviderFactory.create(upstreamTlsContext); Assert.fail("no exception thrown"); + } catch (IllegalStateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase"); + } + } + + @Test + public void createCertProviderClientSslContextProvider() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + clientSslContextProviderFactory.create(upstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + + @Test + public void createCertProviderClientSslContextProvider_onlyRootCert() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + /* certInstanceName= */ null, + /* certName= */ null, + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + clientSslContextProviderFactory.create(upstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + + @Test + public void createCertProviderClientSslContextProvider_withStaticContext() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder() + .addAllMatchSubjectAltNames( + ImmutableSet.of( + StringMatcher.newBuilder().setExact("foo").build(), + StringMatcher.newBuilder().setExact("bar").build())) + .build(); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + /* certInstanceName= */ null, + /* certName= */ null, + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + staticCertValidationContext); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + clientSslContextProviderFactory.create(upstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + + @Test + public void createCertProviderClientSslContextProvider_2providers() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[2]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + + createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "file_watcher", 1); + + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "file_provider", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + clientSslContextProviderFactory.create(upstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[1]); + } + + @Test + public void createCertProviderClientSslContextProvider_ioException() throws IOException { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + when(bootstrapper.readBootstrap()).thenThrow(new IOException("test IOException")); + try { + clientSslContextProviderFactory.create(upstreamTlsContext); + Assert.fail("no exception thrown"); + } catch (RuntimeException expected) { + assertThat(expected).hasMessageThat().isEqualTo("java.io.IOException: test IOException"); + } + } + + @Test + public void createEmptyCommonTlsContext_exception() throws IOException { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(null, null, null); + try { + clientSslContextProviderFactory.create(upstreamTlsContext); + Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { assertThat(expected) .hasMessageThat() - .isEqualTo("UpstreamTlsContext to have all filenames or all SdsConfig"); + .isEqualTo("Unsupported configurations in UpstreamTlsContext!"); } } + + @Test + public void createNullCommonTlsContext_exception() throws IOException { + UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null); + try { + clientSslContextProviderFactory.create(upstreamTlsContext); + Assert.fail("no exception thrown"); + } catch (NullPointerException expected) { + assertThat(expected) + .hasMessageThat() + .isEqualTo("upstreamTlsContext should have CommonTlsContext"); + } + } + + static void createAndRegisterProviderProvider( + CertificateProviderRegistry certificateProviderRegistry, + final CertificateProvider.DistributorWatcher[] watcherCaptor, + String testca, + final int i) { + final CertificateProviderProvider mockProviderProviderTestCa = + mock(CertificateProviderProvider.class); + when(mockProviderProviderTestCa.getName()).thenReturn(testca); + + when(mockProviderProviderTestCa.createCertificateProvider( + any(Object.class), any(CertificateProvider.DistributorWatcher.class), eq(true))) + .thenAnswer( + new Answer() { + @Override + public CertificateProvider answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + CertificateProvider.DistributorWatcher watcher = + (CertificateProvider.DistributorWatcher) args[1]; + watcherCaptor[i] = watcher; + return new TestCertificateProvider( + watcher, true, args[0], mockProviderProviderTestCa, false); + } + }); + certificateProviderRegistry.register(mockProviderProviderTestCa); + } + + static void verifyWatcher( + SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) { + assertThat(watcherCaptor).isNotNull(); + assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1); + assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) + .isSameInstanceAs(sslContextProvider); + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index 1c2b58f7b2..4ab957bdb4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -17,13 +17,28 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.createAndRegisterProviderProvider; +import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.verifyWatcher; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableSet; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.xds.Bootstrapper; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; +import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider; +import io.grpc.xds.internal.certprovider.CertificateProvider; +import io.grpc.xds.internal.certprovider.CertificateProviderRegistry; +import io.grpc.xds.internal.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils; +import java.io.IOException; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -32,8 +47,23 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ServerSslContextProviderFactoryTest { - ServerSslContextProviderFactory serverSslContextProviderFactory = - new ServerSslContextProviderFactory(); + Bootstrapper bootstrapper; + CertificateProviderRegistry certificateProviderRegistry; + CertificateProviderStore certificateProviderStore; + CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory; + ServerSslContextProviderFactory serverSslContextProviderFactory; + + @Before + public void setUp() { + bootstrapper = mock(Bootstrapper.class); + certificateProviderRegistry = new CertificateProviderRegistry(); + certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); + certProviderServerSslContextProviderFactory = + new CertProviderServerSslContextProvider.Factory(certificateProviderStore); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + bootstrapper, certProviderServerSslContextProviderFactory); + } @Test public void createSslContextProvider_allFilenames() { @@ -59,10 +89,8 @@ public class ServerSslContextProviderFactoryTest { SslContextProvider unused = serverSslContextProviderFactory.create(downstreamTlsContext); Assert.fail("no exception thrown"); - } catch (UnsupportedOperationException expected) { - assertThat(expected) - .hasMessageThat() - .isEqualTo("DownstreamTlsContext to have all filenames or all SdsConfig"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs"); } } @@ -79,10 +107,159 @@ public class ServerSslContextProviderFactoryTest { SslContextProvider unused = serverSslContextProviderFactory.create(downstreamTlsContext); Assert.fail("no exception thrown"); + } catch (IllegalStateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase"); + } + } + + @Test + public void createCertProviderServerSslContextProvider() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null, + /* requireClientCert= */ true); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + serverSslContextProviderFactory.create(downstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + + @Test + public void createCertProviderServerSslContextProvider_onlyCertInstance() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + /* rootInstanceName= */ null, + /* rootCertName= */ null, + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null, + /* requireClientCert= */ true); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + serverSslContextProviderFactory.create(downstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + + @Test + public void createCertProviderServerSslContextProvider_withStaticContext() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder() + .addAllMatchSubjectAltNames( + ImmutableSet.of( + StringMatcher.newBuilder().setExact("foo").build(), + StringMatcher.newBuilder().setExact("bar").build())) + .build(); + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + staticCertValidationContext, + /* requireClientCert= */ true); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + serverSslContextProviderFactory.create(downstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + + @Test + public void createCertProviderServerSslContextProvider_2providers() throws IOException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[2]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + + createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "file_watcher", 1); + + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "file_provider", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null, + /* requireClientCert= */ true); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + serverSslContextProviderFactory.create(downstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[1]); + } + + @Test + public void createCertProviderServerSslContextProvider_ioException() throws IOException { + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null, + /* requireClientCert= */ true); + when(bootstrapper.readBootstrap()).thenThrow(new IOException("test IOException")); + try { + serverSslContextProviderFactory.create(downstreamTlsContext); + Assert.fail("no exception thrown"); + } catch (RuntimeException expected) { + assertThat(expected).hasMessageThat().isEqualTo("java.io.IOException: test IOException"); + } + } + + @Test + public void createEmptyCommonTlsContext_exception() throws IOException { + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(null, null, null); + try { + serverSslContextProviderFactory.create(downstreamTlsContext); + Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { assertThat(expected) .hasMessageThat() - .isEqualTo("DownstreamTlsContext to have all filenames or all SdsConfig"); + .isEqualTo("Unsupported configurations in DownstreamTlsContext!"); + } + } + + @Test + public void createNullCommonTlsContext_exception() throws IOException { + DownstreamTlsContext downstreamTlsContext = new DownstreamTlsContext(null, true); + try { + serverSslContextProviderFactory.create(downstreamTlsContext); + Assert.fail("no exception thrown"); + } catch (NullPointerException expected) { + assertThat(expected) + .hasMessageThat() + .isEqualTo("downstreamTlsContext should have CommonTlsContext"); } } }