xds: Add CertProviderSslContextProviders to Client&Server SslContextProviderFactories (#7338)

This commit is contained in:
sanjaypujare 2020-08-21 14:08:39 -07:00 committed by GitHub
parent 80480e69ef
commit e6ab167334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 525 additions and 51 deletions

View File

@ -72,7 +72,7 @@ public final class EnvoyServerProtoData {
public static final class UpstreamTlsContext extends BaseTlsContext {
@VisibleForTesting
UpstreamTlsContext(CommonTlsContext commonTlsContext) {
public UpstreamTlsContext(CommonTlsContext commonTlsContext) {
super(commonTlsContext);
}
@ -93,7 +93,8 @@ public final class EnvoyServerProtoData {
private final boolean requireClientCertificate;
@VisibleForTesting
DownstreamTlsContext(CommonTlsContext commonTlsContext, boolean requireClientCertificate) {
public DownstreamTlsContext(
CommonTlsContext commonTlsContext, boolean requireClientCertificate) {
super(commonTlsContext);
this.requireClientCertificate = requireClientCertificate;
}

View File

@ -23,6 +23,7 @@ import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
@ -33,7 +34,8 @@ import java.security.cert.X509Certificate;
import java.util.Map;
/** A client SslContext provider using CertificateProviderInstance to fetch secrets. */
final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider {
@Internal
public final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider {
private CertProviderClientSslContextProvider(
Node node,
@ -70,20 +72,22 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP
}
/** Creates CertProviderClientSslContextProvider. */
static final class Factory {
@Internal
public static final class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory(CertificateProviderStore.getInstance());
private final CertificateProviderStore certificateProviderStore;
@VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) {
@VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) {
this.certificateProviderStore = certificateProviderStore;
}
static Factory getInstance() {
public static Factory getInstance() {
return DEFAULT_INSTANCE;
}
CertProviderClientSslContextProvider getProvider(
/** Creates a {@link CertProviderClientSslContextProvider}. */
public CertProviderClientSslContextProvider getProvider(
UpstreamTlsContext upstreamTlsContext,
Node node,
Map<String, CertificateProviderInfo> certProviders) {

View File

@ -23,6 +23,7 @@ import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
@ -36,7 +37,8 @@ import java.security.cert.X509Certificate;
import java.util.Map;
/** A server SslContext provider using CertificateProviderInstance to fetch secrets. */
final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider {
@Internal
public final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider {
private CertProviderServerSslContextProvider(
Node node,
@ -73,20 +75,22 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP
}
/** Creates CertProviderServerSslContextProvider. */
static final class Factory {
@Internal
public static final class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory(CertificateProviderStore.getInstance());
private final CertificateProviderStore certificateProviderStore;
@VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) {
@VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) {
this.certificateProviderStore = certificateProviderStore;
}
static Factory getInstance() {
public static Factory getInstance() {
return DEFAULT_INSTANCE;
}
CertProviderServerSslContextProvider getProvider(
/** Creates a {@link CertProviderServerSslContextProvider}. */
public CertProviderServerSslContextProvider getProvider(
DownstreamTlsContext downstreamTlsContext,
Node node,
Map<String, CertificateProviderInfo> certProviders) {

View File

@ -23,6 +23,7 @@ import io.grpc.Status;
import io.grpc.xds.internal.sds.Closeable;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -48,7 +49,7 @@ public abstract class CertificateProvider implements Closeable {
}
@VisibleForTesting
static final class DistributorWatcher implements Watcher {
public static final class DistributorWatcher implements Watcher {
private PrivateKey privateKey;
private List<X509Certificate> certChain;
private List<X509Certificate> trustedRoots;
@ -70,6 +71,10 @@ public abstract class CertificateProvider implements Closeable {
downstreamWatchers.remove(watcher);
}
@VisibleForTesting public Set<Watcher> getDownstreamWatchers() {
return Collections.unmodifiableSet(downstreamWatchers);
}
private void sendLastCertificateUpdate(Watcher watcher) {
watcher.updateCertificate(privateKey, certChain);
}

View File

@ -16,13 +16,15 @@
package io.grpc.xds.internal.certprovider;
import io.grpc.Internal;
import io.grpc.xds.internal.certprovider.CertificateProvider.Watcher;
/**
* Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may
* move this out of the internal package and make this an official API in the future.
*/
interface CertificateProviderProvider {
@Internal
public interface CertificateProviderProvider {
/** Returns the unique name of the {@link CertificateProvider} plugin. */
String getName();

View File

@ -31,7 +31,7 @@ public final class CertificateProviderRegistry {
new LinkedHashMap<>();
@VisibleForTesting
CertificateProviderRegistry() {
public CertificateProviderRegistry() {
}
/** Returns the singleton registry. */

View File

@ -139,7 +139,7 @@ public final class CertificateProviderStore {
}
@VisibleForTesting
CertificateProviderStore(CertificateProviderRegistry certificateProviderRegistry) {
public CertificateProviderStore(CertificateProviderRegistry certificateProviderRegistry) {
this.certificateProviderRegistry = certificateProviderRegistry;
certProviderMap = new ReferenceCountingMap<>(new CertProviderFactory());
}

View File

@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException;
import java.util.concurrent.Executors;
@ -29,6 +30,20 @@ import java.util.concurrent.Executors;
final class ClientSslContextProviderFactory
implements ValueFactory<UpstreamTlsContext, SslContextProvider> {
private final Bootstrapper bootstrapper;
private final CertProviderClientSslContextProvider.Factory
certProviderClientSslContextProviderFactory;
ClientSslContextProviderFactory() {
this(Bootstrapper.getInstance(), CertProviderClientSslContextProvider.Factory.getInstance());
}
ClientSslContextProviderFactory(
Bootstrapper bootstrapper, CertProviderClientSslContextProvider.Factory factory) {
this.bootstrapper = bootstrapper;
this.certProviderClientSslContextProviderFactory = factory;
}
/** Creates an SslContextProvider from the given UpstreamTlsContext. */
@Override
public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) {
@ -52,8 +67,17 @@ final class ClientSslContextProviderFactory
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
} else if (CommonTlsContextUtil.hasCertProviderInstance(
upstreamTlsContext.getCommonTlsContext())) {
try {
return certProviderClientSslContextProviderFactory.getProvider(
upstreamTlsContext,
bootstrapper.readBootstrap().getNode().toEnvoyProtoNode(),
bootstrapper.readBootstrap().getCertProviders());
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
throw new UnsupportedOperationException(
"UpstreamTlsContext to have all filenames or all SdsConfig");
throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!");
}
}

View File

@ -23,6 +23,7 @@ import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.config.core.v3.DataSource.SpecifierCase;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.ValidationContextTypeCase;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import javax.annotation.Nullable;
@ -34,18 +35,31 @@ final class CommonTlsContextUtil {
/** Returns true only if given CommonTlsContext uses no SdsSecretConfigs. */
static boolean hasAllSecretsUsingFilename(CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext");
// return true if it has no SdsSecretConfig(s)
return (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0)
&& !commonTlsContext.hasValidationContextSdsSecretConfig();
return commonTlsContext != null
&& (commonTlsContext.getTlsCertificatesCount() > 0
|| commonTlsContext.hasValidationContext());
}
/** Returns true only if given CommonTlsContext uses only SdsSecretConfigs. */
static boolean hasAllSecretsUsingSds(CommonTlsContext commonTlsContext) {
checkNotNull(commonTlsContext, "commonTlsContext");
// return true if it has only SdsSecretConfig(s)
return (commonTlsContext.getTlsCertificatesCount() == 0)
&& !commonTlsContext.hasValidationContext();
return commonTlsContext != null
&& (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0
|| commonTlsContext.hasValidationContextSdsSecretConfig());
}
static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) {
return commonTlsContext != null
&& (commonTlsContext.hasTlsCertificateCertificateProviderInstance()
|| hasCertProviderValidationContext(commonTlsContext));
}
private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedCertificateValidationContext =
commonTlsContext.getCombinedValidationContext();
return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance();
}
return commonTlsContext.hasValidationContextCertificateProviderInstance();
}
@Nullable

View File

@ -16,6 +16,7 @@
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
@ -60,6 +61,9 @@ final class SecretVolumeClientSslContextProvider extends SslContextProvider {
static SecretVolumeClientSslContextProvider getProvider(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
checkArgument(
commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0,
"unexpected TlsCertificateSdsSecretConfigs");
CertificateValidationContext certificateValidationContext =
getCertificateValidationContext(commonTlsContext);
// first validate

View File

@ -16,6 +16,7 @@
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.getCertificateValidationContext;
import static io.grpc.xds.internal.sds.CommonTlsContextUtil.validateCertificateContext;
@ -61,6 +62,9 @@ final class SecretVolumeServerSslContextProvider extends SslContextProvider {
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
checkArgument(
commonTlsContext.getTlsCertificateSdsSecretConfigsCount() == 0,
"unexpected TlsCertificateSdsSecretConfigs");
TlsCertificate tlsCertificate = null;
if (commonTlsContext.getTlsCertificatesCount() > 0) {
tlsCertificate = commonTlsContext.getTlsCertificates(0);

View File

@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException;
import java.util.concurrent.Executors;
@ -29,6 +30,20 @@ import java.util.concurrent.Executors;
final class ServerSslContextProviderFactory
implements ValueFactory<DownstreamTlsContext, SslContextProvider> {
private final Bootstrapper bootstrapper;
private final CertProviderServerSslContextProvider.Factory
certProviderServerSslContextProviderFactory;
ServerSslContextProviderFactory() {
this(Bootstrapper.getInstance(), CertProviderServerSslContextProvider.Factory.getInstance());
}
ServerSslContextProviderFactory(
Bootstrapper bootstrapper, CertProviderServerSslContextProvider.Factory factory) {
this.bootstrapper = bootstrapper;
this.certProviderServerSslContextProviderFactory = factory;
}
/** Creates a SslContextProvider from the given DownstreamTlsContext. */
@Override
public SslContextProvider create(
@ -54,8 +69,17 @@ final class ServerSslContextProviderFactory
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
} else if (CommonTlsContextUtil.hasCertProviderInstance(
downstreamTlsContext.getCommonTlsContext())) {
try {
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapper.readBootstrap().getNode().toEnvoyProtoNode(),
bootstrapper.readBootstrap().getCertProviders());
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
throw new UnsupportedOperationException(
"DownstreamTlsContext to have all filenames or all SdsConfig");
throw new UnsupportedOperationException("Unsupported configurations in DownstreamTlsContext!");
}
}

View File

@ -54,7 +54,8 @@ public class CommonCertProviderTestUtils {
"-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer
Pattern.CASE_INSENSITIVE);
static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException {
/** Creates a test bootstrap info object. */
public static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException {
String rawData =
"{\n"
+ " \"xds_servers\": [],\n"

View File

@ -22,12 +22,13 @@ public class TestCertificateProvider extends CertificateProvider {
int closeCalled = 0;
int startCalled = 0;
TestCertificateProvider(
DistributorWatcher watcher,
boolean notifyCertUpdates,
Object config,
CertificateProviderProvider certificateProviderProvider,
boolean throwExceptionForCertUpdates) {
/** Creates a TestCertificateProvider instance. */
public TestCertificateProvider(
DistributorWatcher watcher,
boolean notifyCertUpdates,
Object config,
CertificateProviderProvider certificateProviderProvider,
boolean throwExceptionForCertUpdates) {
super(watcher, notifyCertUpdates);
if (throwExceptionForCertUpdates && notifyCertUpdates) {
throw new UnsupportedOperationException("Provider does not support Certificate Updates.");

View File

@ -20,20 +20,54 @@ import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderRegistry;
import io.grpc.xds.internal.certprovider.CertificateProviderStore;
import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils;
import io.grpc.xds.internal.certprovider.TestCertificateProvider;
import java.io.IOException;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link ClientSslContextProviderFactory}. */
@RunWith(JUnit4.class)
public class ClientSslContextProviderFactoryTest {
ClientSslContextProviderFactory clientSslContextProviderFactory =
new ClientSslContextProviderFactory();
Bootstrapper bootstrapper;
CertificateProviderRegistry certificateProviderRegistry;
CertificateProviderStore certificateProviderStore;
CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory;
ClientSslContextProviderFactory clientSslContextProviderFactory;
@Before
public void setUp() {
bootstrapper = mock(Bootstrapper.class);
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
certProviderClientSslContextProviderFactory =
new CertProviderClientSslContextProvider.Factory(certificateProviderStore);
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapper, certProviderClientSslContextProviderFactory);
}
@Test
public void createSslContextProvider_allFilenames() {
@ -55,13 +89,10 @@ public class ClientSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildUpstreamTlsContext(commonTlsContext);
try {
SslContextProvider unused =
clientSslContextProviderFactory.create(upstreamTlsContext);
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("UpstreamTlsContext to have all filenames or all SdsConfig");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs");
}
}
@ -80,10 +111,188 @@ public class ClientSslContextProviderFactoryTest {
SslContextProvider unused =
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase");
}
}
@Test
public void createCertProviderClientSslContextProvider() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_onlyRootCert() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_withStaticContext() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_2providers() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
@Test
public void createCertProviderClientSslContextProvider_ioException() throws IOException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
when(bootstrapper.readBootstrap()).thenThrow(new IOException("test IOException"));
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (RuntimeException expected) {
assertThat(expected).hasMessageThat().isEqualTo("java.io.IOException: test IOException");
}
}
@Test
public void createEmptyCommonTlsContext_exception() throws IOException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(null, null, null);
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("UpstreamTlsContext to have all filenames or all SdsConfig");
.isEqualTo("Unsupported configurations in UpstreamTlsContext!");
}
}
@Test
public void createNullCommonTlsContext_exception() throws IOException {
UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null);
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (NullPointerException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("upstreamTlsContext should have CommonTlsContext");
}
}
static void createAndRegisterProviderProvider(
CertificateProviderRegistry certificateProviderRegistry,
final CertificateProvider.DistributorWatcher[] watcherCaptor,
String testca,
final int i) {
final CertificateProviderProvider mockProviderProviderTestCa =
mock(CertificateProviderProvider.class);
when(mockProviderProviderTestCa.getName()).thenReturn(testca);
when(mockProviderProviderTestCa.createCertificateProvider(
any(Object.class), any(CertificateProvider.DistributorWatcher.class), eq(true)))
.thenAnswer(
new Answer<CertificateProvider>() {
@Override
public CertificateProvider answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
CertificateProvider.DistributorWatcher watcher =
(CertificateProvider.DistributorWatcher) args[1];
watcherCaptor[i] = watcher;
return new TestCertificateProvider(
watcher, true, args[0], mockProviderProviderTestCa, false);
}
});
certificateProviderRegistry.register(mockProviderProviderTestCa);
}
static void verifyWatcher(
SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) {
assertThat(watcherCaptor).isNotNull();
assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1);
assertThat(watcherCaptor.getDownstreamWatchers().iterator().next())
.isSameInstanceAs(sslContextProvider);
}
}

View File

@ -17,13 +17,28 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.createAndRegisterProviderProvider;
import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.verifyWatcher;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderRegistry;
import io.grpc.xds.internal.certprovider.CertificateProviderStore;
import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils;
import java.io.IOException;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@ -32,8 +47,23 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class ServerSslContextProviderFactoryTest {
ServerSslContextProviderFactory serverSslContextProviderFactory =
new ServerSslContextProviderFactory();
Bootstrapper bootstrapper;
CertificateProviderRegistry certificateProviderRegistry;
CertificateProviderStore certificateProviderStore;
CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory;
ServerSslContextProviderFactory serverSslContextProviderFactory;
@Before
public void setUp() {
bootstrapper = mock(Bootstrapper.class);
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
certProviderServerSslContextProviderFactory =
new CertProviderServerSslContextProvider.Factory(certificateProviderStore);
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
bootstrapper, certProviderServerSslContextProviderFactory);
}
@Test
public void createSslContextProvider_allFilenames() {
@ -59,10 +89,8 @@ public class ServerSslContextProviderFactoryTest {
SslContextProvider unused =
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("DownstreamTlsContext to have all filenames or all SdsConfig");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs");
}
}
@ -79,10 +107,159 @@ public class ServerSslContextProviderFactoryTest {
SslContextProvider unused =
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase");
}
}
@Test
public void createCertProviderServerSslContextProvider() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderServerSslContextProvider_onlyCertInstance() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
/* rootInstanceName= */ null,
/* rootCertName= */ null,
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderServerSslContextProvider_withStaticContext() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderServerSslContextProvider_2providers() throws IOException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonCertProviderTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
@Test
public void createCertProviderServerSslContextProvider_ioException() throws IOException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
when(bootstrapper.readBootstrap()).thenThrow(new IOException("test IOException"));
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (RuntimeException expected) {
assertThat(expected).hasMessageThat().isEqualTo("java.io.IOException: test IOException");
}
}
@Test
public void createEmptyCommonTlsContext_exception() throws IOException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(null, null, null);
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("DownstreamTlsContext to have all filenames or all SdsConfig");
.isEqualTo("Unsupported configurations in DownstreamTlsContext!");
}
}
@Test
public void createNullCommonTlsContext_exception() throws IOException {
DownstreamTlsContext downstreamTlsContext = new DownstreamTlsContext(null, true);
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (NullPointerException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("downstreamTlsContext should have CommonTlsContext");
}
}
}