xds: clean up to remove the cert-instance-override related code (#7986)

This commit is contained in:
sanjaypujare 2021-03-18 10:22:04 -07:00 committed by GitHub
parent e8d935e5c9
commit cd3b0c4412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 111 deletions

View File

@ -34,33 +34,26 @@ import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
*/ */
public final class TlsContextManagerImpl implements TlsContextManager { public final class TlsContextManagerImpl implements TlsContextManager {
public static final String GOOGLE_CLOUD_PRIVATE_SPIFFE = "google_cloud_private_spiffe";
private static TlsContextManagerImpl instance; private static TlsContextManagerImpl instance;
private static final boolean CERT_INSTANCE_OVERRIDE =
Boolean.parseBoolean(System.getenv("GRPC_XDS_CERT_INSTANCE_OVERRIDE"));
private final ReferenceCountingMap<UpstreamTlsContext, SslContextProvider> mapForClients; private final ReferenceCountingMap<UpstreamTlsContext, SslContextProvider> mapForClients;
private final ReferenceCountingMap<DownstreamTlsContext, SslContextProvider> mapForServers; private final ReferenceCountingMap<DownstreamTlsContext, SslContextProvider> mapForServers;
private final boolean hasCertInstanceOverride;
/** Create a TlsContextManagerImpl instance using the passed in {@link Bootstrapper}. */ /** Create a TlsContextManagerImpl instance using the passed in {@link Bootstrapper}. */
@VisibleForTesting public TlsContextManagerImpl(Bootstrapper bootstrapper) { @VisibleForTesting public TlsContextManagerImpl(Bootstrapper bootstrapper) {
this( this(
new ClientSslContextProviderFactory(bootstrapper), new ClientSslContextProviderFactory(bootstrapper),
new ServerSslContextProviderFactory(bootstrapper), CERT_INSTANCE_OVERRIDE); new ServerSslContextProviderFactory(bootstrapper));
} }
@VisibleForTesting @VisibleForTesting
TlsContextManagerImpl( TlsContextManagerImpl(
ValueFactory<UpstreamTlsContext, SslContextProvider> clientFactory, ValueFactory<UpstreamTlsContext, SslContextProvider> clientFactory,
ValueFactory<DownstreamTlsContext, SslContextProvider> serverFactory, ValueFactory<DownstreamTlsContext, SslContextProvider> serverFactory) {
boolean certInstanceOverride) {
checkNotNull(clientFactory, "clientFactory"); checkNotNull(clientFactory, "clientFactory");
checkNotNull(serverFactory, "serverFactory"); checkNotNull(serverFactory, "serverFactory");
mapForClients = new ReferenceCountingMap<>(clientFactory); mapForClients = new ReferenceCountingMap<>(clientFactory);
mapForServers = new ReferenceCountingMap<>(serverFactory); mapForServers = new ReferenceCountingMap<>(serverFactory);
this.hasCertInstanceOverride = certInstanceOverride;
} }
/** Gets the TlsContextManagerImpl singleton. */ /** Gets the TlsContextManagerImpl singleton. */
@ -76,7 +69,6 @@ public final class TlsContextManagerImpl implements TlsContextManager {
DownstreamTlsContext downstreamTlsContext) { DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext"); checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext.Builder builder = downstreamTlsContext.getCommonTlsContext().toBuilder(); CommonTlsContext.Builder builder = downstreamTlsContext.getCommonTlsContext().toBuilder();
builder = performCertInstanceOverride(builder);
downstreamTlsContext = downstreamTlsContext =
new DownstreamTlsContext( new DownstreamTlsContext(
builder.build(), downstreamTlsContext.isRequireClientCertificate()); builder.build(), downstreamTlsContext.isRequireClientCertificate());
@ -88,38 +80,10 @@ public final class TlsContextManagerImpl implements TlsContextManager {
UpstreamTlsContext upstreamTlsContext) { UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext"); checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder();
builder = performCertInstanceOverride(builder);
upstreamTlsContext = new UpstreamTlsContext(builder.build()); upstreamTlsContext = new UpstreamTlsContext(builder.build());
return mapForClients.get(upstreamTlsContext); return mapForClients.get(upstreamTlsContext);
} }
@VisibleForTesting
CommonTlsContext.Builder performCertInstanceOverride(CommonTlsContext.Builder builder) {
if (hasCertInstanceOverride) {
if (builder.getTlsCertificateSdsSecretConfigsCount() > 0) {
builder.setTlsCertificateCertificateProviderInstance(
CommonTlsContext.CertificateProviderInstance.newBuilder()
.setInstanceName(GOOGLE_CLOUD_PRIVATE_SPIFFE));
}
if (builder.hasCombinedValidationContext()) {
CommonTlsContext.CombinedCertificateValidationContext.Builder ccvcBuilder =
builder.getCombinedValidationContextBuilder();
if (ccvcBuilder.hasValidationContextSdsSecretConfig()) {
ccvcBuilder =
ccvcBuilder.setValidationContextCertificateProviderInstance(
CommonTlsContext.CertificateProviderInstance.newBuilder()
.setInstanceName(GOOGLE_CLOUD_PRIVATE_SPIFFE));
builder.setCombinedValidationContext(ccvcBuilder);
}
} else if (builder.hasValidationContextSdsSecretConfig()) {
builder.setValidationContextCertificateProviderInstance(
CommonTlsContext.CertificateProviderInstance.newBuilder()
.setInstanceName(GOOGLE_CLOUD_PRIVATE_SPIFFE));
}
}
return builder;
}
@Override @Override
public SslContextProvider releaseClientSslContextProvider( public SslContextProvider releaseClientSslContextProvider(
SslContextProvider clientSslContextProvider) { SslContextProvider clientSslContextProvider) {

View File

@ -30,10 +30,6 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
@ -143,7 +139,7 @@ public class TlsContextManagerTest {
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory, false); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
SslContextProvider mockProvider = mock(SslContextProvider.class); SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockServerFactory.create(downstreamTlsContext)).thenReturn(mockProvider); when(mockServerFactory.create(downstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider serverSecretProvider = SslContextProvider serverSecretProvider =
@ -162,7 +158,7 @@ public class TlsContextManagerTest {
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory, false); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
SslContextProvider mockProvider = mock(SslContextProvider.class); SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockClientFactory.create(upstreamTlsContext)).thenReturn(mockProvider); when(mockClientFactory.create(upstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider clientSecretProvider = SslContextProvider clientSecretProvider =
@ -173,71 +169,4 @@ public class TlsContextManagerTest {
tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider); tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider);
verify(mockProvider, times(1)).close(); verify(mockProvider, times(1)).close();
} }
@Test
public void certInstanceOverrideForTlsCert() {
TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory, true);
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate(
/* name= */ "name", /* targetUri= */ "unix:/tmp/sds/path", CA_PEM_FILE);
CommonTlsContext.Builder origBuilder = commonTlsContext.toBuilder();
CommonTlsContext.Builder modBuilder =
tlsContextManagerImpl.performCertInstanceOverride(origBuilder);
assertThat(modBuilder.hasValidationContextCertificateProviderInstance()).isFalse();
assertThat(modBuilder.hasCombinedValidationContext()).isFalse();
assertThat(modBuilder.hasTlsCertificateCertificateProviderInstance()).isTrue();
CommonTlsContext.CertificateProviderInstance instance =
modBuilder.getTlsCertificateCertificateProviderInstance();
assertThat(instance.getInstanceName()).isEqualTo("google_cloud_private_spiffe");
}
@Test
public void certInstanceOverrideForValidationContext() {
TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory, true);
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext(
/* name= */ "name",
/* targetUri= */ "unix:/tmp/sds/path",
CLIENT_KEY_FILE,
CLIENT_PEM_FILE);
CommonTlsContext.Builder origBuilder = commonTlsContext.toBuilder();
CommonTlsContext.Builder modBuilder =
tlsContextManagerImpl.performCertInstanceOverride(origBuilder);
assertThat(modBuilder.hasTlsCertificateCertificateProviderInstance()).isFalse();
assertThat(modBuilder.hasCombinedValidationContext()).isFalse();
assertThat(modBuilder.hasValidationContextCertificateProviderInstance()).isTrue();
CommonTlsContext.CertificateProviderInstance instance =
modBuilder.getValidationContextCertificateProviderInstance();
assertThat(instance.getInstanceName()).isEqualTo("google_cloud_private_spiffe");
}
@Test
public void certInstanceOverrideForCombinedValidationContext() {
TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory, true);
CertificateValidationContext staticCertContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename("/tmp/a.pem"))
.build();
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
builder =
CommonTlsContextTestsUtil.addCertificateValidationContext(
builder, "name", /* targetUri= */ "unix:/tmp/sds/path", "uds", staticCertContext);
CommonTlsContext.Builder modBuilder =
tlsContextManagerImpl.performCertInstanceOverride(builder);
assertThat(modBuilder.hasTlsCertificateCertificateProviderInstance()).isFalse();
assertThat(modBuilder.hasCombinedValidationContext()).isTrue();
assertThat(modBuilder.hasValidationContextCertificateProviderInstance()).isFalse();
CommonTlsContext.CombinedCertificateValidationContext combined =
modBuilder.getCombinedValidationContext();
CommonTlsContext.CertificateProviderInstance instance =
combined.getValidationContextCertificateProviderInstance();
assertThat(instance.getInstanceName()).isEqualTo("google_cloud_private_spiffe");
SdsSecretConfig validationContextSdsConfig = combined.getValidationContextSdsSecretConfig();
assertThat(validationContextSdsConfig.getName()).isEqualTo("name");
}
} }