xds: add CertProviderSslContextProvider support (#7309)

This commit is contained in:
sanjaypujare 2020-08-17 09:45:13 -07:00 committed by GitHub
parent 1c269e4289
commit 39c49b0408
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1301 additions and 250 deletions

View File

@ -76,9 +76,10 @@ public abstract class Bootstrapper {
*/
public abstract BootstrapInfo readBootstrap() throws IOException;
/** Parses a raw string into {@link BootstrapInfo}. */
@VisibleForTesting
@SuppressWarnings("deprecation")
static BootstrapInfo parseConfig(String rawData) throws IOException {
public static BootstrapInfo parseConfig(String rawData) throws IOException {
XdsLogger logger = XdsLogger.withPrefix(LOG_PREFIX);
logger.log(XdsLogLevel.INFO, "Reading bootstrap information");
@SuppressWarnings("unchecked")
@ -264,11 +265,11 @@ public abstract class Bootstrapper {
this.config = checkNotNull(config, "config");
}
String getPluginName() {
public String getPluginName() {
return pluginName;
}
Map<String, ?> getConfig() {
public Map<String, ?> getConfig() {
return config;
}
}

View File

@ -325,7 +325,8 @@ final class EnvoyProtoData {
return listeningAddresses;
}
io.envoyproxy.envoy.config.core.v3.Node toEnvoyProtoNode() {
@VisibleForTesting
public io.envoyproxy.envoy.config.core.v3.Node toEnvoyProtoNode() {
io.envoyproxy.envoy.config.core.v3.Node.Builder builder =
io.envoyproxy.envoy.config.core.v3.Node.newBuilder();
builder.setId(id);

View File

@ -0,0 +1,123 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.SslContextBuilder;
import java.security.cert.CertStoreException;
import java.security.cert.X509Certificate;
import java.util.Map;
/** A client SslContext provider using CertificateProviderInstance to fetch secrets. */
final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider {
private CertProviderClientSslContextProvider(
Node node,
Map<String, CertificateProviderInfo> certProviders,
CommonTlsContext.CertificateProviderInstance certInstance,
CommonTlsContext.CertificateProviderInstance rootCertInstance,
CertificateValidationContext staticCertValidationContext,
UpstreamTlsContext upstreamTlsContext,
CertificateProviderStore certificateProviderStore) {
super(
node,
certProviders,
certInstance,
checkNotNull(rootCertInstance, "Client SSL requires rootCertInstance"),
staticCertValidationContext,
upstreamTlsContext,
certificateProviderStore);
}
@Override
protected final SslContextBuilder getSslContextBuilder(
CertificateValidationContext certificateValidationContextdationContext)
throws CertStoreException {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forClient()
.trustManager(
new SdsTrustManagerFactory(
savedTrustedRoots.toArray(new X509Certificate[0]),
certificateValidationContextdationContext));
if (isMtls()) {
sslContextBuilder.keyManager(savedKey, savedCertChain);
}
return sslContextBuilder;
}
/** Creates CertProviderClientSslContextProvider. */
static final class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory(CertificateProviderStore.getInstance());
private final CertificateProviderStore certificateProviderStore;
@VisibleForTesting Factory(CertificateProviderStore certificateProviderStore) {
this.certificateProviderStore = certificateProviderStore;
}
static Factory getInstance() {
return DEFAULT_INSTANCE;
}
CertProviderClientSslContextProvider getProvider(
UpstreamTlsContext upstreamTlsContext,
Node node,
Map<String, CertificateProviderInfo> certProviders) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
CommonTlsContext.CertificateProviderInstance rootCertInstance = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) {
rootCertInstance =
combinedValidationContext.getValidationContextCertificateProviderInstance();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) {
rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
CommonTlsContext.CertificateProviderInstance certInstance = null;
if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) {
certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance();
}
return new CertProviderClientSslContextProvider(
node,
certProviders,
certInstance,
rootCertInstance,
staticCertValidationContext,
upstreamTlsContext,
certificateProviderStore);
}
}
}

View File

@ -0,0 +1,155 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
import io.grpc.xds.internal.sds.DynamicSslContextProvider;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
/** Base class for {@link CertProviderClientSslContextProvider}. */
abstract class CertProviderSslContextProvider extends DynamicSslContextProvider implements
CertificateProvider.Watcher {
@Nullable private final CertificateProviderStore.Handle certHandle;
@Nullable private final CertificateProviderStore.Handle rootCertHandle;
@Nullable private final CertificateProviderInstance certInstance;
@Nullable private final CertificateProviderInstance rootCertInstance;
@Nullable protected PrivateKey savedKey;
@Nullable protected List<X509Certificate> savedCertChain;
@Nullable protected List<X509Certificate> savedTrustedRoots;
protected CertProviderSslContextProvider(
Node node,
Map<String, CertificateProviderInfo> certProviders,
CertificateProviderInstance certInstance,
CertificateProviderInstance rootCertInstance,
CertificateValidationContext staticCertValidationContext,
BaseTlsContext tlsContext,
CertificateProviderStore certificateProviderStore) {
super(tlsContext, staticCertValidationContext);
this.certInstance = certInstance;
this.rootCertInstance = rootCertInstance;
String certInstanceName = null;
if (certInstance != null && certInstance.isInitialized()) {
certInstanceName = certInstance.getInstanceName();
CertificateProviderInfo certProviderInstanceConfig =
getCertProviderConfig(certProviders, certInstanceName);
certHandle =
certificateProviderStore.createOrGetProvider(
certInstance.getCertificateName(),
certProviderInstanceConfig.getPluginName(),
certProviderInstanceConfig.getConfig(),
this,
true);
} else {
certHandle = null;
}
if (rootCertInstance != null
&& rootCertInstance.isInitialized()
&& !rootCertInstance.getInstanceName().equals(certInstanceName)) {
CertificateProviderInfo certProviderInstanceConfig =
getCertProviderConfig(certProviders, rootCertInstance.getInstanceName());
rootCertHandle =
certificateProviderStore.createOrGetProvider(
rootCertInstance.getCertificateName(),
certProviderInstanceConfig.getPluginName(),
certProviderInstanceConfig.getConfig(),
this,
true);
} else {
rootCertHandle = null;
}
}
private CertificateProviderInfo getCertProviderConfig(
Map<String, CertificateProviderInfo> certProviders, String pluginInstanceName) {
return certProviders.get(pluginInstanceName);
}
@Override
public final void updateCertificate(PrivateKey key, List<X509Certificate> certChain) {
savedKey = key;
savedCertChain = certChain;
updateSslContextWhenReady();
}
@Override
public final void updateTrustedRoots(List<X509Certificate> trustedRoots) {
savedTrustedRoots = trustedRoots;
updateSslContextWhenReady();
}
private void updateSslContextWhenReady() {
if (isMtls()) {
if (savedKey != null && savedTrustedRoots != null) {
updateSslContext();
clearKeysAndCerts();
}
} else if (isClientSideTls()) {
if (savedTrustedRoots != null) {
updateSslContext();
clearKeysAndCerts();
}
} else if (isServerSideTls()) {
if (savedKey != null) {
updateSslContext();
clearKeysAndCerts();
}
}
}
private void clearKeysAndCerts() {
savedKey = null;
savedTrustedRoots = null;
savedCertChain = null;
}
protected final boolean isMtls() {
return certInstance != null && rootCertInstance != null;
}
protected final boolean isClientSideTls() {
return rootCertInstance != null && certInstance == null;
}
protected final boolean isServerSideTls() {
return certInstance != null && rootCertInstance == null;
}
@Override
protected final CertificateValidationContext generateCertificateValidationContext() {
return staticCertificateValidationContext;
}
@Override
public final void close() {
if (certHandle != null) {
certHandle.close();
}
if (rootCertHandle != null) {
rootCertHandle.close();
}
}
}

View File

@ -129,7 +129,7 @@ public final class CertificateProviderStore {
CertificateProviderProvider certProviderProvider =
certificateProviderRegistry.getProvider(key.pluginName);
if (certProviderProvider == null) {
throw new IllegalArgumentException("Provider not found.");
throw new IllegalArgumentException("Provider not found for " + key.pluginName);
}
CertificateProvider certProvider = certProviderProvider.createCertificateProvider(
key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates);

View File

@ -0,0 +1,143 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.collect.ImmutableList;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.grpc.Status;
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
/** Base class for dynamic {@link SslContextProvider}s. */
public abstract class DynamicSslContextProvider extends SslContextProvider {
protected final List<Callback> pendingCallbacks = new ArrayList<>();
@Nullable protected final CertificateValidationContext staticCertificateValidationContext;
@Nullable protected SslContext sslContext;
protected DynamicSslContextProvider(
BaseTlsContext tlsContext, CertificateValidationContext staticCertValidationContext) {
super(tlsContext);
this.staticCertificateValidationContext = staticCertValidationContext;
}
@Nullable
public SslContext getSslContext() {
return sslContext;
}
protected abstract CertificateValidationContext generateCertificateValidationContext();
/** Gets a server or client side SslContextBuilder. */
protected abstract SslContextBuilder getSslContextBuilder(
CertificateValidationContext certificateValidationContext)
throws CertificateException, IOException, CertStoreException;
// this gets called only when requested secrets are ready...
protected final void updateSslContext() {
try {
CertificateValidationContext localCertValidationContext =
generateCertificateValidationContext();
SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext);
CommonTlsContext commonTlsContext = getCommonTlsContext();
if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) {
List<String> alpnList = commonTlsContext.getAlpnProtocolsList();
ApplicationProtocolConfig apn =
new ApplicationProtocolConfig(
ApplicationProtocolConfig.Protocol.ALPN,
ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
alpnList);
sslContextBuilder.applicationProtocolConfig(apn);
}
List<Callback> pendingCallbacksCopy = null;
SslContext sslContextCopy = null;
synchronized (pendingCallbacks) {
sslContext = sslContextBuilder.build();
sslContextCopy = sslContext;
pendingCallbacksCopy = clonePendingCallbacksAndClear();
}
makePendingCallbacks(sslContextCopy, pendingCallbacksCopy);
} catch (Exception e) {
onError(Status.fromThrowable(e));
throw new RuntimeException(e);
}
}
protected final void callPerformCallback(
Callback callback, final SslContext sslContextCopy) {
performCallback(
new SslContextGetter() {
@Override
public SslContext get() {
return sslContextCopy;
}
},
callback
);
}
@Override
public final void addCallback(Callback callback) {
checkNotNull(callback, "callback");
// if there is a computed sslContext just send it
SslContext sslContextCopy = null;
synchronized (pendingCallbacks) {
if (sslContext != null) {
sslContextCopy = sslContext;
} else {
pendingCallbacks.add(callback);
}
}
if (sslContextCopy != null) {
callPerformCallback(callback, sslContextCopy);
}
}
private final void makePendingCallbacks(
SslContext sslContextCopy, List<Callback> pendingCallbacksCopy) {
for (Callback callback : pendingCallbacksCopy) {
callPerformCallback(callback, sslContextCopy);
}
}
/** Propagates error to all the callback receivers. */
public final void onError(Status error) {
for (Callback callback : clonePendingCallbacksAndClear()) {
callback.onException(error.asException());
}
}
private List<Callback> clonePendingCallbacksAndClear() {
synchronized (pendingCallbacks) {
List<Callback> copy = ImmutableList.copyOf(pendingCallbacks);
pendingCallbacks.clear();
return copy;
}
}
}

View File

@ -90,7 +90,7 @@ final class SdsClientSslContextProvider extends SdsSslContextProvider {
}
@Override
SslContextBuilder getSslContextBuilder(
protected final SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
SslContextBuilder sslContextBuilder =

View File

@ -197,7 +197,7 @@ public final class SdsProtocolNegotiators {
.findOrCreateClientSslContextProvider(upstreamTlsContext);
sslContextProvider.addCallback(
new SslContextProvider.Callback() {
new SslContextProvider.Callback(ctx.executor()) {
@Override
public void updateSecret(SslContext sslContext) {
@ -220,8 +220,8 @@ public final class SdsProtocolNegotiators {
public void onException(Throwable throwable) {
ctx.fireExceptionCaught(throwable);
}
},
ctx.executor());
}
);
}
@Override
@ -370,7 +370,7 @@ public final class SdsProtocolNegotiators {
}
final SslContextProvider sslContextProvider = sslContextProviderTemp;
sslContextProvider.addCallback(
new SslContextProvider.Callback() {
new SslContextProvider.Callback(ctx.executor()) {
@Override
public void updateSecret(SslContext sslContext) {
@ -389,8 +389,8 @@ public final class SdsProtocolNegotiators {
public void onException(Throwable throwable) {
ctx.fireExceptionCaught(throwable);
}
},
ctx.executor());
}
);
}
}
}

View File

@ -75,7 +75,7 @@ final class SdsServerSslContextProvider extends SdsSslContextProvider {
}
@Override
SslContextBuilder getSslContextBuilder(
protected final SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
SslContextBuilder sslContextBuilder =

View File

@ -21,27 +21,18 @@ import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.Secret;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.grpc.Status;
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/** Base class for SdsClientSslContextProvider and SdsServerSslContextProvider. */
abstract class SdsSslContextProvider extends SslContextProvider implements SdsClient.SecretWatcher {
abstract class SdsSslContextProvider extends DynamicSslContextProvider implements
SdsClient.SecretWatcher {
private static final Logger logger = Logger.getLogger(SdsSslContextProvider.class.getName());
@ -49,13 +40,10 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
@Nullable private final SdsClient validationContextSdsClient;
@Nullable private final SdsSecretConfig certSdsConfig;
@Nullable private final SdsSecretConfig validationContextSdsConfig;
@Nullable private final CertificateValidationContext staticCertificateValidationContext;
private final List<CallbackPair> pendingCallbacks = new ArrayList<>();
@Nullable protected TlsCertificate tlsCertificate;
@Nullable private CertificateValidationContext certificateValidationContext;
@Nullable private SslContext sslContext;
SdsSslContextProvider(
protected SdsSslContextProvider(
Node node,
SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig,
@ -63,10 +51,9 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
Executor watcherExecutor,
Executor channelExecutor,
BaseTlsContext tlsContext) {
super(tlsContext);
super(tlsContext, staticCertValidationContext);
this.certSdsConfig = certSdsConfig;
this.validationContextSdsConfig = validationContextSdsConfig;
this.staticCertificateValidationContext = staticCertValidationContext;
if (certSdsConfig != null && certSdsConfig.isInitialized()) {
certSdsClient =
SdsClient.Factory.createSdsClient(certSdsConfig, node, watcherExecutor, channelExecutor);
@ -87,35 +74,7 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
}
@Override
public void addCallback(Callback callback, Executor executor) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// if there is a computed sslContext just send it
SslContext sslContextCopy = sslContext;
if (sslContextCopy != null) {
callPerformCallback(callback, executor, sslContextCopy);
} else {
synchronized (pendingCallbacks) {
pendingCallbacks.add(new CallbackPair(callback, executor));
}
}
}
private void callPerformCallback(
Callback callback, Executor executor, final SslContext sslContextCopy) {
performCallback(
new SslContextGetter() {
@Override
public SslContext get() {
return sslContextCopy;
}
},
callback,
executor);
}
@Override
public synchronized void onSecretChanged(Secret secretUpdate) {
public final synchronized void onSecretChanged(Secret secretUpdate) {
checkNotNull(secretUpdate);
if (secretUpdate.hasTlsCertificate()) {
checkState(
@ -143,35 +102,8 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
}
}
/** Gets a server or client side SslContextBuilder. */
abstract SslContextBuilder getSslContextBuilder(
CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException;
// this gets called only when requested secrets are ready...
private void updateSslContext() {
try {
CertificateValidationContext localCertValidationContext = mergeStaticAndDynamicCertContexts();
SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext);
CommonTlsContext commonTlsContext = getCommonTlsContext();
if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) {
List<String> alpnList = commonTlsContext.getAlpnProtocolsList();
ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
ApplicationProtocolConfig.Protocol.ALPN,
ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
alpnList);
sslContextBuilder.applicationProtocolConfig(apn);
}
SslContext sslContextCopy = sslContextBuilder.build();
sslContext = sslContextCopy;
makePendingCallbacks(sslContextCopy);
} catch (CertificateException | IOException | CertStoreException e) {
logger.log(Level.SEVERE, "exception in updateSslContext", e);
}
}
private CertificateValidationContext mergeStaticAndDynamicCertContexts() {
@Override
protected final CertificateValidationContext generateCertificateValidationContext() {
if (staticCertificateValidationContext == null) {
return certificateValidationContext;
}
@ -183,27 +115,8 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
return localCertContextBuilder.mergeFrom(staticCertificateValidationContext).build();
}
private void makePendingCallbacks(SslContext sslContextCopy) {
synchronized (pendingCallbacks) {
for (CallbackPair pair : pendingCallbacks) {
callPerformCallback(pair.callback, pair.executor, sslContextCopy);
}
pendingCallbacks.clear();
}
}
@Override
public void onError(Status error) {
synchronized (pendingCallbacks) {
for (CallbackPair callbackPair : pendingCallbacks) {
callbackPair.callback.onException(error.asException());
}
pendingCallbacks.clear();
}
}
@Override
public void close() {
public final void close() {
if (certSdsClient != null) {
certSdsClient.cancelSecretWatch(this);
certSdsClient.shutdown();
@ -213,14 +126,4 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
validationContextSdsClient.shutdown();
}
}
private static class CallbackPair {
private final Callback callback;
private final Executor executor;
private CallbackPair(Callback callback, Executor executor) {
this.callback = callback;
this.executor = executor;
}
}
}

View File

@ -34,7 +34,6 @@ import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** A client SslContext provider that uses file-based secrets (secret volume). */
@ -92,9 +91,8 @@ final class SecretVolumeClientSslContextProvider extends SslContextProvider {
}
@Override
public void addCallback(final Callback callback, Executor executor) {
public void addCallback(final Callback callback) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
@ -104,8 +102,8 @@ final class SecretVolumeClientSslContextProvider extends SslContextProvider {
return buildSslContextFromSecrets();
}
},
callback,
executor);
callback
);
}
@Override

View File

@ -33,7 +33,6 @@ import java.io.File;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** A server SslContext provider that uses file-based secrets (secret volume). */
@ -85,9 +84,8 @@ final class SecretVolumeServerSslContextProvider extends SslContextProvider {
}
@Override
public void addCallback(final Callback callback, Executor executor) {
public void addCallback(final Callback callback) {
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
// as per the contract we will read the current secrets on disk
// this involves I/O which can potentially block the executor
performCallback(
@ -97,8 +95,8 @@ final class SecretVolumeServerSslContextProvider extends SslContextProvider {
return buildSslContextFromSecrets();
}
},
callback,
executor);
callback
);
}
@Override

View File

@ -47,19 +47,25 @@ public abstract class SslContextProvider implements Closeable {
protected final BaseTlsContext tlsContext;
public interface Callback {
abstract static class Callback {
private final Executor executor;
protected Callback(Executor executor) {
this.executor = executor;
}
/** Informs callee of new/updated SslContext. */
void updateSecret(SslContext sslContext);
abstract void updateSecret(SslContext sslContext);
/** Informs callee of an exception that was generated. */
void onException(Throwable throwable);
abstract void onException(Throwable throwable);
}
SslContextProvider(BaseTlsContext tlsContext) {
protected SslContextProvider(BaseTlsContext tlsContext) {
this.tlsContext = checkNotNull(tlsContext, "tlsContext");
}
CommonTlsContext getCommonTlsContext() {
protected CommonTlsContext getCommonTlsContext() {
return tlsContext.getCommonTlsContext();
}
@ -100,14 +106,13 @@ public abstract class SslContextProvider implements Closeable {
* Registers a callback on the given executor. The callback will run when SslContext becomes
* available or immediately if the result is already available.
*/
public abstract void addCallback(Callback callback, Executor executor);
public abstract void addCallback(Callback callback);
final void performCallback(
final SslContextGetter sslContextGetter, final Callback callback, Executor executor) {
protected final void performCallback(
final SslContextGetter sslContextGetter, final Callback callback) {
checkNotNull(sslContextGetter, "sslContextGetter");
checkNotNull(callback, "callback");
checkNotNull(executor, "executor");
executor.execute(
callback.executor.execute(
new Runnable() {
@Override
public void run() {

View File

@ -16,7 +16,7 @@
package io.grpc.xds.internal.sds.trust;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
@ -53,9 +53,29 @@ public final class SdsTrustManagerFactory extends SimpleTrustManagerFactory {
/** Constructor constructs from a {@link CertificateValidationContext}. */
public SdsTrustManagerFactory(CertificateValidationContext certificateValidationContext)
throws CertificateException, IOException, CertStoreException {
checkNotNull(certificateValidationContext, "certificateValidationContext");
sdsX509TrustManager = createSdsX509TrustManager(
getTrustedCaFromCertContext(certificateValidationContext), certificateValidationContext);
this(
getTrustedCaFromCertContext(certificateValidationContext),
certificateValidationContext,
false);
}
public SdsTrustManagerFactory(
X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext)
throws CertStoreException {
this(certs, staticCertificateValidationContext, true);
}
private SdsTrustManagerFactory(
X509Certificate[] certs,
CertificateValidationContext certificateValidationContext,
boolean validationContextIsStatic)
throws CertStoreException {
if (validationContextIsStatic) {
checkArgument(
certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(),
"only static certificateValidationContext expected");
}
sdsX509TrustManager = createSdsX509TrustManager(certs, certificateValidationContext);
}
private static X509Certificate[] getTrustedCaFromCertContext(

View File

@ -59,7 +59,8 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509
}
// Copied from OkHostnameVerifier.verifyHostName().
private static boolean verifyDnsNameInPattern(String pattern, String sanToVerify) {
private static boolean verifyDnsNameInPattern(String pattern, StringMatcher sanToVerifyMatcher) {
String sanToVerify = sanToVerifyMatcher.getExact();
// Basic sanity checks
// Check length == 0 instead of .isEmpty() to support Java 5.
if (sanToVerify == null
@ -150,9 +151,9 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509
// sanToVerify matches pattern
}
private static boolean verifyDnsNameInSanList(String altNameFromCert,
List<String> verifySanList) {
for (String verifySan : verifySanList) {
private static boolean verifyDnsNameInSanList(
String altNameFromCert, List<StringMatcher> verifySanList) {
for (StringMatcher verifySan : verifySanList) {
if (verifyDnsNameInPattern(altNameFromCert, verifySan)) {
return true;
}
@ -168,16 +169,17 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509
* @param verifySanList list of SANs from certificate context
* @return true if there is a match
*/
private static boolean verifyStringInSanList(String stringFromCert, List<String> verifySanList) {
for (String sanToVerify : verifySanList) {
if (Ascii.equalsIgnoreCase(sanToVerify, stringFromCert)) {
private static boolean verifyStringInSanList(
String stringFromCert, List<StringMatcher> verifySanList) {
for (StringMatcher sanToVerify : verifySanList) {
if (Ascii.equalsIgnoreCase(sanToVerify.getExact(), stringFromCert)) {
return true;
}
}
return false;
}
private static boolean verifyOneSanInList(List<?> entry, List<String> verifySanList)
private static boolean verifyOneSanInList(List<?> entry, List<StringMatcher> verifySanList)
throws CertificateParsingException {
// from OkHostnameVerifier.getSubjectAltNames
if (entry == null || entry.size() < 2) {
@ -200,9 +202,8 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509
}
// logic from Envoy::Extensions::TransportSockets::Tls::ContextImpl::verifySubjectAltName
@SuppressWarnings("UnusedMethod") // TODO(#7166): support StringMatcher list.
private static void verifySubjectAltNameInLeaf(X509Certificate cert, List<String> verifyList)
throws CertificateException {
private static void verifySubjectAltNameInLeaf(
X509Certificate cert, List<StringMatcher> verifyList) throws CertificateException {
Collection<List<?>> names = cert.getSubjectAlternativeNames();
if (names == null || names.isEmpty()) {
throw new CertificateException("Peer certificate SAN check failed");
@ -233,9 +234,7 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509
throw new CertificateException("Peer certificate(s) missing");
}
// verify SANs only in the top cert (leaf cert)
// v2 version: verifySubjectAltNameInLeaf(peerCertChain[0], verifyList);
// TODO(#7166): Implement v3 version.
throw new UnsupportedOperationException();
verifySubjectAltNameInLeaf(peerCertChain[0], verifyList);
}
@Override

View File

@ -0,0 +1,293 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils.getCertFromResourceName;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext;
import static org.junit.Assert.fail;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link CertProviderClientSslContextProvider}. */
@RunWith(JUnit4.class)
public class CertProviderClientSslContextProviderTest {
private static final Logger logger =
Logger.getLogger(CertProviderClientSslContextProviderTest.class.getName());
CertificateProviderRegistry certificateProviderRegistry;
CertificateProviderStore certificateProviderStore;
private CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory;
@Before
public void setUp() throws Exception {
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
certProviderClientSslContextProviderFactory =
new CertProviderClientSslContextProvider.Factory(certificateProviderStore);
}
/** Helper method to build CertProviderClientSslContextProvider. */
private CertProviderClientSslContextProvider getSslContextProvider(
String certInstanceName,
String rootInstanceName,
Bootstrapper.BootstrapInfo bootstrapInfo,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext) {
EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
certInstanceName,
"cert-default",
rootInstanceName,
"root-default",
alpnProtocols,
staticCertValidationContext);
return certProviderClientSslContextProviderFactory.getProvider(
upstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
}
@Test
public void testProviderForClient_mtls() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderClientSslContextProvider provider =
getSslContextProvider(
"gcp_id",
"gcp_id",
CommonCertProviderTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNull();
// now generate cert update
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE),
ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE)));
assertThat(provider.savedKey).isNotNull();
assertThat(provider.savedCertChain).isNotNull();
assertThat(provider.getSslContext()).isNull();
// now generate root cert update
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
assertThat(provider.getSslContext()).isNotNull();
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
TestCallback testCallback =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
TestCallback testCallback1 =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
// just do root cert update: sslContext should still be the same
watcherCaptor[0].updateTrustedRoots(
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNotNull();
testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
// now update id cert: sslContext should be updated i.e.different from the previous one
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE)));
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNotNull();
testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext);
}
@Test
public void testProviderForClient_queueExecutor() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderClientSslContextProvider provider =
getSslContextProvider(
"gcp_id",
"gcp_id",
CommonCertProviderTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
QueuedExecutor queuedExecutor = new QueuedExecutor();
TestCallback testCallback =
CommonTlsContextTestsUtil.getValueThruCallback(provider, queuedExecutor);
assertThat(queuedExecutor.runQueue).isEmpty();
// now generate cert update
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE),
ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE)));
assertThat(queuedExecutor.runQueue).isEmpty(); // still empty
// now generate root cert update
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
assertThat(queuedExecutor.runQueue).hasSize(1);
queuedExecutor.drain();
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void testProviderForClient_tls() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderClientSslContextProvider provider =
getSslContextProvider(
/* certInstanceName= */ null,
"gcp_id",
CommonCertProviderTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNull();
// now generate root cert update
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
assertThat(provider.getSslContext()).isNotNull();
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
TestCallback testCallback =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void testProviderForClient_sslContextException_onError() throws Exception {
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build();
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderClientSslContextProvider provider =
getSslContextProvider(
/* certInstanceName= */ null,
"gcp_id",
CommonCertProviderTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */null,
staticCertValidationContext);
TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor());
provider.addCallback(testCallback);
try {
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
fail("exception expected");
} catch (RuntimeException expected) {
assertThat(expected)
.hasMessageThat()
.contains("only static certificateValidationContext expected");
}
assertThat(testCallback.updatedThrowable).isNotNull();
assertThat(testCallback.updatedThrowable)
.hasCauseThat()
.hasMessageThat()
.contains("only static certificateValidationContext expected");
}
@Test
public void testProviderForClient_rootInstanceNull_expectError() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
try {
getSslContextProvider(
/* certInstanceName= */ null,
/* rootInstanceName= */ null,
CommonCertProviderTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
fail("exception expected");
} catch (NullPointerException expected) {
assertThat(expected).hasMessageThat().contains("Client SSL requires rootCertInstance");
}
}
static class QueuedExecutor implements Executor {
/** A list of Runnables to be run in order. */
private final Queue<Runnable> runQueue = new ConcurrentLinkedQueue<>();
@Override
public synchronized void execute(Runnable r) {
runQueue.add(checkNotNull(r, "'r' must not be null."));
}
public synchronized void drain() {
Runnable r;
while ((r = runQueue.poll()) != null) {
try {
r.run();
} catch (RuntimeException e) {
// Log it and keep going.
logger.log(Level.SEVERE, "Exception while executing runnable " + r, e);
}
}
}
}
}

View File

@ -49,36 +49,6 @@ public class CertificateProviderStoreTest {
private CertificateProviderStore certificateProviderStore;
private boolean throwExceptionForCertUpdates;
private class TestCertificateProvider extends CertificateProvider {
Object config;
CertificateProviderProvider certProviderProvider;
int closeCalled = 0;
int startCalled = 0;
protected TestCertificateProvider(
CertificateProvider.DistributorWatcher watcher,
boolean notifyCertUpdates,
Object config,
CertificateProviderProvider certificateProviderProvider) {
super(watcher, notifyCertUpdates);
if (throwExceptionForCertUpdates && notifyCertUpdates) {
throw new UnsupportedOperationException("Provider does not support Certificate Updates.");
}
this.config = config;
this.certProviderProvider = certificateProviderProvider;
}
@Override
public void close() {
closeCalled++;
}
@Override
public void start() {
startCalled++;
}
}
@Before
public void setUp() {
certificateProviderRegistry = new CertificateProviderRegistry();
@ -94,7 +64,7 @@ public class CertificateProviderStoreTest {
"cert-name1", "plugin1", "config", mockWatcher, true);
fail("exception expected");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("Provider not found.");
assertThat(expected).hasMessageThat().isEqualTo("Provider not found for plugin1");
}
}
@ -111,7 +81,7 @@ public class CertificateProviderStoreTest {
"cert-name1", "plugin1", "config", mockWatcher, true);
fail("exception expected");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("Provider not found.");
assertThat(expected).hasMessageThat().isEqualTo("Provider not found for plugin1");
}
}
@ -369,7 +339,8 @@ public class CertificateProviderStoreTest {
(CertificateProvider.DistributorWatcher) args[1];
boolean notifyCertUpdates = (Boolean) args[2];
return new TestCertificateProvider(
watcher, notifyCertUpdates, config, certProviderProvider);
watcher, notifyCertUpdates, config, certProviderProvider,
throwExceptionForCertUpdates);
}
});
certificateProviderRegistry.register(certProviderProvider);

View File

@ -0,0 +1,177 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.io.CharStreams;
import io.grpc.internal.testing.TestUtils;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.base64.Base64;
import io.netty.util.CharsetUtil;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.Reader;
import java.security.KeyException;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class CommonCertProviderTestUtils {
private static final Logger logger =
Logger.getLogger(CommonCertProviderTestUtils.class.getName());
private static final Pattern KEY_PATTERN = Pattern.compile(
"-+BEGIN\\s+.*PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+" + // Header
"([a-z0-9+/=\\r\\n]+)" + // Base64 text
"-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", // Footer
Pattern.CASE_INSENSITIVE);
static Bootstrapper.BootstrapInfo getTestBootstrapInfo() throws IOException {
String rawData =
"{\n"
+ " \"xds_servers\": [],\n"
+ " \"certificate_providers\": {\n"
+ " \"gcp_id\": {\n"
+ " \"plugin_name\": \"testca\",\n"
+ " \"config\": {\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"target_uri\": \"meshca.com\",\n"
+ " \"channel_credentials\": {\"google_default\": {}},\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " \"token_exchange_service\": \"securetoken.googleapis.com\",\n"
+ " \"subject_token_path\": \"/etc/secret/sajwt.token\"\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " },\n" // end google_grpc
+ " \"time_out\": {\"seconds\": 10}\n"
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"certificate_lifetime\": {\"seconds\": 86400},\n"
+ " \"renewal_grace_period\": {\"seconds\": 3600},\n"
+ " \"key_type\": \"RSA\",\n"
+ " \"key_size\": 2048,\n"
+ " \"location\": \"https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3\"\n"
+ " }\n" // end config
+ " },\n" // end gcp_id
+ " \"file_provider\": {\n"
+ " \"plugin_name\": \"file_watcher\",\n"
+ " \"config\": {\"path\": \"/etc/secret/certs\"}\n"
+ " }\n"
+ " }\n"
+ "}";
return Bootstrapper.parseConfig(rawData);
}
static PrivateKey getPrivateKey(String resourceName)
throws Exception {
InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName);
ByteBuf encodedKeyBuf = readPrivateKey(inputStream);
byte[] encodedKey = new byte[encodedKeyBuf.readableBytes()];
encodedKeyBuf.readBytes(encodedKey).release();
PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(encodedKey);
try {
return KeyFactory.getInstance("RSA").generatePrivate(spec);
} catch (InvalidKeySpecException ignore) {
try {
return KeyFactory.getInstance("DSA").generatePrivate(spec);
} catch (InvalidKeySpecException ignore2) {
try {
return KeyFactory.getInstance("EC").generatePrivate(spec);
} catch (InvalidKeySpecException e) {
throw new InvalidKeySpecException("Neither RSA, DSA nor EC worked", e);
}
}
}
}
static ByteBuf readPrivateKey(InputStream in) throws KeyException {
String content;
try {
content = readContent(in);
} catch (IOException e) {
throw new KeyException("failed to read key input stream", e);
}
Matcher m = KEY_PATTERN.matcher(content);
if (!m.find()) {
throw new KeyException("could not find a PKCS #8 private key in input stream");
}
ByteBuf base64 = Unpooled.copiedBuffer(m.group(1), CharsetUtil.US_ASCII);
ByteBuf der = Base64.decode(base64);
base64.release();
return der;
}
private static String readContent(InputStream in) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
byte[] buf = new byte[8192];
for (; ; ) {
int ret = in.read(buf);
if (ret < 0) {
break;
}
out.write(buf, 0, ret);
}
return out.toString(CharsetUtil.US_ASCII.name());
} finally {
safeClose(out);
}
}
private static void safeClose(OutputStream out) {
try {
out.close();
} catch (IOException e) {
logger.log(Level.WARNING, "Failed to close a stream.", e);
}
}
static X509Certificate getCertFromResourceName(String resourceName)
throws IOException, CertificateException {
return CertificateUtils.toX509Certificate(
new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8)));
}
private static String getResourceContents(String resourceName) throws IOException {
InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName);
String text = null;
try (Reader reader = new InputStreamReader(inputStream, UTF_8)) {
text = CharStreams.toString(reader);
}
return text;
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
public class TestCertificateProvider extends CertificateProvider {
Object config;
CertificateProviderProvider certProviderProvider;
int closeCalled = 0;
int startCalled = 0;
TestCertificateProvider(
DistributorWatcher watcher,
boolean notifyCertUpdates,
Object config,
CertificateProviderProvider certificateProviderProvider,
boolean throwExceptionForCertUpdates) {
super(watcher, notifyCertUpdates);
if (throwExceptionForCertUpdates && notifyCertUpdates) {
throw new UnsupportedOperationException("Provider does not support Certificate Updates.");
}
this.config = config;
this.certProviderProvider = certificateProviderProvider;
}
@Override
public void close() {
closeCalled++;
}
@Override
public void start() {
startCalled++;
}
static void createAndRegisterProviderProvider(
CertificateProviderRegistry certificateProviderRegistry,
final CertificateProvider.DistributorWatcher[] watcherCaptor,
String testca,
final int index) {
final CertificateProviderProvider mockProviderProviderTestCa =
new TestCertificateProviderProvider(testca, watcherCaptor, index);
certificateProviderRegistry.register(mockProviderProviderTestCa);
}
private static class TestCertificateProviderProvider implements CertificateProviderProvider {
private final String testCa;
private final CertificateProvider.DistributorWatcher[] watcherCaptor;
private final int index;
TestCertificateProviderProvider(
String testCa, CertificateProvider.DistributorWatcher[] watcherCaptor, int index) {
this.testCa = testCa;
this.watcherCaptor = watcherCaptor;
this.index = index;
}
@Override
public String getName() {
return testCa;
}
@Override
public CertificateProvider createCertificateProvider(
Object config, DistributorWatcher watcher, boolean notifyCertUpdates) {
watcherCaptor[index] = watcher;
return new TestCertificateProvider(watcher, true, config, this, false);
}
}
}

View File

@ -16,10 +16,12 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Strings;
import com.google.common.io.CharStreams;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.BoolValue;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
@ -40,6 +42,7 @@ import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.internal.testing.TestUtils;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import io.netty.handler.ssl.SslContext;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
@ -48,6 +51,8 @@ import java.io.Reader;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** Utility class for client and server ssl provider tests. */
@ -461,4 +466,122 @@ public class CommonTlsContextTestsUtil {
}
return text;
}
private static CommonTlsContext buildCommonTlsContextForCertProviderInstance(
String certInstanceName,
String certName,
String rootInstanceName,
String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (certInstanceName != null) {
builder =
builder.setTlsCertificateCertificateProviderInstance(
CommonTlsContext.CertificateProviderInstance.newBuilder()
.setInstanceName(certInstanceName)
.setCertificateName(certName));
}
builder =
addCertificateValidationContext(
builder, rootInstanceName, rootCertName, staticCertValidationContext);
if (alpnProtocols != null) {
builder.addAllAlpnProtocols(alpnProtocols);
}
return builder.build();
}
private static CommonTlsContext.Builder addCertificateValidationContext(
CommonTlsContext.Builder builder,
String rootInstanceName,
String rootCertName,
CertificateValidationContext staticCertValidationContext) {
if (rootInstanceName != null) {
CommonTlsContext.CertificateProviderInstance.Builder providerInstanceBuilder =
CommonTlsContext.CertificateProviderInstance.newBuilder()
.setInstanceName(rootInstanceName)
.setCertificateName(rootCertName);
if (staticCertValidationContext != null) {
CombinedCertificateValidationContext combined =
CombinedCertificateValidationContext.newBuilder()
.setDefaultValidationContext(staticCertValidationContext)
.setValidationContextCertificateProviderInstance(providerInstanceBuilder)
.build();
return builder.setCombinedValidationContext(combined);
}
builder = builder.setValidationContextCertificateProviderInstance(providerInstanceBuilder);
}
return builder;
}
/** Helper method to build UpstreamTlsContext for CertProvider tests. */
public static EnvoyServerProtoData.UpstreamTlsContext
buildUpstreamTlsContextForCertProviderInstance(
@Nullable String certInstanceName,
@Nullable String certName,
@Nullable String rootInstanceName,
@Nullable String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext) {
return buildUpstreamTlsContext(
buildCommonTlsContextForCertProviderInstance(
certInstanceName,
certName,
rootInstanceName,
rootCertName,
alpnProtocols,
staticCertValidationContext));
}
/** Perform some simple checks on sslContext. */
public static void doChecksOnSslContext(boolean server, SslContext sslContext,
List<String> expectedApnProtos) {
if (server) {
assertThat(sslContext.isServer()).isTrue();
} else {
assertThat(sslContext.isClient()).isTrue();
}
List<String> apnProtos = sslContext.applicationProtocolNegotiator().protocols();
assertThat(apnProtos).isNotNull();
if (expectedApnProtos != null) {
assertThat(apnProtos).isEqualTo(expectedApnProtos);
} else {
assertThat(apnProtos).contains("h2");
}
}
/**
* Helper method to get the value thru directExecutor callback. Because of directExecutor this is
* a synchronous callback - so need to provide a listener.
*/
public static TestCallback getValueThruCallback(SslContextProvider provider) {
return getValueThruCallback(provider, MoreExecutors.directExecutor());
}
/** Helper method to get the value thru callback with a user passed executor. */
public static TestCallback getValueThruCallback(SslContextProvider provider, Executor executor) {
TestCallback testCallback = new TestCallback(executor);
provider.addCallback(testCallback);
return testCallback;
}
public static class TestCallback extends SslContextProvider.Callback {
public SslContext updatedSslContext;
public Throwable updatedThrowable;
public TestCallback(Executor executor) {
super(executor);
}
@Override
public void updateSecret(SslContext sslContext) {
updatedSslContext = sslContext;
}
@Override
public void onException(Throwable throwable) {
updatedThrowable = throwable;
}
}
}

View File

@ -22,9 +22,10 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.getValueThruCallback;
import static io.grpc.xds.internal.sds.SdsClientTest.getOneCertificateValidationContextSecret;
import static io.grpc.xds.internal.sds.SdsClientTest.getOneTlsCertSecret;
import static io.grpc.xds.internal.sds.SecretVolumeSslContextProviderTest.doChecksOnSslContext;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -33,6 +34,7 @@ import io.envoyproxy.envoy.api.v2.core.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.Status.Code;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback;
import java.io.IOException;
import java.util.Arrays;
import org.junit.After;
@ -123,8 +125,7 @@ public class SdsSslContextProviderTest {
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider("cert1", "valid1", null, null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@ -142,8 +143,7 @@ public class SdsSslContextProviderTest {
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@ -159,8 +159,7 @@ public class SdsSslContextProviderTest {
/* validationContextName= */ null,
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@ -176,8 +175,7 @@ public class SdsSslContextProviderTest {
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@ -193,8 +191,7 @@ public class SdsSslContextProviderTest {
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
assertThat(server.lastNack).isNotNull();
assertThat(server.lastNack.getVersionInfo()).isEmpty();
@ -222,8 +219,7 @@ public class SdsSslContextProviderTest {
.build()),
/* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@ -240,8 +236,7 @@ public class SdsSslContextProviderTest {
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(
false, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));
@ -260,8 +255,7 @@ public class SdsSslContextProviderTest {
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(
true, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));

View File

@ -22,16 +22,17 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.getValueThruCallback;
import com.google.common.util.concurrent.MoreExecutors;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback;
import io.netty.handler.ssl.SslContext;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.List;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
@ -371,22 +372,6 @@ public class SecretVolumeSslContextProviderTest {
doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null);
}
static void doChecksOnSslContext(boolean server, SslContext sslContext,
List<String> expectedApnProtos) {
if (server) {
assertThat(sslContext.isServer()).isTrue();
} else {
assertThat(sslContext.isClient()).isTrue();
}
List<String> apnProtos = sslContext.applicationProtocolNegotiator().protocols();
assertThat(apnProtos).isNotNull();
if (expectedApnProtos != null) {
assertThat(apnProtos).isEqualTo(expectedApnProtos);
} else {
assertThat(apnProtos).contains("h2");
}
}
@Test
public void getProviderForServer() throws IOException, CertificateException, CertStoreException {
sslContextForEitherWithBothCertAndTrust(
@ -421,32 +406,6 @@ public class SecretVolumeSslContextProviderTest {
}
}
static class TestCallback implements SslContextProvider.Callback {
SslContext updatedSslContext;
Throwable updatedThrowable;
@Override
public void updateSecret(SslContext sslContext) {
updatedSslContext = sslContext;
}
@Override
public void onException(Throwable throwable) {
updatedThrowable = throwable;
}
}
/**
* Helper method to get the value thru directExecutor callback. Because of directExecutor this is
* a synchronous callback - so need to provide a listener.
*/
static TestCallback getValueThruCallback(SslContextProvider provider) {
TestCallback testCallback = new TestCallback();
provider.addCallback(testCallback, MoreExecutors.directExecutor());
return testCallback;
}
@Test
public void getProviderForServer_both_callsback() throws IOException {
SecretVolumeServerSslContextProvider provider =

View File

@ -26,6 +26,7 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FI
import com.google.protobuf.ByteString;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.internal.testing.TestUtils;
import java.io.IOException;
import java.security.cert.CertStoreException;
@ -80,6 +81,100 @@ public class SdsTrustManagerFactoryTest {
.isEqualTo(CertificateUtils.toX509Certificates(TestUtils.loadCert(CA_PEM_FILE))[0]);
}
@Test
public void constructor_fromRootCert()
throws CertificateException, IOException, CertStoreException {
X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE);
CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1",
"san2");
SdsTrustManagerFactory factory =
new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext);
assertThat(factory).isNotNull();
TrustManager[] tms = factory.getTrustManagers();
assertThat(tms).isNotNull();
assertThat(tms).hasLength(1);
TrustManager myTm = tms[0];
assertThat(myTm).isInstanceOf(SdsX509TrustManager.class);
SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) myTm;
X509Certificate[] acceptedIssuers = sdsX509TrustManager.getAcceptedIssuers();
assertThat(acceptedIssuers).isNotNull();
assertThat(acceptedIssuers).hasLength(1);
X509Certificate caCert = acceptedIssuers[0];
assertThat(caCert)
.isEqualTo(CertificateUtils.toX509Certificates(TestUtils.loadCert(CA_PEM_FILE))[0]);
}
@Test
public void constructorRootCert_checkServerTrusted()
throws CertificateException, IOException, CertStoreException {
X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE);
CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1",
"waterzooi.test.google.be");
SdsTrustManagerFactory factory =
new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext);
SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0];
X509Certificate[] serverChain =
CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
sdsX509TrustManager.checkServerTrusted(serverChain, "RSA");
}
@Test
public void constructorRootCert_nonStaticContext_throwsException()
throws CertificateException, IOException, CertStoreException {
X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE);
try {
new SdsTrustManagerFactory(
new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected)
.hasMessageThat()
.contains("only static certificateValidationContext expected");
}
}
@Test
public void constructorRootCert_checkServerTrusted_throwsException()
throws CertificateException, IOException, CertStoreException {
X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE);
CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1",
"san2");
SdsTrustManagerFactory factory =
new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext);
SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0];
X509Certificate[] serverChain =
CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
try {
sdsX509TrustManager.checkServerTrusted(serverChain, "RSA");
Assert.fail("no exception thrown");
} catch (CertificateException expected) {
assertThat(expected)
.hasMessageThat()
.contains("Peer certificate SAN check failed");
}
}
@Test
public void constructorRootCert_checkClientTrusted_throwsException()
throws CertificateException, IOException, CertStoreException {
X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE);
CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1",
"san2");
SdsTrustManagerFactory factory =
new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext);
SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0];
X509Certificate[] clientChain =
CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
try {
sdsX509TrustManager.checkClientTrusted(clientChain, "RSA");
Assert.fail("no exception thrown");
} catch (CertificateException expected) {
assertThat(expected)
.hasMessageThat()
.contains("Peer certificate SAN check failed");
}
}
@Test
public void checkServerTrusted_goodCert()
throws CertificateException, IOException, CertStoreException {
@ -156,4 +251,13 @@ public class SdsTrustManagerFactoryTest {
DataSource.newBuilder().setInlineBytes(ByteString.copyFrom(x509Cert.getEncoded())))
.build();
}
private static final CertificateValidationContext buildStaticValidationContext(
String... verifySans) {
CertificateValidationContext.Builder builder = CertificateValidationContext.newBuilder();
for (String san : verifySans) {
builder.addMatchSubjectAltNames(StringMatcher.newBuilder().setExact(san));
}
return builder.build();
}
}