xds: reorder processing of tlsContext to prioritize CertProviderInstance (#7592)

This commit is contained in:
sanjaypujare 2020-11-04 12:57:20 -08:00 committed by GitHub
parent d52b359631
commit d7764d7e32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 24 deletions

View File

@ -51,7 +51,19 @@ final class ClientSslContextProviderFactory
checkNotNull(
upstreamTlsContext.getCommonTlsContext(),
"upstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename(upstreamTlsContext.getCommonTlsContext())) {
if (CommonTlsContextUtil.hasCertProviderInstance(
upstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderClientSslContextProviderFactory.getProvider(
upstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
upstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeClientSslContextProvider.getProvider(upstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
upstreamTlsContext.getCommonTlsContext())) {
@ -67,17 +79,6 @@ final class ClientSslContextProviderFactory
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasCertProviderInstance(
upstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderClientSslContextProviderFactory.getProvider(
upstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
}
throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!");
}

View File

@ -52,7 +52,18 @@ final class ServerSslContextProviderFactory
checkNotNull(
downstreamTlsContext.getCommonTlsContext(),
"downstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
if (CommonTlsContextUtil.hasCertProviderInstance(
downstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
downstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeServerSslContextProvider.getProvider(downstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
@ -69,17 +80,6 @@ final class ServerSslContextProviderFactory
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasCertProviderInstance(
downstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
}
throw new UnsupportedOperationException("Unsupported configurations in DownstreamTlsContext!");
}

View File

@ -26,8 +26,10 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
@ -139,6 +141,33 @@ public class ClientSslContextProviderFactoryTest {
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void bothPresent_expectCertProviderClientSslContextProvider()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder();
builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem");
upstreamTlsContext = new UpstreamTlsContext(builder.build());
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_onlyRootCert()
throws XdsInitializationException {
@ -301,4 +330,27 @@ public class ClientSslContextProviderFactoryTest {
assertThat(watcherCaptor.getDownstreamWatchers().iterator().next())
.isSameInstanceAs(sslContextProvider);
}
static CommonTlsContext.Builder addFilenames(
CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
CommonTlsContext.CertificateProviderInstance certificateProviderInstance =
builder.getValidationContextCertificateProviderInstance();
CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder =
CommonTlsContext.CombinedCertificateValidationContext.newBuilder();
combinedBuilder
.setDefaultValidationContext(certContext)
.setValidationContextCertificateProviderInstance(certificateProviderInstance);
return builder
.addTlsCertificates(tlsCert)
.setCombinedValidationContext(combinedBuilder.build());
}
}

View File

@ -31,6 +31,7 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.XdsInitializationException;
import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider;
@ -136,6 +137,37 @@ public class ServerSslContextProviderFactoryTest {
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void bothPresent_expectCertProviderServerSslContextProvider()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
CommonTlsContext.Builder builder = downstreamTlsContext.getCommonTlsContext().toBuilder();
builder =
ClientSslContextProviderFactoryTest.addFilenames(builder, "foo.pem", "foo.key", "root.pem");
downstreamTlsContext =
new EnvoyServerProtoData.DownstreamTlsContext(
builder.build(), downstreamTlsContext.isRequireClientCertificate());
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderServerSslContextProvider_onlyCertInstance()
throws XdsInitializationException {