diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java index 19ad67186c..ccaf86e3f6 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java @@ -51,7 +51,19 @@ final class ClientSslContextProviderFactory checkNotNull( upstreamTlsContext.getCommonTlsContext(), "upstreamTlsContext should have CommonTlsContext"); - if (CommonTlsContextUtil.hasAllSecretsUsingFilename(upstreamTlsContext.getCommonTlsContext())) { + if (CommonTlsContextUtil.hasCertProviderInstance( + upstreamTlsContext.getCommonTlsContext())) { + try { + Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap(); + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } catch (XdsInitializationException e) { + throw new RuntimeException(e); + } + } else if (CommonTlsContextUtil.hasAllSecretsUsingFilename( + upstreamTlsContext.getCommonTlsContext())) { return SecretVolumeClientSslContextProvider.getProvider(upstreamTlsContext); } else if (CommonTlsContextUtil.hasAllSecretsUsingSds( upstreamTlsContext.getCommonTlsContext())) { @@ -67,17 +79,6 @@ final class ClientSslContextProviderFactory } catch (XdsInitializationException e) { throw new RuntimeException(e); } - } else if (CommonTlsContextUtil.hasCertProviderInstance( - upstreamTlsContext.getCommonTlsContext())) { - try { - Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap(); - return certProviderClientSslContextProviderFactory.getProvider( - upstreamTlsContext, - bootstrapInfo.getNode().toEnvoyProtoNode(), - bootstrapInfo.getCertProviders()); - } catch (XdsInitializationException e) { - throw new RuntimeException(e); - } } throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java index 3a33c975c0..657bbd7dfd 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java @@ -52,7 +52,18 @@ final class ServerSslContextProviderFactory checkNotNull( downstreamTlsContext.getCommonTlsContext(), "downstreamTlsContext should have CommonTlsContext"); - if (CommonTlsContextUtil.hasAllSecretsUsingFilename( + if (CommonTlsContextUtil.hasCertProviderInstance( + downstreamTlsContext.getCommonTlsContext())) { + try { + Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap(); + return certProviderServerSslContextProviderFactory.getProvider( + downstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } catch (XdsInitializationException e) { + throw new RuntimeException(e); + } + } else if (CommonTlsContextUtil.hasAllSecretsUsingFilename( downstreamTlsContext.getCommonTlsContext())) { return SecretVolumeServerSslContextProvider.getProvider(downstreamTlsContext); } else if (CommonTlsContextUtil.hasAllSecretsUsingSds( @@ -69,17 +80,6 @@ final class ServerSslContextProviderFactory } catch (XdsInitializationException e) { throw new RuntimeException(e); } - } else if (CommonTlsContextUtil.hasCertProviderInstance( - downstreamTlsContext.getCommonTlsContext())) { - try { - Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap(); - return certProviderServerSslContextProviderFactory.getProvider( - downstreamTlsContext, - bootstrapInfo.getNode().toEnvoyProtoNode(), - bootstrapInfo.getCertProviders()); - } catch (XdsInitializationException e) { - throw new RuntimeException(e); - } } throw new UnsupportedOperationException("Unsupported configurations in DownstreamTlsContext!"); } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index 46f0685b5a..397f1c332c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -26,8 +26,10 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableSet; +import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; @@ -139,6 +141,33 @@ public class ClientSslContextProviderFactoryTest { verifyWatcher(sslContextProvider, watcherCaptor[0]); } + @Test + public void bothPresent_expectCertProviderClientSslContextProvider() + throws XdsInitializationException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); + builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem"); + upstreamTlsContext = new UpstreamTlsContext(builder.build()); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + clientSslContextProviderFactory.create(upstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + @Test public void createCertProviderClientSslContextProvider_onlyRootCert() throws XdsInitializationException { @@ -301,4 +330,27 @@ public class ClientSslContextProviderFactoryTest { assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) .isSameInstanceAs(sslContextProvider); } + + static CommonTlsContext.Builder addFilenames( + CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) { + TlsCertificate tlsCert = + TlsCertificate.newBuilder() + .setCertificateChain(DataSource.newBuilder().setFilename(certChain)) + .setPrivateKey(DataSource.newBuilder().setFilename(privateKey)) + .build(); + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder() + .setTrustedCa(DataSource.newBuilder().setFilename(trustCa)) + .build(); + CommonTlsContext.CertificateProviderInstance certificateProviderInstance = + builder.getValidationContextCertificateProviderInstance(); + CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder = + CommonTlsContext.CombinedCertificateValidationContext.newBuilder(); + combinedBuilder + .setDefaultValidationContext(certContext) + .setValidationContextCertificateProviderInstance(certificateProviderInstance); + return builder + .addTlsCertificates(tlsCert) + .setCombinedValidationContext(combinedBuilder.build()); + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index 00d32268b4..0a68197173 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -31,6 +31,7 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; +import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.XdsInitializationException; import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider; @@ -136,6 +137,37 @@ public class ServerSslContextProviderFactoryTest { verifyWatcher(sslContextProvider, watcherCaptor[0]); } + @Test + public void bothPresent_expectCertProviderServerSslContextProvider() + throws XdsInitializationException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null, + /* requireClientCert= */ true); + + CommonTlsContext.Builder builder = downstreamTlsContext.getCommonTlsContext().toBuilder(); + builder = + ClientSslContextProviderFactoryTest.addFilenames(builder, "foo.pem", "foo.key", "root.pem"); + downstreamTlsContext = + new EnvoyServerProtoData.DownstreamTlsContext( + builder.build(), downstreamTlsContext.isRequireClientCertificate()); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); + when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo); + SslContextProvider sslContextProvider = + serverSslContextProviderFactory.create(downstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + @Test public void createCertProviderServerSslContextProvider_onlyCertInstance() throws XdsInitializationException {