From 22603810b97ed2a3f7ff82aed1ec7fe136aea7ff Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 8 Sep 2021 23:06:21 +0000 Subject: [PATCH] xds: use the new cert-provider instances if present (#8494) --- .../java/io/grpc/xds/ClientXdsClient.java | 93 ++++++++++-------- .../CertProviderClientSslContextProvider.java | 28 ++---- .../CertProviderServerSslContextProvider.java | 28 ++---- .../CertProviderSslContextProvider.java | 48 ++++++++++ .../internal/sds/CommonTlsContextUtil.java | 14 ++- .../io/grpc/xds/ClientXdsClientDataTest.java | 32 +++++-- .../io/grpc/xds/ClientXdsClientTestBase.java | 41 +++++++- .../io/grpc/xds/ClientXdsClientV2Test.java | 6 ++ .../io/grpc/xds/ClientXdsClientV3Test.java | 16 ++++ ...tProviderClientSslContextProviderTest.java | 84 +++++++++++++++++ ...tProviderServerSslContextProviderTest.java | 94 +++++++++++++++++++ .../sds/CommonTlsContextTestsUtil.java | 83 ++++++++++++++++ 12 files changed, 473 insertions(+), 94 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 8384551597..21cf78b126 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -456,10 +456,6 @@ final class ClientXdsClient extends AbstractXdsClient { if (commonTlsContext.hasTlsParams()) { throw new ResourceInvalidException("common-tls-context with tls_params is not supported"); } - if (commonTlsContext.hasValidationContext()) { - throw new ResourceInvalidException( - "common-tls-context with validation_context is not supported"); - } if (commonTlsContext.hasValidationContextSdsSecretConfig()) { throw new ResourceInvalidException( "common-tls-context with validation_context_sds_secret_config is not supported"); @@ -473,54 +469,50 @@ final class ClientXdsClient extends AbstractXdsClient { "common-tls-context with validation_context_certificate_provider_instance is not" + " supported"); } - String certInstanceName = null; - if (!commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + String certInstanceName = getIdentityCertInstanceName(commonTlsContext); + if (certInstanceName == null) { if (server) { throw new ResourceInvalidException( - "tls_certificate_certificate_provider_instance is required in downstream-tls-context"); + "tls_certificate_provider_instance is required in downstream-tls-context"); } if (commonTlsContext.getTlsCertificatesCount() > 0) { throw new ResourceInvalidException( - "common-tls-context with tls_certificates is not supported"); + "tls_certificate_provider_instance is unset"); } if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { throw new ResourceInvalidException( - "common-tls-context with tls_certificate_sds_secret_configs is not supported"); + "tls_certificate_provider_instance is unset"); } if (commonTlsContext.hasTlsCertificateCertificateProvider()) { throw new ResourceInvalidException( - "common-tls-context with tls_certificate_certificate_provider is not supported"); + "tls_certificate_provider_instance is unset"); } - } else { - certInstanceName = commonTlsContext.getTlsCertificateCertificateProviderInstance() - .getInstanceName(); + } else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { + throw new ResourceInvalidException( + "CertificateProvider instance name '" + certInstanceName + + "' not defined in the bootstrap file."); } - if (certInstanceName != null) { - if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { - throw new ResourceInvalidException( - "CertificateProvider instance name '" + certInstanceName - + "' not defined in the bootstrap file."); - } - } - String rootCaInstanceName = null; - if (!commonTlsContext.hasCombinedValidationContext()) { + String rootCaInstanceName = getRootCertInstanceName(commonTlsContext); + if (rootCaInstanceName == null) { if (!server) { throw new ResourceInvalidException( - "combined_validation_context is required in upstream-tls-context"); + "ca_certificate_provider_instance is required in upstream-tls-context"); } } else { - CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext - = commonTlsContext.getCombinedValidationContext(); - if (!combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance()) { + if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { throw new ResourceInvalidException( - "validation_context_certificate_provider_instance is required in" - + " combined_validation_context"); + "ca_certificate_provider_instance name '" + rootCaInstanceName + + "' not defined in the bootstrap file."); } - rootCaInstanceName = combinedCertificateValidationContext - .getValidationContextCertificateProviderInstance().getInstanceName(); - if (combinedCertificateValidationContext.hasDefaultValidationContext()) { - CertificateValidationContext certificateValidationContext - = combinedCertificateValidationContext.getDefaultValidationContext(); + CertificateValidationContext certificateValidationContext = null; + if (commonTlsContext.hasValidationContext()) { + certificateValidationContext = commonTlsContext.getValidationContext(); + } else if (commonTlsContext.hasCombinedValidationContext() && commonTlsContext + .getCombinedValidationContext().hasDefaultValidationContext()) { + certificateValidationContext = commonTlsContext.getCombinedValidationContext() + .getDefaultValidationContext(); + } + if (certificateValidationContext != null) { if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) { throw new ResourceInvalidException( "match_subject_alt_names only allowed in upstream_tls_context"); @@ -547,13 +539,38 @@ final class ClientXdsClient extends AbstractXdsClient { } } } - if (rootCaInstanceName != null) { - if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { - throw new ResourceInvalidException( - "ValidationContextProvider instance name '" + rootCaInstanceName - + "' not defined in the bootstrap file."); + } + + private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasTlsCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName(); + } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName(); + } + return null; + } + + private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext()) { + if (commonTlsContext.getValidationContext().hasCaCertificateProviderInstance()) { + return commonTlsContext.getValidationContext().getCaCertificateProviderInstance() + .getInstanceName(); + } + } else if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext + = commonTlsContext.getCombinedValidationContext(); + if (combinedCertificateValidationContext.hasDefaultValidationContext() + && combinedCertificateValidationContext.getDefaultValidationContext() + .hasCaCertificateProviderInstance()) { + return combinedCertificateValidationContext.getDefaultValidationContext() + .getCaCertificateProviderInstance().getInstanceName(); + } else if (combinedCertificateValidationContext + .hasValidationContextCertificateProviderInstance()) { + return combinedCertificateValidationContext + .getValidationContextCertificateProviderInstance().getInstanceName(); } } + return null; } private static void checkForUniqueness(Set uniqueSet, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java index 2ee21e7db6..ce9ef3de68 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java @@ -22,7 +22,6 @@ import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; @@ -94,27 +93,12 @@ public final class CertProviderClientSslContextProvider extends CertProviderSslC @Nullable Map certProviders) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); - CommonTlsContext.CertificateProviderInstance rootCertInstance = null; - CertificateValidationContext staticCertValidationContext = null; - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = - combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - if (combinedValidationContext.hasDefaultValidationContext()) { - staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance(); - } else if (commonTlsContext.hasValidationContext()) { - staticCertValidationContext = commonTlsContext.getValidationContext(); - } - CommonTlsContext.CertificateProviderInstance certInstance = null; - if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance(); - } + CertificateValidationContext staticCertValidationContext = getStaticValidationContext( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( + commonTlsContext); return new CertProviderClientSslContextProvider( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java index 1f33e1de78..a7f0849d00 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java @@ -22,7 +22,6 @@ import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; @@ -97,27 +96,12 @@ public final class CertProviderServerSslContextProvider extends CertProviderSslC @Nullable Map certProviders) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); - CommonTlsContext.CertificateProviderInstance rootCertInstance = null; - CertificateValidationContext staticCertValidationContext = null; - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = - combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - if (combinedValidationContext.hasDefaultValidationContext()) { - staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance(); - } else if (commonTlsContext.hasValidationContext()) { - staticCertValidationContext = commonTlsContext.getValidationContext(); - } - CommonTlsContext.CertificateProviderInstance certInstance = null; - if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance(); - } + CertificateValidationContext staticCertValidationContext = getStaticValidationContext( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( + commonTlsContext); return new CertProviderServerSslContextProvider( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java index 1af9e1670d..1ec5876419 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java @@ -18,9 +18,11 @@ package io.grpc.xds.internal.certprovider; import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; +import io.grpc.xds.internal.sds.CommonTlsContextUtil; import io.grpc.xds.internal.sds.DynamicSslContextProvider; import java.security.PrivateKey; import java.security.cert.X509Certificate; @@ -88,6 +90,52 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider return certProviders != null ? certProviders.get(pluginInstanceName) : null; } + @Nullable + protected static CertificateProviderInstance getCertProviderInstance( + CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasTlsCertificateProviderInstance()) { + return CommonTlsContextUtil.convert(commonTlsContext.getTlsCertificateProviderInstance()); + } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateCertificateProviderInstance(); + } + return null; + } + + @Nullable + protected static CertificateValidationContext getStaticValidationContext( + CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext()) { + return commonTlsContext.getValidationContext(); + } else if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasDefaultValidationContext()) { + return combinedValidationContext.getDefaultValidationContext(); + } + } + return null; + } + + @Nullable + protected static CommonTlsContext.CertificateProviderInstance getRootCertProviderInstance( + CommonTlsContext commonTlsContext) { + CertificateValidationContext certValidationContext = getStaticValidationContext( + commonTlsContext); + if (certValidationContext != null && certValidationContext.hasCaCertificateProviderInstance()) { + return CommonTlsContextUtil.convert(certValidationContext.getCaCertificateProviderInstance()); + } + if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { + return combinedValidationContext.getValidationContextCertificateProviderInstance(); + } + } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { + return commonTlsContext.getValidationContextCertificateProviderInstance(); + } + return null; + } + @Override public final void updateCertificate(PrivateKey key, List certChain) { savedKey = key; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java index 234989ad11..0c28c79ee2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java @@ -16,11 +16,12 @@ package io.grpc.xds.internal.sds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; /** Class for utility functions for {@link CommonTlsContext}. */ -final class CommonTlsContextUtil { +public final class CommonTlsContextUtil { private CommonTlsContextUtil() {} @@ -38,4 +39,15 @@ final class CommonTlsContextUtil { } return commonTlsContext.hasValidationContextCertificateProviderInstance(); } + + /** + * Converts {@link CertificateProviderPluginInstance} to + * {@link CommonTlsContext.CertificateProviderInstance}. + */ + public static CommonTlsContext.CertificateProviderInstance convert( + CertificateProviderPluginInstance pluginInstance) { + return CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName(pluginInstance.getInstanceName()) + .setCertificateName(pluginInstance.getCertificateName()).build(); + } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index fb5c349f12..80cd2a8046 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -77,6 +77,7 @@ import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; @@ -1551,7 +1552,7 @@ public class ClientXdsClientDataTest { .setValidationContext(CertificateValidationContext.getDefaultInstance()) .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with validation_context is not supported"); + thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1603,14 +1604,26 @@ public class ClientXdsClientDataTest { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "tls_certificate_certificate_provider_instance is required in downstream-tls-context"); + "tls_certificate_provider_instance is required in downstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, true); } + @Test + @SuppressWarnings("deprecation") + public void validateCommonTlsContext_tlsNewCertificateProviderInstance() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("name1").build()) + .build(); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); + } + @Test @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance() - throws ResourceInvalidException { + throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setTlsCertificateCertificateProviderInstance( CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) @@ -1662,7 +1675,7 @@ public class ClientXdsClientDataTest { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "ValidationContextProvider instance name 'bad-name' not defined in the bootstrap file."); + "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); ClientXdsClient .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); } @@ -1674,7 +1687,7 @@ public class ClientXdsClientDataTest { .addTlsCertificates(TlsCertificate.getDefaultInstance()) .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with tls_certificates is not supported"); + thrown.expectMessage("tls_certificate_provider_instance is unset"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1686,7 +1699,7 @@ public class ClientXdsClientDataTest { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "common-tls-context with tls_certificate_sds_secret_configs is not supported"); + "tls_certificate_provider_instance is unset"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1700,7 +1713,7 @@ public class ClientXdsClientDataTest { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "common-tls-context with tls_certificate_certificate_provider is not supported"); + "tls_certificate_provider_instance is unset"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1710,7 +1723,7 @@ public class ClientXdsClientDataTest { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("combined_validation_context is required in upstream-tls-context"); + thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1723,8 +1736,7 @@ public class ClientXdsClientDataTest { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "validation_context_certificate_provider_instance is required in " - + "combined_validation_context"); + "ca_certificate_provider_instance is required in upstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 9e4d92fb34..55bd6ba3e9 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -39,6 +39,7 @@ import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import io.envoyproxy.envoy.config.route.v3.FilterConfig; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.BindableService; import io.grpc.Context; @@ -1353,6 +1354,42 @@ public abstract class ClientXdsClientTestBase { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } + /** + * CDS response containing new UpstreamTlsContext for a cluster. + */ + @Test + @SuppressWarnings("deprecation") + public void cdsResponseWithNewUpstreamTlsContext() { + Assume.assumeTrue(useProtocolV3()); + DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + + // Management server sends back CDS response with UpstreamTlsContext. + Any clusterEds = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", + null, true, + mf.buildNewUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null)); + List clusters = ImmutableList.of( + Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", + "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), + clusterEds, + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, false, + null, "envoy.transport_sockets.tls", null))); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + CertificateProviderPluginInstance certificateProviderInstance = + cdsUpdate.upstreamTlsContext().getCommonTlsContext().getValidationContext() + .getCaCertificateProviderInstance(); + assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("cert-instance-name"); + assertThat(certificateProviderInstance.getCertificateName()).isEqualTo("cert1"); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + } + /** * CDS response containing bad UpstreamTlsContext for a cluster. */ @@ -1373,7 +1410,7 @@ public abstract class ClientXdsClientTestBase { "CDS response Cluster 'cluster.googleapis.com' validation error: " + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " + "io.grpc.xds.ClientXdsClient$ResourceInvalidException: " - + "combined_validation_context is required in upstream-tls-context")); + + "ca_certificate_provider_instance is required in upstream-tls-context")); verifyNoInteractions(cdsResourceWatcher); } @@ -2400,6 +2437,8 @@ public abstract class ClientXdsClientTestBase { protected abstract Message buildUpstreamTlsContext(String instanceName, String certName); + protected abstract Message buildNewUpstreamTlsContext(String instanceName, String certName); + protected abstract Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java index 409613aecf..39f5d1a1a2 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java @@ -515,6 +515,12 @@ public class ClientXdsClientV2Test extends ClientXdsClientTestBase { .build(); } + @Override + protected Message buildNewUpstreamTlsContext(String instanceName, String certName) { + return buildUpstreamTlsContext(instanceName, certName); + } + + @Override protected Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests) { diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java index eddba1040d..dfd407ef01 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java @@ -77,6 +77,8 @@ import io.envoyproxy.envoy.extensions.filters.http.fault.v3.HTTPFault; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; @@ -555,6 +557,20 @@ public class ClientXdsClientV3Test extends ClientXdsClientTestBase { .build(); } + @Override + protected Message buildNewUpstreamTlsContext(String instanceName, String certName) { + CommonTlsContext.Builder commonTlsContextBuilder = CommonTlsContext.newBuilder(); + if (instanceName != null && certName != null) { + commonTlsContextBuilder.setValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName(instanceName) + .setCertificateName(certName).build())); + } + return UpstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContextBuilder) + .build(); + } + @Override protected Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests) { diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java index 00b2901464..1eed5488aa 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java @@ -87,6 +87,27 @@ public class CertProviderClientSslContextProviderTest { bootstrapInfo.getCertProviders()); } + /** Helper method to build CertProviderClientSslContextProvider. */ + private CertProviderClientSslContextProvider getNewSslContextProvider( + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } + @Test public void testProviderForClient_mtls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -150,6 +171,69 @@ public class CertProviderClientSslContextProviderTest { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForClient_mtls_newXds() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getNewSslContextProvider( + "gcp_id", + "gcp_id", + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForClient_queueExecutor() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java index ef801ccc2c..783ce2b11f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java @@ -31,12 +31,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.MoreExecutors; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; +import java.util.Arrays; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -81,6 +83,30 @@ public class CertProviderServerSslContextProviderTest { bootstrapInfo.getCertProviders()); } + /** Helper method to build CertProviderServerSslContextProvider. */ + private CertProviderServerSslContextProvider getNewSslContextProvider( + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean requireClientCert) { + EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildNewDownstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext, + requireClientCert); + return certProviderServerSslContextProviderFactory.getProvider( + downstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } + + @Test public void testProviderForServer_mtls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -145,6 +171,74 @@ public class CertProviderServerSslContextProviderTest { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForServer_mtls_newXds() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder().addAllMatchSubjectAltNames(Arrays + .asList(StringMatcher.newBuilder().setExact("foo.com").build(), + StringMatcher.newBuilder().setExact("bar.com").build())).build(); + CertProviderServerSslContextProvider provider = + getNewSslContextProvider( + "gcp_id", + "gcp_id", + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + staticCertValidationContext, + /* requireClientCert= */ true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForServer_queueExecutor() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index 81fbda9bde..840cced424 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -22,6 +22,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.io.CharStreams; import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; @@ -234,6 +235,30 @@ public class CommonTlsContextTestsUtil { return builder.build(); } + private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( + String certInstanceName, + String certName, + String rootInstanceName, + String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (certInstanceName != null) { + builder = + builder.setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); + } + builder = + addNewCertificateValidationContext( + builder, rootInstanceName, rootCertName, staticCertValidationContext); + if (alpnProtocols != null) { + builder.addAllAlpnProtocols(alpnProtocols); + } + return builder.build(); + } + @SuppressWarnings("deprecation") private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, @@ -259,6 +284,26 @@ public class CommonTlsContextTestsUtil { return builder; } + private static CommonTlsContext.Builder addNewCertificateValidationContext( + CommonTlsContext.Builder builder, + String rootInstanceName, + String rootCertName, + CertificateValidationContext staticCertValidationContext) { + if (rootInstanceName != null) { + CertificateProviderPluginInstance providerInstance = + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(rootInstanceName) + .setCertificateName(rootCertName) + .build(); + CertificateValidationContext.Builder validationContextBuilder = + staticCertValidationContext != null ? staticCertValidationContext.toBuilder() + : CertificateValidationContext.newBuilder(); + return builder.setValidationContext( + validationContextBuilder.setCaCertificateProviderInstance(providerInstance)); + } + return builder; + } + /** Helper method to build UpstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContextForCertProviderInstance( @@ -278,6 +323,25 @@ public class CommonTlsContextTestsUtil { staticCertValidationContext)); } + /** Helper method to build UpstreamTlsContext for CertProvider tests. */ + public static EnvoyServerProtoData.UpstreamTlsContext + buildNewUpstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + return buildUpstreamTlsContext( + buildNewCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext)); + } + /** Helper method to build DownstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContextForCertProviderInstance( @@ -298,6 +362,25 @@ public class CommonTlsContextTestsUtil { staticCertValidationContext), requireClientCert); } + /** Helper method to build DownstreamTlsContext for CertProvider tests. */ + public static EnvoyServerProtoData.DownstreamTlsContext + buildNewDownstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean requireClientCert) { + return buildInternalDownstreamTlsContext( + buildNewCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext), requireClientCert); + } /** Perform some simple checks on sslContext. */ public static void doChecksOnSslContext(boolean server, SslContext sslContext,