From 39c49b04080739fdb3f4b45a72efb80863a580ec Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Mon, 17 Aug 2020 09:45:13 -0700 Subject: [PATCH] xds: add CertProviderSslContextProvider support (#7309) --- .../main/java/io/grpc/xds/Bootstrapper.java | 7 +- .../main/java/io/grpc/xds/EnvoyProtoData.java | 3 +- .../CertProviderClientSslContextProvider.java | 123 ++++++++ .../CertProviderSslContextProvider.java | 155 +++++++++ .../CertificateProviderStore.java | 2 +- .../sds/DynamicSslContextProvider.java | 143 +++++++++ .../sds/SdsClientSslContextProvider.java | 2 +- .../internal/sds/SdsProtocolNegotiators.java | 12 +- .../sds/SdsServerSslContextProvider.java | 2 +- .../internal/sds/SdsSslContextProvider.java | 113 +------ .../SecretVolumeClientSslContextProvider.java | 8 +- .../SecretVolumeServerSslContextProvider.java | 8 +- .../xds/internal/sds/SslContextProvider.java | 25 +- .../sds/trust/SdsTrustManagerFactory.java | 28 +- .../sds/trust/SdsX509TrustManager.java | 27 +- ...tProviderClientSslContextProviderTest.java | 293 ++++++++++++++++++ .../CertificateProviderStoreTest.java | 37 +-- .../CommonCertProviderTestUtils.java | 177 +++++++++++ .../certprovider/TestCertificateProvider.java | 84 +++++ .../sds/CommonTlsContextTestsUtil.java | 123 ++++++++ .../sds/SdsSslContextProviderTest.java | 28 +- .../SecretVolumeSslContextProviderTest.java | 47 +-- .../sds/trust/SdsTrustManagerFactoryTest.java | 104 +++++++ 23 files changed, 1301 insertions(+), 250 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/DynamicSslContextProvider.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java diff --git a/xds/src/main/java/io/grpc/xds/Bootstrapper.java b/xds/src/main/java/io/grpc/xds/Bootstrapper.java index c2a29ab10c..a0bacd4190 100644 --- a/xds/src/main/java/io/grpc/xds/Bootstrapper.java +++ b/xds/src/main/java/io/grpc/xds/Bootstrapper.java @@ -76,9 +76,10 @@ public abstract class Bootstrapper { */ public abstract BootstrapInfo readBootstrap() throws IOException; + /** Parses a raw string into {@link BootstrapInfo}. */ @VisibleForTesting @SuppressWarnings("deprecation") - static BootstrapInfo parseConfig(String rawData) throws IOException { + public static BootstrapInfo parseConfig(String rawData) throws IOException { XdsLogger logger = XdsLogger.withPrefix(LOG_PREFIX); logger.log(XdsLogLevel.INFO, "Reading bootstrap information"); @SuppressWarnings("unchecked") @@ -264,11 +265,11 @@ public abstract class Bootstrapper { this.config = checkNotNull(config, "config"); } - String getPluginName() { + public String getPluginName() { return pluginName; } - Map getConfig() { + public Map getConfig() { return config; } } diff --git a/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java index 14a62f29f5..cbb73423a1 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java @@ -325,7 +325,8 @@ final class EnvoyProtoData { return listeningAddresses; } - io.envoyproxy.envoy.config.core.v3.Node toEnvoyProtoNode() { + @VisibleForTesting + public io.envoyproxy.envoy.config.core.v3.Node toEnvoyProtoNode() { io.envoyproxy.envoy.config.core.v3.Node.Builder builder = io.envoyproxy.envoy.config.core.v3.Node.newBuilder(); builder.setId(id); diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java new file mode 100644 index 0000000000..562f27431c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java @@ -0,0 +1,123 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.envoy.config.core.v3.Node; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.xds.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; +import io.netty.handler.ssl.SslContextBuilder; +import java.security.cert.CertStoreException; +import java.security.cert.X509Certificate; +import java.util.Map; + +/** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ +final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { + + private CertProviderClientSslContextProvider( + Node node, + Map certProviders, + CommonTlsContext.CertificateProviderInstance certInstance, + CommonTlsContext.CertificateProviderInstance rootCertInstance, + CertificateValidationContext staticCertValidationContext, + UpstreamTlsContext upstreamTlsContext, + CertificateProviderStore certificateProviderStore) { + super( + node, + certProviders, + certInstance, + checkNotNull(rootCertInstance, "Client SSL requires rootCertInstance"), + staticCertValidationContext, + upstreamTlsContext, + certificateProviderStore); + } + + @Override + protected final SslContextBuilder getSslContextBuilder( + CertificateValidationContext certificateValidationContextdationContext) + throws CertStoreException { + SslContextBuilder sslContextBuilder = + GrpcSslContexts.forClient() + .trustManager( + new SdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContextdationContext)); + if (isMtls()) { + sslContextBuilder.keyManager(savedKey, savedCertChain); + } + return sslContextBuilder; + } + + /** Creates CertProviderClientSslContextProvider. */ + static final class Factory { + private static final Factory DEFAULT_INSTANCE = + new Factory(CertificateProviderStore.getInstance()); + private final CertificateProviderStore certificateProviderStore; + + @VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) { + this.certificateProviderStore = certificateProviderStore; + } + + static Factory getInstance() { + return DEFAULT_INSTANCE; + } + + CertProviderClientSslContextProvider getProvider( + UpstreamTlsContext upstreamTlsContext, + Node node, + Map certProviders) { + checkNotNull(upstreamTlsContext, "upstreamTlsContext"); + CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); + CommonTlsContext.CertificateProviderInstance rootCertInstance = null; + CertificateValidationContext staticCertValidationContext = null; + if (commonTlsContext.hasCombinedValidationContext()) { + CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { + rootCertInstance = + combinedValidationContext.getValidationContextCertificateProviderInstance(); + } + if (combinedValidationContext.hasDefaultValidationContext()) { + staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); + } + } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { + rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance(); + } else if (commonTlsContext.hasValidationContext()) { + staticCertValidationContext = commonTlsContext.getValidationContext(); + } + CommonTlsContext.CertificateProviderInstance certInstance = null; + if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance(); + } + return new CertProviderClientSslContextProvider( + node, + certProviders, + certInstance, + rootCertInstance, + staticCertValidationContext, + upstreamTlsContext, + certificateProviderStore); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java new file mode 100644 index 0000000000..5f03e3becc --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java @@ -0,0 +1,155 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import io.envoyproxy.envoy.config.core.v3.Node; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; +import io.grpc.xds.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; +import io.grpc.xds.internal.sds.DynamicSslContextProvider; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; + +/** Base class for {@link CertProviderClientSslContextProvider}. */ +abstract class CertProviderSslContextProvider extends DynamicSslContextProvider implements + CertificateProvider.Watcher { + + @Nullable private final CertificateProviderStore.Handle certHandle; + @Nullable private final CertificateProviderStore.Handle rootCertHandle; + @Nullable private final CertificateProviderInstance certInstance; + @Nullable private final CertificateProviderInstance rootCertInstance; + @Nullable protected PrivateKey savedKey; + @Nullable protected List savedCertChain; + @Nullable protected List savedTrustedRoots; + + protected CertProviderSslContextProvider( + Node node, + Map certProviders, + CertificateProviderInstance certInstance, + CertificateProviderInstance rootCertInstance, + CertificateValidationContext staticCertValidationContext, + BaseTlsContext tlsContext, + CertificateProviderStore certificateProviderStore) { + super(tlsContext, staticCertValidationContext); + this.certInstance = certInstance; + this.rootCertInstance = rootCertInstance; + String certInstanceName = null; + if (certInstance != null && certInstance.isInitialized()) { + certInstanceName = certInstance.getInstanceName(); + CertificateProviderInfo certProviderInstanceConfig = + getCertProviderConfig(certProviders, certInstanceName); + certHandle = + certificateProviderStore.createOrGetProvider( + certInstance.getCertificateName(), + certProviderInstanceConfig.getPluginName(), + certProviderInstanceConfig.getConfig(), + this, + true); + } else { + certHandle = null; + } + if (rootCertInstance != null + && rootCertInstance.isInitialized() + && !rootCertInstance.getInstanceName().equals(certInstanceName)) { + CertificateProviderInfo certProviderInstanceConfig = + getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); + rootCertHandle = + certificateProviderStore.createOrGetProvider( + rootCertInstance.getCertificateName(), + certProviderInstanceConfig.getPluginName(), + certProviderInstanceConfig.getConfig(), + this, + true); + } else { + rootCertHandle = null; + } + } + + private CertificateProviderInfo getCertProviderConfig( + Map certProviders, String pluginInstanceName) { + return certProviders.get(pluginInstanceName); + } + + @Override + public final void updateCertificate(PrivateKey key, List certChain) { + savedKey = key; + savedCertChain = certChain; + updateSslContextWhenReady(); + } + + @Override + public final void updateTrustedRoots(List trustedRoots) { + savedTrustedRoots = trustedRoots; + updateSslContextWhenReady(); + } + + private void updateSslContextWhenReady() { + if (isMtls()) { + if (savedKey != null && savedTrustedRoots != null) { + updateSslContext(); + clearKeysAndCerts(); + } + } else if (isClientSideTls()) { + if (savedTrustedRoots != null) { + updateSslContext(); + clearKeysAndCerts(); + } + } else if (isServerSideTls()) { + if (savedKey != null) { + updateSslContext(); + clearKeysAndCerts(); + } + } + } + + private void clearKeysAndCerts() { + savedKey = null; + savedTrustedRoots = null; + savedCertChain = null; + } + + protected final boolean isMtls() { + return certInstance != null && rootCertInstance != null; + } + + protected final boolean isClientSideTls() { + return rootCertInstance != null && certInstance == null; + } + + protected final boolean isServerSideTls() { + return certInstance != null && rootCertInstance == null; + } + + @Override + protected final CertificateValidationContext generateCertificateValidationContext() { + return staticCertificateValidationContext; + } + + @Override + public final void close() { + if (certHandle != null) { + certHandle.close(); + } + if (rootCertHandle != null) { + rootCertHandle.close(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java index 5b09a58127..18f8a11ec2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java @@ -129,7 +129,7 @@ public final class CertificateProviderStore { CertificateProviderProvider certProviderProvider = certificateProviderRegistry.getProvider(key.pluginName); if (certProviderProvider == null) { - throw new IllegalArgumentException("Provider not found."); + throw new IllegalArgumentException("Provider not found for " + key.pluginName); } CertificateProvider certProvider = certProviderProvider.createCertificateProvider( key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/DynamicSslContextProvider.java new file mode 100644 index 0000000000..7f40b822f0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/DynamicSslContextProvider.java @@ -0,0 +1,143 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.grpc.Status; +import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; +import io.netty.handler.ssl.ApplicationProtocolConfig; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.IOException; +import java.security.cert.CertStoreException; +import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +/** Base class for dynamic {@link SslContextProvider}s. */ +public abstract class DynamicSslContextProvider extends SslContextProvider { + + protected final List pendingCallbacks = new ArrayList<>(); + @Nullable protected final CertificateValidationContext staticCertificateValidationContext; + @Nullable protected SslContext sslContext; + + protected DynamicSslContextProvider( + BaseTlsContext tlsContext, CertificateValidationContext staticCertValidationContext) { + super(tlsContext); + this.staticCertificateValidationContext = staticCertValidationContext; + } + + @Nullable + public SslContext getSslContext() { + return sslContext; + } + + protected abstract CertificateValidationContext generateCertificateValidationContext(); + + /** Gets a server or client side SslContextBuilder. */ + protected abstract SslContextBuilder getSslContextBuilder( + CertificateValidationContext certificateValidationContext) + throws CertificateException, IOException, CertStoreException; + + // this gets called only when requested secrets are ready... + protected final void updateSslContext() { + try { + CertificateValidationContext localCertValidationContext = + generateCertificateValidationContext(); + SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext); + CommonTlsContext commonTlsContext = getCommonTlsContext(); + if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { + List alpnList = commonTlsContext.getAlpnProtocolsList(); + ApplicationProtocolConfig apn = + new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + alpnList); + sslContextBuilder.applicationProtocolConfig(apn); + } + List pendingCallbacksCopy = null; + SslContext sslContextCopy = null; + synchronized (pendingCallbacks) { + sslContext = sslContextBuilder.build(); + sslContextCopy = sslContext; + pendingCallbacksCopy = clonePendingCallbacksAndClear(); + } + makePendingCallbacks(sslContextCopy, pendingCallbacksCopy); + } catch (Exception e) { + onError(Status.fromThrowable(e)); + throw new RuntimeException(e); + } + } + + protected final void callPerformCallback( + Callback callback, final SslContext sslContextCopy) { + performCallback( + new SslContextGetter() { + @Override + public SslContext get() { + return sslContextCopy; + } + }, + callback + ); + } + + @Override + public final void addCallback(Callback callback) { + checkNotNull(callback, "callback"); + // if there is a computed sslContext just send it + SslContext sslContextCopy = null; + synchronized (pendingCallbacks) { + if (sslContext != null) { + sslContextCopy = sslContext; + } else { + pendingCallbacks.add(callback); + } + } + if (sslContextCopy != null) { + callPerformCallback(callback, sslContextCopy); + } + } + + private final void makePendingCallbacks( + SslContext sslContextCopy, List pendingCallbacksCopy) { + for (Callback callback : pendingCallbacksCopy) { + callPerformCallback(callback, sslContextCopy); + } + } + + /** Propagates error to all the callback receivers. */ + public final void onError(Status error) { + for (Callback callback : clonePendingCallbacksAndClear()) { + callback.onException(error.asException()); + } + } + + private List clonePendingCallbacksAndClear() { + synchronized (pendingCallbacks) { + List copy = ImmutableList.copyOf(pendingCallbacks); + pendingCallbacks.clear(); + return copy; + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java index 943d205fbf..eef2faf2e2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsClientSslContextProvider.java @@ -90,7 +90,7 @@ final class SdsClientSslContextProvider extends SdsSslContextProvider { } @Override - SslContextBuilder getSslContextBuilder( + protected final SslContextBuilder getSslContextBuilder( CertificateValidationContext localCertValidationContext) throws CertificateException, IOException, CertStoreException { SslContextBuilder sslContextBuilder = diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index 412d11ec6e..d157c79683 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -197,7 +197,7 @@ public final class SdsProtocolNegotiators { .findOrCreateClientSslContextProvider(upstreamTlsContext); sslContextProvider.addCallback( - new SslContextProvider.Callback() { + new SslContextProvider.Callback(ctx.executor()) { @Override public void updateSecret(SslContext sslContext) { @@ -220,8 +220,8 @@ public final class SdsProtocolNegotiators { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - }, - ctx.executor()); + } + ); } @Override @@ -370,7 +370,7 @@ public final class SdsProtocolNegotiators { } final SslContextProvider sslContextProvider = sslContextProviderTemp; sslContextProvider.addCallback( - new SslContextProvider.Callback() { + new SslContextProvider.Callback(ctx.executor()) { @Override public void updateSecret(SslContext sslContext) { @@ -389,8 +389,8 @@ public final class SdsProtocolNegotiators { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - }, - ctx.executor()); + } + ); } } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java index 7d31f8f5c8..27afaa455e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsServerSslContextProvider.java @@ -75,7 +75,7 @@ final class SdsServerSslContextProvider extends SdsSslContextProvider { } @Override - SslContextBuilder getSslContextBuilder( + protected final SslContextBuilder getSslContextBuilder( CertificateValidationContext localCertValidationContext) throws CertificateException, IOException, CertStoreException { SslContextBuilder sslContextBuilder = diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java index 26ff694cf5..b70c2f1e5c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java @@ -21,27 +21,18 @@ import static com.google.common.base.Preconditions.checkState; import io.envoyproxy.envoy.api.v2.core.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.Secret; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; -import io.grpc.Status; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; -import io.netty.handler.ssl.ApplicationProtocolConfig; -import io.netty.handler.ssl.SslContext; -import io.netty.handler.ssl.SslContextBuilder; -import java.io.IOException; -import java.security.cert.CertStoreException; -import java.security.cert.CertificateException; -import java.util.ArrayList; -import java.util.List; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; /** Base class for SdsClientSslContextProvider and SdsServerSslContextProvider. */ -abstract class SdsSslContextProvider extends SslContextProvider implements SdsClient.SecretWatcher { +abstract class SdsSslContextProvider extends DynamicSslContextProvider implements + SdsClient.SecretWatcher { private static final Logger logger = Logger.getLogger(SdsSslContextProvider.class.getName()); @@ -49,13 +40,10 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl @Nullable private final SdsClient validationContextSdsClient; @Nullable private final SdsSecretConfig certSdsConfig; @Nullable private final SdsSecretConfig validationContextSdsConfig; - @Nullable private final CertificateValidationContext staticCertificateValidationContext; - private final List pendingCallbacks = new ArrayList<>(); @Nullable protected TlsCertificate tlsCertificate; @Nullable private CertificateValidationContext certificateValidationContext; - @Nullable private SslContext sslContext; - SdsSslContextProvider( + protected SdsSslContextProvider( Node node, SdsSecretConfig certSdsConfig, SdsSecretConfig validationContextSdsConfig, @@ -63,10 +51,9 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl Executor watcherExecutor, Executor channelExecutor, BaseTlsContext tlsContext) { - super(tlsContext); + super(tlsContext, staticCertValidationContext); this.certSdsConfig = certSdsConfig; this.validationContextSdsConfig = validationContextSdsConfig; - this.staticCertificateValidationContext = staticCertValidationContext; if (certSdsConfig != null && certSdsConfig.isInitialized()) { certSdsClient = SdsClient.Factory.createSdsClient(certSdsConfig, node, watcherExecutor, channelExecutor); @@ -87,35 +74,7 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl } @Override - public void addCallback(Callback callback, Executor executor) { - checkNotNull(callback, "callback"); - checkNotNull(executor, "executor"); - // if there is a computed sslContext just send it - SslContext sslContextCopy = sslContext; - if (sslContextCopy != null) { - callPerformCallback(callback, executor, sslContextCopy); - } else { - synchronized (pendingCallbacks) { - pendingCallbacks.add(new CallbackPair(callback, executor)); - } - } - } - - private void callPerformCallback( - Callback callback, Executor executor, final SslContext sslContextCopy) { - performCallback( - new SslContextGetter() { - @Override - public SslContext get() { - return sslContextCopy; - } - }, - callback, - executor); - } - - @Override - public synchronized void onSecretChanged(Secret secretUpdate) { + public final synchronized void onSecretChanged(Secret secretUpdate) { checkNotNull(secretUpdate); if (secretUpdate.hasTlsCertificate()) { checkState( @@ -143,35 +102,8 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl } } - /** Gets a server or client side SslContextBuilder. */ - abstract SslContextBuilder getSslContextBuilder( - CertificateValidationContext localCertValidationContext) - throws CertificateException, IOException, CertStoreException; - - // this gets called only when requested secrets are ready... - private void updateSslContext() { - try { - CertificateValidationContext localCertValidationContext = mergeStaticAndDynamicCertContexts(); - SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext); - CommonTlsContext commonTlsContext = getCommonTlsContext(); - if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { - List alpnList = commonTlsContext.getAlpnProtocolsList(); - ApplicationProtocolConfig apn = new ApplicationProtocolConfig( - ApplicationProtocolConfig.Protocol.ALPN, - ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, - ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, - alpnList); - sslContextBuilder.applicationProtocolConfig(apn); - } - SslContext sslContextCopy = sslContextBuilder.build(); - sslContext = sslContextCopy; - makePendingCallbacks(sslContextCopy); - } catch (CertificateException | IOException | CertStoreException e) { - logger.log(Level.SEVERE, "exception in updateSslContext", e); - } - } - - private CertificateValidationContext mergeStaticAndDynamicCertContexts() { + @Override + protected final CertificateValidationContext generateCertificateValidationContext() { if (staticCertificateValidationContext == null) { return certificateValidationContext; } @@ -183,27 +115,8 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl return localCertContextBuilder.mergeFrom(staticCertificateValidationContext).build(); } - private void makePendingCallbacks(SslContext sslContextCopy) { - synchronized (pendingCallbacks) { - for (CallbackPair pair : pendingCallbacks) { - callPerformCallback(pair.callback, pair.executor, sslContextCopy); - } - pendingCallbacks.clear(); - } - } - @Override - public void onError(Status error) { - synchronized (pendingCallbacks) { - for (CallbackPair callbackPair : pendingCallbacks) { - callbackPair.callback.onException(error.asException()); - } - pendingCallbacks.clear(); - } - } - - @Override - public void close() { + public final void close() { if (certSdsClient != null) { certSdsClient.cancelSecretWatch(this); certSdsClient.shutdown(); @@ -213,14 +126,4 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl validationContextSdsClient.shutdown(); } } - - private static class CallbackPair { - private final Callback callback; - private final Executor executor; - - private CallbackPair(Callback callback, Executor executor) { - this.callback = callback; - this.executor = executor; - } - } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java index f2ea0546c8..fc3befd5f1 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeClientSslContextProvider.java @@ -34,7 +34,6 @@ import java.io.File; import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; -import java.util.concurrent.Executor; import javax.annotation.Nullable; /** A client SslContext provider that uses file-based secrets (secret volume). */ @@ -92,9 +91,8 @@ final class SecretVolumeClientSslContextProvider extends SslContextProvider { } @Override - public void addCallback(final Callback callback, Executor executor) { + public void addCallback(final Callback callback) { checkNotNull(callback, "callback"); - checkNotNull(executor, "executor"); // as per the contract we will read the current secrets on disk // this involves I/O which can potentially block the executor performCallback( @@ -104,8 +102,8 @@ final class SecretVolumeClientSslContextProvider extends SslContextProvider { return buildSslContextFromSecrets(); } }, - callback, - executor); + callback + ); } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java index 05afb1f4cb..3282fc555c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeServerSslContextProvider.java @@ -33,7 +33,6 @@ import java.io.File; import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; -import java.util.concurrent.Executor; import javax.annotation.Nullable; /** A server SslContext provider that uses file-based secrets (secret volume). */ @@ -85,9 +84,8 @@ final class SecretVolumeServerSslContextProvider extends SslContextProvider { } @Override - public void addCallback(final Callback callback, Executor executor) { + public void addCallback(final Callback callback) { checkNotNull(callback, "callback"); - checkNotNull(executor, "executor"); // as per the contract we will read the current secrets on disk // this involves I/O which can potentially block the executor performCallback( @@ -97,8 +95,8 @@ final class SecretVolumeServerSslContextProvider extends SslContextProvider { return buildSslContextFromSecrets(); } }, - callback, - executor); + callback + ); } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java index 6c73b7c737..08e93d3a4e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java @@ -47,19 +47,25 @@ public abstract class SslContextProvider implements Closeable { protected final BaseTlsContext tlsContext; - public interface Callback { + abstract static class Callback { + private final Executor executor; + + protected Callback(Executor executor) { + this.executor = executor; + } + /** Informs callee of new/updated SslContext. */ - void updateSecret(SslContext sslContext); + abstract void updateSecret(SslContext sslContext); /** Informs callee of an exception that was generated. */ - void onException(Throwable throwable); + abstract void onException(Throwable throwable); } - SslContextProvider(BaseTlsContext tlsContext) { + protected SslContextProvider(BaseTlsContext tlsContext) { this.tlsContext = checkNotNull(tlsContext, "tlsContext"); } - CommonTlsContext getCommonTlsContext() { + protected CommonTlsContext getCommonTlsContext() { return tlsContext.getCommonTlsContext(); } @@ -100,14 +106,13 @@ public abstract class SslContextProvider implements Closeable { * Registers a callback on the given executor. The callback will run when SslContext becomes * available or immediately if the result is already available. */ - public abstract void addCallback(Callback callback, Executor executor); + public abstract void addCallback(Callback callback); - final void performCallback( - final SslContextGetter sslContextGetter, final Callback callback, Executor executor) { + protected final void performCallback( + final SslContextGetter sslContextGetter, final Callback callback) { checkNotNull(sslContextGetter, "sslContextGetter"); checkNotNull(callback, "callback"); - checkNotNull(executor, "executor"); - executor.execute( + callback.executor.execute( new Runnable() { @Override public void run() { diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java index 6ff63c0e5f..479569f159 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java @@ -16,7 +16,7 @@ package io.grpc.xds.internal.sds.trust; -import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; @@ -53,9 +53,29 @@ public final class SdsTrustManagerFactory extends SimpleTrustManagerFactory { /** Constructor constructs from a {@link CertificateValidationContext}. */ public SdsTrustManagerFactory(CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException { - checkNotNull(certificateValidationContext, "certificateValidationContext"); - sdsX509TrustManager = createSdsX509TrustManager( - getTrustedCaFromCertContext(certificateValidationContext), certificateValidationContext); + this( + getTrustedCaFromCertContext(certificateValidationContext), + certificateValidationContext, + false); + } + + public SdsTrustManagerFactory( + X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext) + throws CertStoreException { + this(certs, staticCertificateValidationContext, true); + } + + private SdsTrustManagerFactory( + X509Certificate[] certs, + CertificateValidationContext certificateValidationContext, + boolean validationContextIsStatic) + throws CertStoreException { + if (validationContextIsStatic) { + checkArgument( + certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), + "only static certificateValidationContext expected"); + } + sdsX509TrustManager = createSdsX509TrustManager(certs, certificateValidationContext); } private static X509Certificate[] getTrustedCaFromCertContext( diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java index bb23a59bb7..6b7324ddf7 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java @@ -59,7 +59,8 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509 } // Copied from OkHostnameVerifier.verifyHostName(). - private static boolean verifyDnsNameInPattern(String pattern, String sanToVerify) { + private static boolean verifyDnsNameInPattern(String pattern, StringMatcher sanToVerifyMatcher) { + String sanToVerify = sanToVerifyMatcher.getExact(); // Basic sanity checks // Check length == 0 instead of .isEmpty() to support Java 5. if (sanToVerify == null @@ -150,9 +151,9 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509 // sanToVerify matches pattern } - private static boolean verifyDnsNameInSanList(String altNameFromCert, - List verifySanList) { - for (String verifySan : verifySanList) { + private static boolean verifyDnsNameInSanList( + String altNameFromCert, List verifySanList) { + for (StringMatcher verifySan : verifySanList) { if (verifyDnsNameInPattern(altNameFromCert, verifySan)) { return true; } @@ -168,16 +169,17 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509 * @param verifySanList list of SANs from certificate context * @return true if there is a match */ - private static boolean verifyStringInSanList(String stringFromCert, List verifySanList) { - for (String sanToVerify : verifySanList) { - if (Ascii.equalsIgnoreCase(sanToVerify, stringFromCert)) { + private static boolean verifyStringInSanList( + String stringFromCert, List verifySanList) { + for (StringMatcher sanToVerify : verifySanList) { + if (Ascii.equalsIgnoreCase(sanToVerify.getExact(), stringFromCert)) { return true; } } return false; } - private static boolean verifyOneSanInList(List entry, List verifySanList) + private static boolean verifyOneSanInList(List entry, List verifySanList) throws CertificateParsingException { // from OkHostnameVerifier.getSubjectAltNames if (entry == null || entry.size() < 2) { @@ -200,9 +202,8 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509 } // logic from Envoy::Extensions::TransportSockets::Tls::ContextImpl::verifySubjectAltName - @SuppressWarnings("UnusedMethod") // TODO(#7166): support StringMatcher list. - private static void verifySubjectAltNameInLeaf(X509Certificate cert, List verifyList) - throws CertificateException { + private static void verifySubjectAltNameInLeaf( + X509Certificate cert, List verifyList) throws CertificateException { Collection> names = cert.getSubjectAlternativeNames(); if (names == null || names.isEmpty()) { throw new CertificateException("Peer certificate SAN check failed"); @@ -233,9 +234,7 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509 throw new CertificateException("Peer certificate(s) missing"); } // verify SANs only in the top cert (leaf cert) - // v2 version: verifySubjectAltNameInLeaf(peerCertChain[0], verifyList); - // TODO(#7166): Implement v3 version. - throw new UnsupportedOperationException(); + verifySubjectAltNameInLeaf(peerCertChain[0], verifyList); } @Override diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java new file mode 100644 index 0000000000..6e9ae15ee5 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java @@ -0,0 +1,293 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils.getCertFromResourceName; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext; +import static org.junit.Assert.fail; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.MoreExecutors; +import io.envoyproxy.envoy.config.core.v3.DataSource; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.grpc.xds.Bootstrapper; +import io.grpc.xds.EnvoyServerProtoData; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link CertProviderClientSslContextProvider}. */ +@RunWith(JUnit4.class) +public class CertProviderClientSslContextProviderTest { + private static final Logger logger = + Logger.getLogger(CertProviderClientSslContextProviderTest.class.getName()); + + CertificateProviderRegistry certificateProviderRegistry; + CertificateProviderStore certificateProviderStore; + private CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory; + + @Before + public void setUp() throws Exception { + certificateProviderRegistry = new CertificateProviderRegistry(); + certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); + certProviderClientSslContextProviderFactory = + new CertProviderClientSslContextProvider.Factory(certificateProviderStore); + } + + /** Helper method to build CertProviderClientSslContextProvider. */ + private CertProviderClientSslContextProvider getSslContextProvider( + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } + + @Test + public void testProviderForClient_mtls() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + "gcp_id", + CommonCertProviderTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + + @Test + public void testProviderForClient_queueExecutor() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + "gcp_id", + CommonCertProviderTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + QueuedExecutor queuedExecutor = new QueuedExecutor(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider, queuedExecutor); + assertThat(queuedExecutor.runQueue).isEmpty(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(queuedExecutor.runQueue).isEmpty(); // still empty + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(queuedExecutor.runQueue).hasSize(1); + queuedExecutor.drain(); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + } + + @Test + public void testProviderForClient_tls() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + /* certInstanceName= */ null, + "gcp_id", + CommonCertProviderTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + } + + @Test + public void testProviderForClient_sslContextException_onError() throws Exception { + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder() + .setTrustedCa(DataSource.newBuilder().setInlineString("foo")) + .build(); + + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + /* certInstanceName= */ null, + "gcp_id", + CommonCertProviderTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */null, + staticCertValidationContext); + + TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor()); + provider.addCallback(testCallback); + try { + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + fail("exception expected"); + } catch (RuntimeException expected) { + assertThat(expected) + .hasMessageThat() + .contains("only static certificateValidationContext expected"); + } + assertThat(testCallback.updatedThrowable).isNotNull(); + assertThat(testCallback.updatedThrowable) + .hasCauseThat() + .hasMessageThat() + .contains("only static certificateValidationContext expected"); + } + + @Test + public void testProviderForClient_rootInstanceNull_expectError() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + try { + getSslContextProvider( + /* certInstanceName= */ null, + /* rootInstanceName= */ null, + CommonCertProviderTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + fail("exception expected"); + } catch (NullPointerException expected) { + assertThat(expected).hasMessageThat().contains("Client SSL requires rootCertInstance"); + } + } + + static class QueuedExecutor implements Executor { + /** A list of Runnables to be run in order. */ + private final Queue runQueue = new ConcurrentLinkedQueue<>(); + + @Override + public synchronized void execute(Runnable r) { + runQueue.add(checkNotNull(r, "'r' must not be null.")); + } + + public synchronized void drain() { + Runnable r; + while ((r = runQueue.poll()) != null) { + try { + r.run(); + } catch (RuntimeException e) { + // Log it and keep going. + logger.log(Level.SEVERE, "Exception while executing runnable " + r, e); + } + } + } + } + +} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java index 569d72bf43..53144c2d48 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java @@ -49,36 +49,6 @@ public class CertificateProviderStoreTest { private CertificateProviderStore certificateProviderStore; private boolean throwExceptionForCertUpdates; - private class TestCertificateProvider extends CertificateProvider { - Object config; - CertificateProviderProvider certProviderProvider; - int closeCalled = 0; - int startCalled = 0; - - protected TestCertificateProvider( - CertificateProvider.DistributorWatcher watcher, - boolean notifyCertUpdates, - Object config, - CertificateProviderProvider certificateProviderProvider) { - super(watcher, notifyCertUpdates); - if (throwExceptionForCertUpdates && notifyCertUpdates) { - throw new UnsupportedOperationException("Provider does not support Certificate Updates."); - } - this.config = config; - this.certProviderProvider = certificateProviderProvider; - } - - @Override - public void close() { - closeCalled++; - } - - @Override - public void start() { - startCalled++; - } - } - @Before public void setUp() { certificateProviderRegistry = new CertificateProviderRegistry(); @@ -94,7 +64,7 @@ public class CertificateProviderStoreTest { "cert-name1", "plugin1", "config", mockWatcher, true); fail("exception expected"); } catch (IllegalArgumentException expected) { - assertThat(expected).hasMessageThat().isEqualTo("Provider not found."); + assertThat(expected).hasMessageThat().isEqualTo("Provider not found for plugin1"); } } @@ -111,7 +81,7 @@ public class CertificateProviderStoreTest { "cert-name1", "plugin1", "config", mockWatcher, true); fail("exception expected"); } catch (IllegalArgumentException expected) { - assertThat(expected).hasMessageThat().isEqualTo("Provider not found."); + assertThat(expected).hasMessageThat().isEqualTo("Provider not found for plugin1"); } } @@ -369,7 +339,8 @@ public class CertificateProviderStoreTest { (CertificateProvider.DistributorWatcher) args[1]; boolean notifyCertUpdates = (Boolean) args[2]; return new TestCertificateProvider( - watcher, notifyCertUpdates, config, certProviderProvider); + watcher, notifyCertUpdates, config, certProviderProvider, + throwExceptionForCertUpdates); } }); certificateProviderRegistry.register(certProviderProvider); diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java new file mode 100644 index 0000000000..3a056c209f --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java @@ -0,0 +1,177 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.io.CharStreams; +import io.grpc.internal.testing.TestUtils; +import io.grpc.xds.Bootstrapper; +import io.grpc.xds.internal.sds.trust.CertificateUtils; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import io.netty.util.CharsetUtil; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.Reader; +import java.security.KeyException; +import java.security.KeyFactory; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class CommonCertProviderTestUtils { + private static final Logger logger = + Logger.getLogger(CommonCertProviderTestUtils.class.getName()); + + private static final Pattern KEY_PATTERN = Pattern.compile( + "-+BEGIN\\s+.*PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+" + // Header + "([a-z0-9+/=\\r\\n]+)" + // Base64 text + "-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer + Pattern.CASE_INSENSITIVE); + + static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException { + String rawData = + "{\n" + + " \"xds_servers\": [],\n" + + " \"certificate_providers\": {\n" + + " \"gcp_id\": {\n" + + " \"plugin_name\": \"testca\",\n" + + " \"config\": {\n" + + " \"server\": {\n" + + " \"api_type\": \"GRPC\",\n" + + " \"grpc_services\": [{\n" + + " \"google_grpc\": {\n" + + " \"target_uri\": \"meshca.com\",\n" + + " \"channel_credentials\": {\"google_default\": {}},\n" + + " \"call_credentials\": [{\n" + + " \"sts_service\": {\n" + + " \"token_exchange_service\": \"securetoken.googleapis.com\",\n" + + " \"subject_token_path\": \"/etc/secret/sajwt.token\"\n" + + " }\n" + + " }]\n" // end call_credentials + + " },\n" // end google_grpc + + " \"time_out\": {\"seconds\": 10}\n" + + " }]\n" // end grpc_services + + " },\n" // end server + + " \"certificate_lifetime\": {\"seconds\": 86400},\n" + + " \"renewal_grace_period\": {\"seconds\": 3600},\n" + + " \"key_type\": \"RSA\",\n" + + " \"key_size\": 2048,\n" + + " \"location\": \"https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" + + " }\n" // end config + + " },\n" // end gcp_id + + " \"file_provider\": {\n" + + " \"plugin_name\": \"file_watcher\",\n" + + " \"config\": {\"path\": \"/etc/secret/certs\"}\n" + + " }\n" + + " }\n" + + "}"; + return Bootstrapper.parseConfig(rawData); + } + + static PrivateKey getPrivateKey(String resourceName) + throws Exception { + InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName); + ByteBuf encodedKeyBuf = readPrivateKey(inputStream); + + byte[] encodedKey = new byte[encodedKeyBuf.readableBytes()]; + encodedKeyBuf.readBytes(encodedKey).release(); + PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(encodedKey); + try { + return KeyFactory.getInstance("RSA").generatePrivate(spec); + } catch (InvalidKeySpecException ignore) { + try { + return KeyFactory.getInstance("DSA").generatePrivate(spec); + } catch (InvalidKeySpecException ignore2) { + try { + return KeyFactory.getInstance("EC").generatePrivate(spec); + } catch (InvalidKeySpecException e) { + throw new InvalidKeySpecException("Neither RSA, DSA nor EC worked", e); + } + } + } + } + + static ByteBuf readPrivateKey(InputStream in) throws KeyException { + String content; + try { + content = readContent(in); + } catch (IOException e) { + throw new KeyException("failed to read key input stream", e); + } + Matcher m = KEY_PATTERN.matcher(content); + if (!m.find()) { + throw new KeyException("could not find a PKCS #8 private key in input stream"); + } + ByteBuf base64 = Unpooled.copiedBuffer(m.group(1), CharsetUtil.US_ASCII); + ByteBuf der = Base64.decode(base64); + base64.release(); + return der; + } + + private static String readContent(InputStream in) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + byte[] buf = new byte[8192]; + for (; ; ) { + int ret = in.read(buf); + if (ret < 0) { + break; + } + out.write(buf, 0, ret); + } + return out.toString(CharsetUtil.US_ASCII.name()); + } finally { + safeClose(out); + } + } + + private static void safeClose(OutputStream out) { + try { + out.close(); + } catch (IOException e) { + logger.log(Level.WARNING, "Failed to close a stream.", e); + } + } + + static X509Certificate getCertFromResourceName(String resourceName) + throws IOException, CertificateException { + return CertificateUtils.toX509Certificate( + new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))); + } + + private static String getResourceContents(String resourceName) throws IOException { + InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName); + String text = null; + try (Reader reader = new InputStreamReader(inputStream, UTF_8)) { + text = CharStreams.toString(reader); + } + return text; + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java new file mode 100644 index 0000000000..406ae4f0bb --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +public class TestCertificateProvider extends CertificateProvider { + Object config; + CertificateProviderProvider certProviderProvider; + int closeCalled = 0; + int startCalled = 0; + + TestCertificateProvider( + DistributorWatcher watcher, + boolean notifyCertUpdates, + Object config, + CertificateProviderProvider certificateProviderProvider, + boolean throwExceptionForCertUpdates) { + super(watcher, notifyCertUpdates); + if (throwExceptionForCertUpdates && notifyCertUpdates) { + throw new UnsupportedOperationException("Provider does not support Certificate Updates."); + } + this.config = config; + this.certProviderProvider = certificateProviderProvider; + } + + @Override + public void close() { + closeCalled++; + } + + @Override + public void start() { + startCalled++; + } + + static void createAndRegisterProviderProvider( + CertificateProviderRegistry certificateProviderRegistry, + final CertificateProvider.DistributorWatcher[] watcherCaptor, + String testca, + final int index) { + final CertificateProviderProvider mockProviderProviderTestCa = + new TestCertificateProviderProvider(testca, watcherCaptor, index); + certificateProviderRegistry.register(mockProviderProviderTestCa); + } + + private static class TestCertificateProviderProvider implements CertificateProviderProvider { + + private final String testCa; + private final CertificateProvider.DistributorWatcher[] watcherCaptor; + private final int index; + + TestCertificateProviderProvider( + String testCa, CertificateProvider.DistributorWatcher[] watcherCaptor, int index) { + this.testCa = testCa; + this.watcherCaptor = watcherCaptor; + this.index = index; + } + + @Override + public String getName() { + return testCa; + } + + @Override + public CertificateProvider createCertificateProvider( + Object config, DistributorWatcher watcher, boolean notifyCertUpdates) { + watcherCaptor[index] = watcher; + return new TestCertificateProvider(watcher, true, config, this, false); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index afa57f96c4..5f769ae950 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -16,10 +16,12 @@ package io.grpc.xds.internal.sds; +import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Strings; import com.google.common.io.CharStreams; +import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.BoolValue; import com.google.protobuf.Struct; import com.google.protobuf.Value; @@ -40,6 +42,7 @@ import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.internal.testing.TestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.internal.sds.trust.CertificateUtils; +import io.netty.handler.ssl.SslContext; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; @@ -48,6 +51,8 @@ import java.io.Reader; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Executor; import javax.annotation.Nullable; /** Utility class for client and server ssl provider tests. */ @@ -461,4 +466,122 @@ public class CommonTlsContextTestsUtil { } return text; } + + private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( + String certInstanceName, + String certName, + String rootInstanceName, + String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (certInstanceName != null) { + builder = + builder.setTlsCertificateCertificateProviderInstance( + CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); + } + builder = + addCertificateValidationContext( + builder, rootInstanceName, rootCertName, staticCertValidationContext); + if (alpnProtocols != null) { + builder.addAllAlpnProtocols(alpnProtocols); + } + return builder.build(); + } + + private static CommonTlsContext.Builder addCertificateValidationContext( + CommonTlsContext.Builder builder, + String rootInstanceName, + String rootCertName, + CertificateValidationContext staticCertValidationContext) { + if (rootInstanceName != null) { + CommonTlsContext.CertificateProviderInstance.Builder providerInstanceBuilder = + CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName(rootInstanceName) + .setCertificateName(rootCertName); + if (staticCertValidationContext != null) { + CombinedCertificateValidationContext combined = + CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext(staticCertValidationContext) + .setValidationContextCertificateProviderInstance(providerInstanceBuilder) + .build(); + return builder.setCombinedValidationContext(combined); + } + builder = builder.setValidationContextCertificateProviderInstance(providerInstanceBuilder); + } + return builder; + } + + /** Helper method to build UpstreamTlsContext for CertProvider tests. */ + public static EnvoyServerProtoData.UpstreamTlsContext + buildUpstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + return buildUpstreamTlsContext( + buildCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext)); + } + + /** Perform some simple checks on sslContext. */ + public static void doChecksOnSslContext(boolean server, SslContext sslContext, + List expectedApnProtos) { + if (server) { + assertThat(sslContext.isServer()).isTrue(); + } else { + assertThat(sslContext.isClient()).isTrue(); + } + List apnProtos = sslContext.applicationProtocolNegotiator().protocols(); + assertThat(apnProtos).isNotNull(); + if (expectedApnProtos != null) { + assertThat(apnProtos).isEqualTo(expectedApnProtos); + } else { + assertThat(apnProtos).contains("h2"); + } + } + + /** + * Helper method to get the value thru directExecutor callback. Because of directExecutor this is + * a synchronous callback - so need to provide a listener. + */ + public static TestCallback getValueThruCallback(SslContextProvider provider) { + return getValueThruCallback(provider, MoreExecutors.directExecutor()); + } + + /** Helper method to get the value thru callback with a user passed executor. */ + public static TestCallback getValueThruCallback(SslContextProvider provider, Executor executor) { + TestCallback testCallback = new TestCallback(executor); + provider.addCallback(testCallback); + return testCallback; + } + + public static class TestCallback extends SslContextProvider.Callback { + + public SslContext updatedSslContext; + public Throwable updatedThrowable; + + public TestCallback(Executor executor) { + super(executor); + } + + @Override + public void updateSecret(SslContext sslContext) { + updatedSslContext = sslContext; + } + + @Override + public void onException(Throwable throwable) { + updatedThrowable = throwable; + } + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java index 62779ec0a9..6ecb533e28 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java @@ -22,9 +22,10 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.getValueThruCallback; import static io.grpc.xds.internal.sds.SdsClientTest.getOneCertificateValidationContextSecret; import static io.grpc.xds.internal.sds.SdsClientTest.getOneTlsCertSecret; -import static io.grpc.xds.internal.sds.SecretVolumeSslContextProviderTest.doChecksOnSslContext; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -33,6 +34,7 @@ import io.envoyproxy.envoy.api.v2.core.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.Status.Code; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; import java.io.IOException; import java.util.Arrays; import org.junit.After; @@ -123,8 +125,7 @@ public class SdsSslContextProviderTest { SdsServerSslContextProvider provider = getSdsServerSslContextProvider("cert1", "valid1", null, null); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null); } @@ -142,8 +143,7 @@ public class SdsSslContextProviderTest { /* validationContextName= */ "valid1", /* matchSubjectAltNames= */ null, /* alpnProtocols= */ null); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); } @@ -159,8 +159,7 @@ public class SdsSslContextProviderTest { /* validationContextName= */ null, /* matchSubjectAltNames= */ null, /* alpnProtocols= */ null); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null); } @@ -176,8 +175,7 @@ public class SdsSslContextProviderTest { /* validationContextName= */ "valid1", /* matchSubjectAltNames= */ null, null); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); } @@ -193,8 +191,7 @@ public class SdsSslContextProviderTest { /* validationContextName= */ "valid1", /* matchSubjectAltNames= */ null, /* alpnProtocols= */ null); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); assertThat(server.lastNack).isNotNull(); assertThat(server.lastNack.getVersionInfo()).isEmpty(); @@ -222,8 +219,7 @@ public class SdsSslContextProviderTest { .build()), /* alpnProtocols= */ null); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); } @@ -240,8 +236,7 @@ public class SdsSslContextProviderTest { /* validationContextName= */ "valid1", /* matchSubjectAltNames= */ null, /* alpnProtocols= */ Arrays.asList("managed-mtls", "h2")); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext( false, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2")); @@ -260,8 +255,7 @@ public class SdsSslContextProviderTest { /* validationContextName= */ "valid1", /* matchSubjectAltNames= */ null, /* alpnProtocols= */ Arrays.asList("managed-mtls", "h2")); - SecretVolumeSslContextProviderTest.TestCallback testCallback = - SecretVolumeSslContextProviderTest.getValueThruCallback(provider); + TestCallback testCallback = getValueThruCallback(provider); doChecksOnSslContext( true, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2")); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java index 551b8a7f6d..44a5c461dc 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java @@ -22,16 +22,17 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.getValueThruCallback; -import com.google.common.util.concurrent.MoreExecutors; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; import io.netty.handler.ssl.SslContext; import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; -import java.util.List; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -371,22 +372,6 @@ public class SecretVolumeSslContextProviderTest { doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null); } - static void doChecksOnSslContext(boolean server, SslContext sslContext, - List expectedApnProtos) { - if (server) { - assertThat(sslContext.isServer()).isTrue(); - } else { - assertThat(sslContext.isClient()).isTrue(); - } - List apnProtos = sslContext.applicationProtocolNegotiator().protocols(); - assertThat(apnProtos).isNotNull(); - if (expectedApnProtos != null) { - assertThat(apnProtos).isEqualTo(expectedApnProtos); - } else { - assertThat(apnProtos).contains("h2"); - } - } - @Test public void getProviderForServer() throws IOException, CertificateException, CertStoreException { sslContextForEitherWithBothCertAndTrust( @@ -421,32 +406,6 @@ public class SecretVolumeSslContextProviderTest { } } - static class TestCallback implements SslContextProvider.Callback { - - SslContext updatedSslContext; - Throwable updatedThrowable; - - @Override - public void updateSecret(SslContext sslContext) { - updatedSslContext = sslContext; - } - - @Override - public void onException(Throwable throwable) { - updatedThrowable = throwable; - } - } - - /** - * Helper method to get the value thru directExecutor callback. Because of directExecutor this is - * a synchronous callback - so need to provide a listener. - */ - static TestCallback getValueThruCallback(SslContextProvider provider) { - TestCallback testCallback = new TestCallback(); - provider.addCallback(testCallback, MoreExecutors.directExecutor()); - return testCallback; - } - @Test public void getProviderForServer_both_callsback() throws IOException { SecretVolumeServerSslContextProvider provider = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java index 53ab963eb1..47ac9e6bb4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java @@ -26,6 +26,7 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FI import com.google.protobuf.ByteString; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.internal.testing.TestUtils; import java.io.IOException; import java.security.cert.CertStoreException; @@ -80,6 +81,100 @@ public class SdsTrustManagerFactoryTest { .isEqualTo(CertificateUtils.toX509Certificates(TestUtils.loadCert(CA_PEM_FILE))[0]); } + @Test + public void constructor_fromRootCert() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", + "san2"); + SdsTrustManagerFactory factory = + new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + assertThat(factory).isNotNull(); + TrustManager[] tms = factory.getTrustManagers(); + assertThat(tms).isNotNull(); + assertThat(tms).hasLength(1); + TrustManager myTm = tms[0]; + assertThat(myTm).isInstanceOf(SdsX509TrustManager.class); + SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) myTm; + X509Certificate[] acceptedIssuers = sdsX509TrustManager.getAcceptedIssuers(); + assertThat(acceptedIssuers).isNotNull(); + assertThat(acceptedIssuers).hasLength(1); + X509Certificate caCert = acceptedIssuers[0]; + assertThat(caCert) + .isEqualTo(CertificateUtils.toX509Certificates(TestUtils.loadCert(CA_PEM_FILE))[0]); + } + + @Test + public void constructorRootCert_checkServerTrusted() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", + "waterzooi.test.google.be"); + SdsTrustManagerFactory factory = + new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + X509Certificate[] serverChain = + CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); + sdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); + } + + @Test + public void constructorRootCert_nonStaticContext_throwsException() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + try { + new SdsTrustManagerFactory( + new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE)); + Assert.fail("no exception thrown"); + } catch (IllegalArgumentException expected) { + assertThat(expected) + .hasMessageThat() + .contains("only static certificateValidationContext expected"); + } + } + + @Test + public void constructorRootCert_checkServerTrusted_throwsException() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", + "san2"); + SdsTrustManagerFactory factory = + new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + X509Certificate[] serverChain = + CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); + try { + sdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); + Assert.fail("no exception thrown"); + } catch (CertificateException expected) { + assertThat(expected) + .hasMessageThat() + .contains("Peer certificate SAN check failed"); + } + } + + @Test + public void constructorRootCert_checkClientTrusted_throwsException() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", + "san2"); + SdsTrustManagerFactory factory = + new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + X509Certificate[] clientChain = + CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); + try { + sdsX509TrustManager.checkClientTrusted(clientChain, "RSA"); + Assert.fail("no exception thrown"); + } catch (CertificateException expected) { + assertThat(expected) + .hasMessageThat() + .contains("Peer certificate SAN check failed"); + } + } + @Test public void checkServerTrusted_goodCert() throws CertificateException, IOException, CertStoreException { @@ -156,4 +251,13 @@ public class SdsTrustManagerFactoryTest { DataSource.newBuilder().setInlineBytes(ByteString.copyFrom(x509Cert.getEncoded()))) .build(); } + + private static final CertificateValidationContext buildStaticValidationContext( + String... verifySans) { + CertificateValidationContext.Builder builder = CertificateValidationContext.newBuilder(); + for (String san : verifySans) { + builder.addMatchSubjectAltNames(StringMatcher.newBuilder().setExact(san)); + } + return builder.build(); + } }