xds: remove MeshCaCertificateProvider and DynamicReloadingCertificate{Provider (#8214)

This commit is contained in:
sanjaypujare 2021-05-26 19:35:51 -07:00 committed by GitHub
parent 328071bbce
commit bfcba82dd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 77 additions and 2690 deletions

View File

@ -40,8 +40,6 @@ public final class CertificateProviderRegistry {
instance = new CertificateProviderRegistry(); instance = new CertificateProviderRegistry();
// TODO(sanjaypujare): replace with Java's SPI mechanism and META-INF resource // TODO(sanjaypujare): replace with Java's SPI mechanism and META-INF resource
instance.register(new FileWatcherCertificateProviderProvider()); instance.register(new FileWatcherCertificateProviderProvider());
instance.register(new DynamicReloadingCertificateProviderProvider());
instance.register(new MeshCaCertificateProviderProvider());
} }
return instance; return instance;
} }

View File

@ -1,225 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.InternalLogId;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import java.io.File;
import java.io.FileInputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/** Implementation of {@link CertificateProvider} for dynamic reloading cert provider. */
final class DynamicReloadingCertificateProvider extends CertificateProvider {
private static final Logger logger =
Logger.getLogger(DynamicReloadingCertificateProvider.class.getName());
private final SynchronizationContext syncContext;
private final ScheduledExecutorService scheduledExecutorService;
private final TimeProvider timeProvider;
private final Path directory;
private final String certFile;
private final String privateKeyFile;
private final String trustFile;
private final long refreshIntervalInSeconds;
@VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle;
private Path lastModifiedTarget;
DynamicReloadingCertificateProvider(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String directory,
String certFile,
String privateKeyFile,
String trustFile,
long refreshIntervalInSeconds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider) {
super(watcher, notifyCertUpdates);
this.scheduledExecutorService =
checkNotNull(scheduledExecutorService, "scheduledExecutorService");
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.directory = Paths.get(checkNotNull(directory, "diretory"));
this.certFile = checkNotNull(certFile, "certFile");
this.privateKeyFile = checkNotNull(privateKeyFile, "privateKeyFile");
this.trustFile = checkNotNull(trustFile, "trustFile");
this.refreshIntervalInSeconds = refreshIntervalInSeconds;
this.syncContext = createSynchronizationContext(directory);
}
private SynchronizationContext createSynchronizationContext(String details) {
final InternalLogId logId =
InternalLogId.allocate("DynamicReloadingCertificateProvider", details);
return new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
private boolean panicMode;
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.log(
Level.SEVERE,
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
e);
panic(e);
}
void panic(final Throwable t) {
if (panicMode) {
// Preserve the first panic information
return;
}
panicMode = true;
close();
}
});
}
@Override
public void start() {
scheduleNextRefreshCertificate(/* delayInSeconds= */0);
}
@Override
public void close() {
if (scheduledHandle != null) {
scheduledHandle.cancel();
scheduledHandle = null;
}
getWatcher().close();
}
private void scheduleNextRefreshCertificate(long delayInSeconds) {
RefreshCertificateTask runnable = new RefreshCertificateTask();
scheduledHandle =
syncContext.schedule(runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService);
}
@VisibleForTesting
void checkAndReloadCertificates() {
try {
Path targetPath = Files.readSymbolicLink(directory);
if (targetPath.equals(lastModifiedTarget)) {
return;
}
try (FileInputStream privateKeyStream =
new FileInputStream(new File(targetPath.toFile(), privateKeyFile));
FileInputStream certsStream =
new FileInputStream(new File(targetPath.toFile(), certFile));
FileInputStream caCertsStream =
new FileInputStream(new File(targetPath.toFile(), trustFile))) {
PrivateKey privateKey = CertificateUtils.getPrivateKey(privateKeyStream);
X509Certificate[] certs = CertificateUtils.toX509Certificates(certsStream);
X509Certificate[] caCerts = CertificateUtils.toX509Certificates(caCertsStream);
getWatcher().updateCertificate(privateKey, Arrays.asList(certs));
getWatcher().updateTrustedRoots(Arrays.asList(caCerts));
}
lastModifiedTarget = targetPath;
} catch (Throwable t) {
generateErrorIfCurrentCertExpired(t);
} finally {
scheduleNextRefreshCertificate(refreshIntervalInSeconds);
}
}
private void generateErrorIfCurrentCertExpired(Throwable t) {
X509Certificate currentCert = getWatcher().getLastIdentityCert();
if (currentCert != null) {
long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
if (delaySeconds > refreshIntervalInSeconds) {
logger.log(Level.FINER, "reload certificate error", t);
return;
}
// The current cert is going to expire in less than {@link refreshIntervalInSeconds}
// Clear the current cert and notify our watchers thru {@code onError}
getWatcher().clearValues();
}
getWatcher().onError(Status.fromThrowable(t));
}
@SuppressWarnings("JdkObsolete")
private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
checkNotNull(lastCert, "lastCert");
return TimeUnit.NANOSECONDS.toSeconds(
TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime())
- timeProvider.currentTimeNanos());
}
@VisibleForTesting
class RefreshCertificateTask implements Runnable {
@Override
public void run() {
checkAndReloadCertificates();
}
}
abstract static class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory() {
@Override
DynamicReloadingCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String directory,
String certFile,
String privateKeyFile,
String trustFile,
long refreshIntervalInSeconds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider) {
return new DynamicReloadingCertificateProvider(
watcher,
notifyCertUpdates,
directory,
certFile,
privateKeyFile,
trustFile,
refreshIntervalInSeconds,
scheduledExecutorService,
timeProvider);
}
};
static Factory getInstance() {
return DEFAULT_INSTANCE;
}
abstract DynamicReloadingCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String directory,
String certFile,
String privateKeyFile,
String trustFile,
long refreshIntervalInSeconds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider);
}
}

View File

@ -1,136 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.internal.JsonUtil;
import io.grpc.internal.TimeProvider;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
/**
* Provider of {@link DynamicReloadingCertificateProvider}s.
*/
final class DynamicReloadingCertificateProviderProvider implements CertificateProviderProvider {
private static final String DIRECTORY_KEY = "directory";
private static final String CERT_FILE_KEY = "certificate-file";
private static final String KEY_FILE_KEY = "private-key-file";
private static final String ROOT_FILE_KEY = "ca-certificate-file";
private static final String REFRESH_INTERVAL_KEY = "refresh-interval";
@VisibleForTesting static final long REFRESH_INTERVAL_DEFAULT = 600L;
static final String DYNAMIC_RELOADING_PROVIDER_NAME = "gke-cas-certs";
final DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory;
private final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory;
private final TimeProvider timeProvider;
DynamicReloadingCertificateProviderProvider() {
this(
DynamicReloadingCertificateProvider.Factory.getInstance(),
ScheduledExecutorServiceFactory.DEFAULT_INSTANCE,
TimeProvider.SYSTEM_TIME_PROVIDER);
}
@VisibleForTesting
DynamicReloadingCertificateProviderProvider(
DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory,
ScheduledExecutorServiceFactory scheduledExecutorServiceFactory,
TimeProvider timeProvider) {
this.dynamicReloadingCertificateProviderFactory = dynamicReloadingCertificateProviderFactory;
this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory;
this.timeProvider = timeProvider;
}
@Override
public String getName() {
return DYNAMIC_RELOADING_PROVIDER_NAME;
}
@Override
public CertificateProvider createCertificateProvider(
Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) {
Config configObj = validateAndTranslateConfig(config);
return dynamicReloadingCertificateProviderFactory.create(
watcher,
notifyCertUpdates,
configObj.directory,
configObj.certFile,
configObj.keyFile,
configObj.rootFile,
configObj.refrehInterval,
scheduledExecutorServiceFactory.create(),
timeProvider);
}
private static String checkForNullAndGet(Map<String, ?> map, String key) {
return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config");
}
private static Config validateAndTranslateConfig(Object config) {
checkArgument(config instanceof Map, "Only Map supported for config");
@SuppressWarnings("unchecked") Map<String, ?> map = (Map<String, ?>)config;
Config configObj = new Config();
configObj.directory = checkForNullAndGet(map, DIRECTORY_KEY);
configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY);
configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY);
configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY);
configObj.refrehInterval = JsonUtil.getNumberAsLong(map, REFRESH_INTERVAL_KEY);
if (configObj.refrehInterval == null) {
configObj.refrehInterval = REFRESH_INTERVAL_DEFAULT;
}
return configObj;
}
abstract static class ScheduledExecutorServiceFactory {
private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE =
new ScheduledExecutorServiceFactory() {
@Override
ScheduledExecutorService create() {
return Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
.setNameFormat("dynamicReloading" + "-%d")
.setDaemon(true)
.build());
}
};
abstract ScheduledExecutorService create();
}
/** POJO class for storing various config values. */
@VisibleForTesting
static class Config {
String directory;
String certFile;
String keyFile;
String rootFile;
Long refrehInterval;
}
}

View File

@ -1,500 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.Status.Code.ABORTED;
import static io.grpc.Status.Code.CANCELLED;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import static io.grpc.Status.Code.INTERNAL;
import static io.grpc.Status.Code.RESOURCE_EXHAUSTED;
import static io.grpc.Status.Code.UNAVAILABLE;
import static io.grpc.Status.Code.UNKNOWN;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.Duration;
import com.google.security.meshca.v1.MeshCertificateRequest;
import com.google.security.meshca.v1.MeshCertificateResponse;
import com.google.security.meshca.v1.MeshCertificateServiceGrpc;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.Grpc;
import io.grpc.InternalLogId;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.TlsChannelCredentials;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringWriter;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.x500.X500Principal;
import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
import org.bouncycastle.operator.ContentSigner;
import org.bouncycastle.operator.OperatorCreationException;
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
import org.bouncycastle.pkcs.PKCS10CertificationRequest;
import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder;
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder;
import org.bouncycastle.util.io.pem.PemObject;
/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
final class MeshCaCertificateProvider extends CertificateProvider {
private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
MeshCaCertificateProvider(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String meshCaUrl,
String zone,
long validitySeconds,
int keySize,
String unused, //TODO(sanjaypujare): to remove during refactoring
String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider,
long rpcTimeoutMillis) {
super(watcher, notifyCertUpdates);
this.meshCaUrl = checkNotNull(meshCaUrl, "meshCaUrl");
checkArgument(
validitySeconds > INITIAL_DELAY_SECONDS,
"validitySeconds must be greater than " + INITIAL_DELAY_SECONDS);
this.validitySeconds = validitySeconds;
this.keySize = keySize;
this.signatureAlg = checkNotNull(signatureAlg, "signatureAlg");
this.meshCaChannelFactory = checkNotNull(meshCaChannelFactory, "meshCaChannelFactory");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
checkArgument(
renewalGracePeriodSeconds > 0L && renewalGracePeriodSeconds < validitySeconds,
"renewalGracePeriodSeconds should be between 0 and " + validitySeconds);
this.renewalGracePeriodSeconds = renewalGracePeriodSeconds;
checkArgument(maxRetryAttempts >= 0, "maxRetryAttempts must be >= 0");
this.maxRetryAttempts = maxRetryAttempts;
this.oauth2Creds = checkNotNull(oauth2Creds, "oauth2Creds");
this.scheduledExecutorService =
checkNotNull(scheduledExecutorService, "scheduledExecutorService");
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.headerInterceptor = new ZoneInfoClientInterceptor(checkNotNull(zone, "zone"));
this.syncContext = createSynchronizationContext(meshCaUrl);
this.rpcTimeoutMillis = rpcTimeoutMillis;
}
private SynchronizationContext createSynchronizationContext(String details) {
final InternalLogId logId = InternalLogId.allocate("MeshCaCertificateProvider", details);
return new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
private boolean panicMode;
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.log(
Level.SEVERE,
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
e);
panic(e);
}
void panic(final Throwable t) {
if (panicMode) {
// Preserve the first panic information
return;
}
panicMode = true;
close();
}
});
}
@Override
public void start() {
scheduleNextRefreshCertificate(INITIAL_DELAY_SECONDS);
}
@Override
public void close() {
if (scheduledHandle != null) {
scheduledHandle.cancel();
scheduledHandle = null;
}
getWatcher().close();
}
private void scheduleNextRefreshCertificate(long delayInSeconds) {
if (scheduledHandle != null && scheduledHandle.isPending()) {
logger.log(Level.SEVERE, "Pending task found: inconsistent state in scheduledHandle!");
scheduledHandle.cancel();
}
RefreshCertificateTask runnable = new RefreshCertificateTask();
scheduledHandle = syncContext.schedule(
runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService);
}
@VisibleForTesting
void refreshCertificate()
throws NoSuchAlgorithmException, IOException, OperatorCreationException {
long refreshDelaySeconds = computeRefreshSecondsFromCurrentCertExpiry();
ManagedChannel channel = meshCaChannelFactory.createChannel(meshCaUrl);
try {
String uniqueReqIdForAllRetries = UUID.randomUUID().toString();
Duration duration = Duration.newBuilder().setSeconds(validitySeconds).build();
KeyPair keyPair = generateKeyPair();
String csr = generateCsr(keyPair);
MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub =
createStubToMeshCa(channel);
List<X509Certificate> x509Chain = makeRequestWithRetries(stub, uniqueReqIdForAllRetries,
duration, csr);
if (x509Chain != null) {
refreshDelaySeconds =
computeDelaySecondsToCertExpiry(x509Chain.get(0)) - renewalGracePeriodSeconds;
getWatcher().updateCertificate(keyPair.getPrivate(), x509Chain);
getWatcher().updateTrustedRoots(ImmutableList.of(x509Chain.get(x509Chain.size() - 1)));
}
} finally {
shutdownChannel(channel);
scheduleNextRefreshCertificate(refreshDelaySeconds);
}
}
private MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub createStubToMeshCa(
ManagedChannel channel) {
return MeshCertificateServiceGrpc
.newBlockingStub(channel)
.withCallCredentials(MoreCallCredentials.from(oauth2Creds))
.withInterceptors(headerInterceptor);
}
private List<X509Certificate> makeRequestWithRetries(
MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub,
String reqId,
Duration duration,
String csr) {
MeshCertificateRequest request =
MeshCertificateRequest.newBuilder()
.setValidity(duration)
.setCsr(csr)
.setRequestId(reqId)
.build();
BackoffPolicy backoffPolicy = backoffPolicyProvider.get();
Throwable lastException = null;
for (int i = 0; i <= maxRetryAttempts; i++) {
try {
MeshCertificateResponse response =
stub.withDeadlineAfter(rpcTimeoutMillis, TimeUnit.MILLISECONDS)
.createCertificate(request);
return getX509CertificatesFromResponse(response);
} catch (Throwable t) {
if (!retriable(t)) {
generateErrorIfCurrentCertExpired(t);
return null;
}
lastException = t;
sleepForNanos(backoffPolicy.nextBackoffNanos());
}
}
generateErrorIfCurrentCertExpired(lastException);
return null;
}
private void sleepForNanos(long nanos) {
ScheduledFuture<?> future = scheduledExecutorService.schedule(new Runnable() {
@Override
public void run() {
// do nothing
}
}, nanos, TimeUnit.NANOSECONDS);
try {
future.get(nanos, TimeUnit.NANOSECONDS);
} catch (InterruptedException ie) {
logger.log(Level.SEVERE, "Inside sleep", ie);
Thread.currentThread().interrupt();
} catch (ExecutionException | TimeoutException ex) {
logger.log(Level.SEVERE, "Inside sleep", ex);
}
}
private static boolean retriable(Throwable t) {
return RETRIABLE_CODES.contains(Status.fromThrowable(t).getCode());
}
private void generateErrorIfCurrentCertExpired(Throwable t) {
X509Certificate currentCert = getWatcher().getLastIdentityCert();
if (currentCert != null) {
long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
if (delaySeconds > INITIAL_DELAY_SECONDS) {
return;
}
getWatcher().clearValues();
}
getWatcher().onError(Status.fromThrowable(t));
}
private KeyPair generateKeyPair() throws NoSuchAlgorithmException {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(keySize);
return keyPairGenerator.generateKeyPair();
}
private String generateCsr(KeyPair pair) throws IOException, OperatorCreationException {
PKCS10CertificationRequestBuilder p10Builder =
new JcaPKCS10CertificationRequestBuilder(
new X500Principal("CN=EXAMPLE.COM"), pair.getPublic());
JcaContentSignerBuilder csBuilder = new JcaContentSignerBuilder(signatureAlg);
ContentSigner signer = csBuilder.build(pair.getPrivate());
PKCS10CertificationRequest csr = p10Builder.build(signer);
PemObject pemObject = new PemObject("NEW CERTIFICATE REQUEST", csr.getEncoded());
try (StringWriter str = new StringWriter()) {
try (JcaPEMWriter pemWriter = new JcaPEMWriter(str)) {
pemWriter.writeObject(pemObject);
}
return str.toString();
}
}
/** Compute refresh interval as half of interval to current cert expiry. */
private long computeRefreshSecondsFromCurrentCertExpiry() {
X509Certificate lastCert = getWatcher().getLastIdentityCert();
if (lastCert == null) {
return INITIAL_DELAY_SECONDS;
}
long delayToCertExpirySeconds = computeDelaySecondsToCertExpiry(lastCert) / 2;
return Math.max(delayToCertExpirySeconds, INITIAL_DELAY_SECONDS);
}
@SuppressWarnings("JdkObsolete")
private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
checkNotNull(lastCert, "lastCert");
return TimeUnit.NANOSECONDS.toSeconds(
TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - timeProvider
.currentTimeNanos());
}
private static void shutdownChannel(ManagedChannel channel) {
channel.shutdown();
try {
channel.awaitTermination(10, TimeUnit.SECONDS);
} catch (InterruptedException ex) {
logger.log(Level.SEVERE, "awaiting channel Termination", ex);
channel.shutdownNow();
Thread.currentThread().interrupt();
}
}
private List<X509Certificate> getX509CertificatesFromResponse(
MeshCertificateResponse response) throws CertificateException, IOException {
List<String> certChain = response.getCertChainList();
List<X509Certificate> x509Chain = new ArrayList<>(certChain.size());
for (String certString : certChain) {
try (ByteArrayInputStream bais = new ByteArrayInputStream(certString.getBytes(UTF_8))) {
x509Chain.add(CertificateUtils.toX509Certificate(bais));
}
}
return x509Chain;
}
@VisibleForTesting
class RefreshCertificateTask implements Runnable {
@Override
public void run() {
try {
refreshCertificate();
} catch (NoSuchAlgorithmException | OperatorCreationException | IOException ex) {
logger.log(Level.SEVERE, "refreshing certificate", ex);
}
}
}
/** Factory for creating channels to MeshCA sever. */
abstract static class MeshCaChannelFactory {
private static final MeshCaChannelFactory DEFAULT_INSTANCE =
new MeshCaChannelFactory() {
/** Creates a channel to the URL in the given list. */
@Override
ManagedChannel createChannel(String serverUri) {
checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!");
logger.log(Level.INFO, "Creating channel to {0}", serverUri);
return Grpc.newChannelBuilder(serverUri, TlsChannelCredentials.create())
.keepAliveTime(1, TimeUnit.MINUTES)
.build();
}
};
static MeshCaChannelFactory getInstance() {
return DEFAULT_INSTANCE;
}
/**
* Creates a channel to the server.
*/
abstract ManagedChannel createChannel(String serverUri);
}
/** Factory for creating channels to MeshCA sever. */
abstract static class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory() {
@Override
MeshCaCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String meshCaUrl,
String zone,
long validitySeconds,
int keySize,
String alg,
String signatureAlg,
MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider,
long rpcTimeoutMillis) {
return new MeshCaCertificateProvider(
watcher,
notifyCertUpdates,
meshCaUrl,
zone,
validitySeconds,
keySize,
alg,
signatureAlg,
meshCaChannelFactory,
backoffPolicyProvider,
renewalGracePeriodSeconds,
maxRetryAttempts,
oauth2Creds,
scheduledExecutorService,
timeProvider,
rpcTimeoutMillis);
}
};
static Factory getInstance() {
return DEFAULT_INSTANCE;
}
abstract MeshCaCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String meshCaUrl,
String zone,
long validitySeconds,
int keySize,
String alg,
String signatureAlg,
MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider,
long rpcTimeoutMillis);
}
private class ZoneInfoClientInterceptor implements ClientInterceptor {
private final String zone;
ZoneInfoClientInterceptor(String zone) {
this.zone = zone;
}
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
next.newCall(method, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
headers.put(KEY_FOR_ZONE_INFO, "location=locations/" + zone);
super.start(responseListener, headers);
}
};
}
}
@VisibleForTesting
static final Metadata.Key<String> KEY_FOR_ZONE_INFO =
Metadata.Key.of("x-goog-request-params", Metadata.ASCII_STRING_MARSHALLER);
@VisibleForTesting
static final long INITIAL_DELAY_SECONDS = 4L;
private static final EnumSet<Status.Code> RETRIABLE_CODES =
EnumSet.of(
CANCELLED,
UNKNOWN,
DEADLINE_EXCEEDED,
RESOURCE_EXHAUSTED,
ABORTED,
INTERNAL,
UNAVAILABLE);
private final SynchronizationContext syncContext;
private final ScheduledExecutorService scheduledExecutorService;
private final int maxRetryAttempts;
private final ZoneInfoClientInterceptor headerInterceptor;
private final BackoffPolicy.Provider backoffPolicyProvider;
private final String meshCaUrl;
private final long validitySeconds;
private final long renewalGracePeriodSeconds;
private final int keySize;
private final String signatureAlg;
private final GoogleCredentials oauth2Creds;
private final TimeProvider timeProvider;
private final MeshCaChannelFactory meshCaChannelFactory;
@VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle;
private final long rpcTimeoutMillis;
}

View File

@ -1,286 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.grpc.internal.JsonUtil.getObject;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.JsonUtil;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sts.StsCredentials;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may
* move this out of the internal package and make this an official API in the future.
*/
final class MeshCaCertificateProviderProvider implements CertificateProviderProvider {
private static final String SERVER_CONFIG_KEY = "server";
private static final String MESHCA_URL_KEY = "target_uri";
private static final String RPC_TIMEOUT_SECONDS_KEY = "time_out";
private static final String GKECLUSTER_URL_KEY = "location";
private static final String CERT_VALIDITY_SECONDS_KEY = "certificate_lifetime";
private static final String RENEWAL_GRACE_PERIOD_SECONDS_KEY = "renewal_grace_period";
private static final String KEY_ALGO_KEY = "key_type"; // aka keyType
private static final String KEY_SIZE_KEY = "key_size";
private static final String STS_SERVICE_KEY = "sts_service";
private static final String TOKEN_EXCHANGE_SERVICE_KEY = "token_exchange_service";
private static final String GKE_SA_JWT_LOCATION_KEY = "subject_token_path";
@VisibleForTesting static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
@VisibleForTesting static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
@VisibleForTesting static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L;
@VisibleForTesting static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L;
@VisibleForTesting static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
@VisibleForTesting static final int KEY_SIZE_DEFAULT = 2048;
@VisibleForTesting static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
@VisibleForTesting static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
@VisibleForTesting
static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
@VisibleForTesting
static final long RPC_TIMEOUT_SECONDS = 10L;
private static final Pattern CLUSTER_URL_PATTERN = Pattern
.compile(".*/projects/(.*)/(?:locations|zones)/(.*)/clusters/.*");
private static final String TRUST_DOMAIN_SUFFIX = ".svc.id.goog";
private static final String AUDIENCE_PREFIX = "identitynamespace:";
static final String MESH_CA_NAME = "meshCA";
final StsCredentials.Factory stsCredentialsFactory;
final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
final BackoffPolicy.Provider backoffPolicyProvider;
final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory;
final TimeProvider timeProvider;
MeshCaCertificateProviderProvider() {
this(
StsCredentials.Factory.getInstance(),
MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
new ExponentialBackoffPolicy.Provider(),
MeshCaCertificateProvider.Factory.getInstance(),
ScheduledExecutorServiceFactory.DEFAULT_INSTANCE,
TimeProvider.SYSTEM_TIME_PROVIDER);
}
@VisibleForTesting
MeshCaCertificateProviderProvider(
StsCredentials.Factory stsCredentialsFactory,
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory,
ScheduledExecutorServiceFactory scheduledExecutorServiceFactory,
TimeProvider timeProvider) {
this.stsCredentialsFactory = stsCredentialsFactory;
this.meshCaChannelFactory = meshCaChannelFactory;
this.backoffPolicyProvider = backoffPolicyProvider;
this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory;
this.timeProvider = timeProvider;
}
@Override
public String getName() {
return MESH_CA_NAME;
}
@Override
public CertificateProvider createCertificateProvider(
Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) {
Config configObj = validateAndTranslateConfig(config);
// Construct audience from project and gkeClusterUrl
String audience =
AUDIENCE_PREFIX + configObj.project + TRUST_DOMAIN_SUFFIX + ":" + configObj.gkeClusterUrl;
StsCredentials stsCredentials = stsCredentialsFactory
.create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
return meshCaCertificateProviderFactory.create(
watcher,
notifyCertUpdates,
configObj.meshCaUrl,
configObj.zone,
configObj.certValiditySeconds,
configObj.keySize,
configObj.keyAlgo,
configObj.signatureAlgo,
meshCaChannelFactory,
backoffPolicyProvider,
configObj.renewalGracePeriodSeconds,
configObj.maxRetryAttempts,
stsCredentials,
scheduledExecutorServiceFactory.create(configObj.meshCaUrl),
timeProvider,
TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS));
}
private static Config validateAndTranslateConfig(Object config) {
// TODO(sanjaypujare): add support for string, struct proto etc
checkArgument(config instanceof Map, "Only Map supported for config");
@SuppressWarnings("unchecked") Map<String, ?> map = (Map<String, ?>)config;
Config configObj = new Config();
extractMeshCaServerConfig(configObj, getObject(map, SERVER_CONFIG_KEY));
configObj.certValiditySeconds =
getSeconds(
JsonUtil.getObject(map, CERT_VALIDITY_SECONDS_KEY), CERT_VALIDITY_SECONDS_DEFAULT);
configObj.renewalGracePeriodSeconds =
getSeconds(
JsonUtil.getObject(map, RENEWAL_GRACE_PERIOD_SECONDS_KEY),
RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT);
String keyType = JsonUtil.getString(map, KEY_ALGO_KEY);
checkArgument(
keyType == null || keyType.equals(KEY_ALGO_DEFAULT), "key_type can only be null or 'RSA'");
// TODO: remove signatureAlgo, keyType (or keyAlgo), maxRetryAttempts
configObj.maxRetryAttempts = MAX_RETRY_ATTEMPTS_DEFAULT;
configObj.keyAlgo = KEY_ALGO_DEFAULT;
configObj.signatureAlgo = SIGNATURE_ALGO_DEFAULT;
configObj.keySize = JsonUtil.getNumberAsInteger(map, KEY_SIZE_KEY);
if (configObj.keySize == null) {
configObj.keySize = KEY_SIZE_DEFAULT;
}
configObj.gkeClusterUrl =
checkNotNull(JsonUtil.getString(map, GKECLUSTER_URL_KEY),
"'location' is required in the config");
parseProjectAndZone(configObj.gkeClusterUrl, configObj);
return configObj;
}
private static void extractMeshCaServerConfig(Config configObj, Map<String, ?> serverConfig) {
// init with defaults
configObj.meshCaUrl = MESHCA_URL_DEFAULT;
configObj.rpcTimeoutSeconds = RPC_TIMEOUT_SECONDS_DEFAULT;
configObj.stsUrl = STS_URL_DEFAULT;
if (serverConfig != null) {
checkArgument(
"GRPC".equals(JsonUtil.getString(serverConfig, "api_type")),
"Only GRPC api_type supported");
List<Map<String, ?>> grpcServices =
checkNotNull(
JsonUtil.getListOfObjects(serverConfig, "grpc_services"), "grpc_services not found");
for (Map<String, ?> grpcService : grpcServices) {
Map<String, ?> googleGrpcConfig = JsonUtil.getObject(grpcService, "google_grpc");
if (googleGrpcConfig != null) {
String value = JsonUtil.getString(googleGrpcConfig, MESHCA_URL_KEY);
if (value != null) {
configObj.meshCaUrl = value;
}
Map<String, ?> channelCreds =
JsonUtil.getObject(googleGrpcConfig, "channel_credentials");
if (channelCreds != null) {
Map<String, ?> googleDefaultChannelCreds =
checkNotNull(
JsonUtil.getObject(channelCreds, "google_default"),
"channel_credentials need to be google_default!");
checkArgument(
googleDefaultChannelCreds.isEmpty(),
"google_default credentials contain illegal value");
}
List<Map<String, ?>> callCreds =
JsonUtil.getListOfObjects(googleGrpcConfig, "call_credentials");
for (Map<String, ?> callCred : callCreds) {
Map<String, ?> stsCreds = JsonUtil.getObject(callCred, STS_SERVICE_KEY);
if (stsCreds != null) {
value = JsonUtil.getString(stsCreds, TOKEN_EXCHANGE_SERVICE_KEY);
if (value != null) {
configObj.stsUrl = value;
}
configObj.gkeSaJwtLocation = JsonUtil.getString(stsCreds, GKE_SA_JWT_LOCATION_KEY);
}
}
configObj.rpcTimeoutSeconds =
getSeconds(
JsonUtil.getObject(grpcService, RPC_TIMEOUT_SECONDS_KEY),
RPC_TIMEOUT_SECONDS_DEFAULT);
}
}
}
// check required value(s)
checkNotNull(configObj.gkeSaJwtLocation, "'subject_token_path' is required in the config");
}
private static Long getSeconds(Map<String,?> duration, long defaultValue) {
if (duration != null) {
return JsonUtil.getNumberAsLong(duration, "seconds");
}
return defaultValue;
}
private static void parseProjectAndZone(String gkeClusterUrl, Config configObj) {
Matcher matcher = CLUSTER_URL_PATTERN.matcher(gkeClusterUrl);
checkState(matcher.find(), "gkeClusterUrl does not have correct format");
checkState(matcher.groupCount() == 2, "gkeClusterUrl does not have project and location parts");
configObj.project = matcher.group(1);
configObj.zone = matcher.group(2);
}
abstract static class ScheduledExecutorServiceFactory {
private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE =
new ScheduledExecutorServiceFactory() {
@Override
ScheduledExecutorService create(String serverUri) {
return Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
.setNameFormat("meshca-" + serverUri + "-%d")
.setDaemon(true)
.build());
}
};
static ScheduledExecutorServiceFactory getInstance() {
return DEFAULT_INSTANCE;
}
abstract ScheduledExecutorService create(String serverUri);
}
/** POJO class for storing various config values. */
@VisibleForTesting
static class Config {
String meshCaUrl;
Long rpcTimeoutSeconds;
String gkeClusterUrl;
Long certValiditySeconds;
Long renewalGracePeriodSeconds;
String keyAlgo; // aka keyType
Integer keySize;
String signatureAlgo;
Integer maxRetryAttempts;
String stsUrl;
String gkeSaJwtLocation;
String zone;
String project;
}
}

View File

@ -1,219 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.internal.JsonParser;
import io.grpc.internal.TimeProvider;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link DynamicReloadingCertificateProviderProvider}. */
@RunWith(JUnit4.class)
public class DynamicReloadingCertificateProviderProviderTest {
@Mock DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory;
@Mock private DynamicReloadingCertificateProviderProvider.ScheduledExecutorServiceFactory
scheduledExecutorServiceFactory;
@Mock private TimeProvider timeProvider;
private DynamicReloadingCertificateProviderProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
provider =
new DynamicReloadingCertificateProviderProvider(
dynamicReloadingCertificateProviderFactory,
scheduledExecutorServiceFactory,
timeProvider);
}
@Test
public void providerRegisteredName() {
CertificateProviderProvider certProviderProvider =
CertificateProviderRegistry.getInstance()
.getProvider(
DynamicReloadingCertificateProviderProvider.DYNAMIC_RELOADING_PROVIDER_NAME);
assertThat(certProviderProvider)
.isInstanceOf(DynamicReloadingCertificateProviderProvider.class);
DynamicReloadingCertificateProviderProvider dynamicReloadingCertificateProviderProvider =
(DynamicReloadingCertificateProviderProvider) certProviderProvider;
assertThat(
dynamicReloadingCertificateProviderProvider.dynamicReloadingCertificateProviderFactory)
.isSameInstanceAs(DynamicReloadingCertificateProvider.Factory.getInstance());
}
@Test
public void createProvider_minimalConfig() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MINIMAL_DYNAMIC_RELOADING_CONFIG);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create()).thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(dynamicReloadingCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq("/var/run/gke-spiffe/certs/..data"),
eq("certificates.pem"),
eq("private_key.pem"),
eq("ca_certificates.pem"),
eq(600L),
eq(mockService),
eq(timeProvider));
}
@Test
public void createProvider_fullConfig() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(FULL_DYNAMIC_RELOADING_CONFIG);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create()).thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(dynamicReloadingCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq("/var/run/gke-spiffe/certs/..data1"),
eq("certificates2.pem"),
eq("private_key3.pem"),
eq("ca_certificates4.pem"),
eq(7890L),
eq(mockService),
eq(timeProvider));
}
@Test
public void createProvider_missingDir_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_DIR_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'directory' is required in the config");
}
}
@Test
public void createProvider_missingCert_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_CERT_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'certificate-file' is required in the config");
}
}
@Test
public void createProvider_missingKey_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_KEY_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'private-key-file' is required in the config");
}
}
@Test
public void createProvider_missingRoot_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_ROOT_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'ca-certificate-file' is required in the config");
}
}
private static final String MINIMAL_DYNAMIC_RELOADING_CONFIG =
"{\n"
+ " \"directory\": \"/var/run/gke-spiffe/certs/..data\","
+ " \"certificate-file\": \"certificates.pem\","
+ " \"private-key-file\": \"private_key.pem\","
+ " \"ca-certificate-file\": \"ca_certificates.pem\""
+ " }";
private static final String FULL_DYNAMIC_RELOADING_CONFIG =
"{\n"
+ " \"directory\": \"/var/run/gke-spiffe/certs/..data1\","
+ " \"certificate-file\": \"certificates2.pem\","
+ " \"private-key-file\": \"private_key3.pem\","
+ " \"ca-certificate-file\": \"ca_certificates4.pem\","
+ " \"refresh-interval\": 7890"
+ " }";
private static final String MISSING_DIR_CONFIG =
"{\n"
+ " \"certificate-file\": \"certificates.pem\","
+ " \"private-key-file\": \"private_key.pem\","
+ " \"ca-certificate-file\": \"ca_certificates.pem\""
+ " }";
private static final String MISSING_CERT_CONFIG =
"{\n"
+ " \"directory\": \"/var/run/gke-spiffe/certs/..data\","
+ " \"private-key-file\": \"private_key.pem\","
+ " \"ca-certificate-file\": \"ca_certificates.pem\""
+ " }";
private static final String MISSING_KEY_CONFIG =
"{\n"
+ " \"directory\": \"/var/run/gke-spiffe/certs/..data\","
+ " \"certificate-file\": \"certificates.pem\","
+ " \"ca-certificate-file\": \"ca_certificates.pem\""
+ " }";
private static final String MISSING_ROOT_CONFIG =
"{\n"
+ " \"directory\": \"/var/run/gke-spiffe/certs/..data\","
+ " \"certificate-file\": \"certificates.pem\","
+ " \"private-key-file\": \"private_key.pem\""
+ " }";
}

View File

@ -1,303 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Status;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.NoSuchFileException;
import java.nio.file.Paths;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link DynamicReloadingCertificateProvider}. */
@RunWith(JUnit4.class)
public class DynamicReloadingCertificateProviderTest {
private static final String CERT_FILE = "cert.pem";
private static final String KEY_FILE = "key.pem";
private static final String ROOT_FILE = "root.pem";
@Mock private CertificateProvider.Watcher mockWatcher;
@Mock private ScheduledExecutorService timeService;
@Mock private TimeProvider timeProvider;
@Rule public TemporaryFolder tempFolder = new TemporaryFolder();
private String symlink;
private DynamicReloadingCertificateProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
DistributorWatcher watcher = new DistributorWatcher();
watcher.addWatcher(mockWatcher);
symlink = new File(tempFolder.getRoot(), "..data").getAbsolutePath();
provider =
new DynamicReloadingCertificateProvider(
watcher,
true,
symlink,
CERT_FILE,
KEY_FILE,
ROOT_FILE,
600L,
timeService,
timeProvider);
}
private void populateTarget(
String certFile, String keyFile, String rootFile, boolean deleteExisting, boolean createNew)
throws IOException {
String target = tempFolder.newFolder().getAbsolutePath();
if (certFile != null) {
certFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFile);
Files.copy(Paths.get(certFile), Paths.get(target, CERT_FILE));
}
if (keyFile != null) {
keyFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFile);
Files.copy(Paths.get(keyFile), Paths.get(target, KEY_FILE));
}
if (rootFile != null) {
rootFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFile);
Files.copy(Paths.get(rootFile), Paths.get(target, ROOT_FILE));
}
if (deleteExisting) {
Files.delete(Paths.get(symlink));
}
if (createNew) {
Files.createSymbolicLink(Paths.get(symlink), Paths.get(target));
}
}
@Test
public void getCertificateAndCheckUpdates() throws IOException, CertificateException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, true);
provider.checkAndReloadCertificates();
verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE);
verifyTimeServiceAndScheduledHandle();
reset(mockWatcher, timeService);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.scheduledHandle.cancel();
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(null, null, (String[]) null);
verifyTimeServiceAndScheduledHandle();
reset(mockWatcher, timeService);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.scheduledHandle.cancel();
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, true, true);
provider.checkAndReloadCertificates();
verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE);
verifyTimeServiceAndScheduledHandle();
}
@Test
public void getCertificate_initialMissingCertFile() throws IOException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, true);
when(timeProvider.currentTimeNanos())
.thenReturn(TimeProvider.SYSTEM_TIME_PROVIDER.currentTimeNanos());
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(Status.Code.UNKNOWN, java.io.FileNotFoundException.class, "cert.pem");
}
@Test
public void getCertificate_missingSymlink() throws IOException {
commonErrorTest(null, null, null, true, false, NoSuchFileException.class, "..data");
}
@Test
public void getCertificate_missingCertFile() throws IOException {
commonErrorTest(
null,
CLIENT_KEY_FILE,
CA_PEM_FILE,
true,
true,
java.io.FileNotFoundException.class,
"cert.pem");
}
@Test
public void getCertificate_missingKeyFile() throws IOException {
commonErrorTest(
CLIENT_PEM_FILE,
null,
CA_PEM_FILE,
true,
true,
java.io.FileNotFoundException.class,
"key.pem");
}
@Test
public void getCertificate_badKeyFile() throws IOException {
commonErrorTest(
CLIENT_PEM_FILE,
SERVER_0_PEM_FILE,
CA_PEM_FILE,
true,
true,
java.security.KeyException.class,
"could not find a PKCS #8 private key in input stream");
}
@Test
public void getCertificate_missingRootFile() throws IOException {
commonErrorTest(
CLIENT_PEM_FILE,
CLIENT_KEY_FILE,
null,
true,
true,
java.io.FileNotFoundException.class,
"root.pem");
}
private void commonErrorTest(
String certFile,
String keyFile,
String rootFile,
boolean deleteExisting,
boolean createNew,
Class<?> throwableType,
String... causeMessages)
throws IOException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, true);
provider.checkAndReloadCertificates();
reset(mockWatcher);
populateTarget(certFile, keyFile, rootFile, deleteExisting, createNew);
when(timeProvider.currentTimeNanos())
.thenReturn(
TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L));
provider.scheduledHandle.cancel();
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(null, null, (String[]) null);
reset(mockWatcher, timeProvider);
when(timeProvider.currentTimeNanos())
.thenReturn(
TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L));
provider.scheduledHandle.cancel();
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(Status.Code.UNKNOWN, throwableType, causeMessages);
}
private void verifyWatcherErrorUpdates(
Status.Code code, Class<?> throwableType, String... causeMessages) {
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
if (code == null && throwableType == null && causeMessages == null) {
verify(mockWatcher, never()).onError(any(Status.class));
} else {
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1)).onError(statusCaptor.capture());
Status status = statusCaptor.getValue();
assertThat(status.getCode()).isEqualTo(code);
Throwable cause = status.getCause();
assertThat(cause).isInstanceOf(throwableType);
for (String causeMessage : causeMessages) {
assertThat(cause).hasMessageThat().contains(causeMessage);
cause = cause.getCause();
}
}
}
private void verifyTimeServiceAndScheduledHandle() {
verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS));
assertThat(provider.scheduledHandle).isNotNull();
assertThat(provider.scheduledHandle.isPending()).isTrue();
}
private void verifyWatcherUpdates(String certPemFile, String rootPemFile)
throws IOException, CertificateException {
ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1))
.updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
List<X509Certificate> certChain = certChainCaptor.getValue();
assertThat(certChain).hasSize(1);
assertThat(certChain.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(certPemFile));
ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
List<X509Certificate> roots = rootsCaptor.getValue();
assertThat(roots).hasSize(1);
assertThat(roots.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(rootPemFile));
verify(mockWatcher, never()).onError(any(Status.class));
}
}

View File

@ -45,8 +45,11 @@ import java.nio.file.Paths;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.Delayed;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
@ -62,6 +65,10 @@ import org.mockito.MockitoAnnotations;
/** Unit tests for {@link FileWatcherCertificateProvider}. */ /** Unit tests for {@link FileWatcherCertificateProvider}. */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class FileWatcherCertificateProviderTest { public class FileWatcherCertificateProviderTest {
/**
* Expire time of cert SERVER_0_PEM_FILE.
*/
static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L;
private static final String CERT_FILE = "cert.pem"; private static final String CERT_FILE = "cert.pem";
private static final String KEY_FILE = "key.pem"; private static final String KEY_FILE = "key.pem";
private static final String ROOT_FILE = "root.pem"; private static final String ROOT_FILE = "root.pem";
@ -126,8 +133,8 @@ public class FileWatcherCertificateProviderTest {
@Test @Test
public void getCertificateAndCheckUpdates() throws IOException, CertificateException { public void getCertificateAndCheckUpdates() throws IOException, CertificateException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -147,8 +154,8 @@ public class FileWatcherCertificateProviderTest {
@Test @Test
public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException { public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -168,8 +175,8 @@ public class FileWatcherCertificateProviderTest {
@Test @Test
public void closeDoesNotScheduleNext() throws IOException, CertificateException { public void closeDoesNotScheduleNext() throws IOException, CertificateException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -186,8 +193,8 @@ public class FileWatcherCertificateProviderTest {
@Test @Test
public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException { public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -208,8 +215,8 @@ public class FileWatcherCertificateProviderTest {
@Test @Test
public void certAndKeyFileUpdateOnly() public void certAndKeyFileUpdateOnly()
throws IOException, CertificateException, InterruptedException { throws IOException, CertificateException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -229,8 +236,8 @@ public class FileWatcherCertificateProviderTest {
@Test @Test
public void getCertificate_initialMissingCertFile() throws IOException { public void getCertificate_initialMissingCertFile() throws IOException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -269,8 +276,8 @@ public class FileWatcherCertificateProviderTest {
@Test @Test
public void getCertificate_missingRootFile() throws IOException, InterruptedException { public void getCertificate_missingRootFile() throws IOException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -283,7 +290,7 @@ public class FileWatcherCertificateProviderTest {
when(timeProvider.currentTimeNanos()) when(timeProvider.currentTimeNanos())
.thenReturn( .thenReturn(
TimeUnit.MILLISECONDS.toNanos( TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); CERT0_EXPIRY_TIME_MILLIS - 610_000L));
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem"); verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem");
} }
@ -299,8 +306,8 @@ public class FileWatcherCertificateProviderTest {
int secondUpdateRootCount, int secondUpdateRootCount,
String... causeMessages) String... causeMessages)
throws IOException, InterruptedException { throws IOException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>(); new TestScheduledFuture<>();
doReturn(scheduledFuture) doReturn(scheduledFuture)
.when(timeService) .when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
@ -314,7 +321,7 @@ public class FileWatcherCertificateProviderTest {
when(timeProvider.currentTimeNanos()) when(timeProvider.currentTimeNanos())
.thenReturn( .thenReturn(
TimeUnit.MILLISECONDS.toNanos( TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); CERT0_EXPIRY_TIME_MILLIS - 610_000L));
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates( verifyWatcherErrorUpdates(
null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null); null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null);
@ -323,7 +330,7 @@ public class FileWatcherCertificateProviderTest {
when(timeProvider.currentTimeNanos()) when(timeProvider.currentTimeNanos())
.thenReturn( .thenReturn(
TimeUnit.MILLISECONDS.toNanos( TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L)); CERT0_EXPIRY_TIME_MILLIS - 590_000L));
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates( verifyWatcherErrorUpdates(
Status.Code.UNKNOWN, Status.Code.UNKNOWN,
@ -392,4 +399,55 @@ public class FileWatcherCertificateProviderTest {
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList()); verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
} }
} }
static class TestScheduledFuture<V> implements ScheduledFuture<V> {
static class Record {
long timeout;
TimeUnit unit;
Record(long timeout, TimeUnit unit) {
this.timeout = timeout;
this.unit = unit;
}
}
ArrayList<Record> calls = new ArrayList<>();
@Override
public long getDelay(TimeUnit unit) {
return 0;
}
@Override
public int compareTo(Delayed o) {
return 0;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return false;
}
@Override
public boolean isCancelled() {
return false;
}
@Override
public boolean isDone() {
return false;
}
@Override
public V get() {
return null;
}
@Override
public V get(long timeout, TimeUnit unit) {
calls.add(new Record(timeout, unit));
return null;
}
}
} }

View File

@ -1,409 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.certprovider.MeshCaCertificateProviderProvider.RPC_TIMEOUT_SECONDS;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.auth.oauth2.GoogleCredentials;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.JsonParser;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sts.StsCredentials;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link MeshCaCertificateProviderProvider}. */
@RunWith(JUnit4.class)
public class MeshCaCertificateProviderProviderTest {
public static final String EXPECTED_AUDIENCE =
"identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3";
public static final String EXPECTED_AUDIENCE_V1BETA1_ZONE =
"identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1beta1/projects/test-project1/zones/test-zone2/clusters/test-cluster3";
public static final String TMP_PATH_4 = "/tmp/path4";
public static final String NON_DEFAULT_MESH_CA_URL = "nonDefaultMeshCaUrl";
@Mock
StsCredentials.Factory stsCredentialsFactory;
@Mock
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
@Mock
BackoffPolicy.Provider backoffPolicyProvider;
@Mock
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
@Mock
private MeshCaCertificateProviderProvider.ScheduledExecutorServiceFactory
scheduledExecutorServiceFactory;
@Mock
private TimeProvider timeProvider;
private MeshCaCertificateProviderProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
provider =
new MeshCaCertificateProviderProvider(
stsCredentialsFactory,
meshCaChannelFactory,
backoffPolicyProvider,
meshCaCertificateProviderFactory,
scheduledExecutorServiceFactory,
timeProvider);
}
@Test
public void providerRegisteredName() {
CertificateProviderProvider certProviderProvider = CertificateProviderRegistry.getInstance()
.getProvider(MeshCaCertificateProviderProvider.MESH_CA_NAME);
assertThat(certProviderProvider).isInstanceOf(MeshCaCertificateProviderProvider.class);
MeshCaCertificateProviderProvider meshCaCertificateProviderProvider =
(MeshCaCertificateProviderProvider) certProviderProvider;
assertThat(meshCaCertificateProviderProvider.stsCredentialsFactory)
.isSameInstanceAs(StsCredentials.Factory.getInstance());
assertThat(meshCaCertificateProviderProvider.meshCaChannelFactory)
.isSameInstanceAs(MeshCaCertificateProvider.MeshCaChannelFactory.getInstance());
assertThat(meshCaCertificateProviderProvider.backoffPolicyProvider)
.isInstanceOf(ExponentialBackoffPolicy.Provider.class);
assertThat(meshCaCertificateProviderProvider.meshCaCertificateProviderFactory)
.isSameInstanceAs(MeshCaCertificateProvider.Factory.getInstance());
}
@Test
public void createProvider_minimalConfig() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MINIMAL_MESHCA_CONFIG);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create(
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT)))
.thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT),
eq(EXPECTED_AUDIENCE),
eq("/tmp/path5"));
verify(meshCaCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT),
eq("test-zone2"),
eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT),
eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT),
eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT),
eq(meshCaChannelFactory),
eq(backoffPolicyProvider),
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
(GoogleCredentials) isNull(),
eq(mockService),
eq(timeProvider),
eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
}
@Test
public void createProvider_minimalConfig_v1beta1AndZone()
throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(V1BETA1_ZONE_MESHCA_CONFIG);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create(
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT)))
.thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT),
eq(EXPECTED_AUDIENCE_V1BETA1_ZONE),
eq("/tmp/path5"));
verify(meshCaCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT),
eq("test-zone2"),
eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT),
eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT),
eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT),
eq(meshCaChannelFactory),
eq(backoffPolicyProvider),
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
(GoogleCredentials) isNull(),
eq(mockService),
eq(timeProvider),
eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
}
@Test
public void createProvider_missingGkeUrl_expectException()
throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_GKE_CLUSTER_URL_MESHCA_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'location' is required in the config");
}
}
@Test
public void createProvider_missingSaJwtLocation_expectException()
throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_SAJWT_MESHCA_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'subject_token_path' is required in the config");
}
}
@Test
public void createProvider_missingProject_expectException()
throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MINIMAL_BAD_CLUSTER_URL_MESHCA_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (IllegalStateException ex) {
assertThat(ex).hasMessageThat().isEqualTo("gkeClusterUrl does not have correct format");
}
}
@Test
public void createProvider_badChannelCreds_expectException()
throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(BAD_CHANNEL_CREDS_MESHCA_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException ex) {
assertThat(ex).hasMessageThat().isEqualTo("channel_credentials need to be google_default!");
}
}
@Test
public void createProvider_nonDefaultFullConfig() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(NONDEFAULT_MESHCA_CONFIG);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create(eq(NON_DEFAULT_MESH_CA_URL)))
.thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
eq("test.sts.com"),
eq(EXPECTED_AUDIENCE),
eq(TMP_PATH_4));
verify(meshCaCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq(NON_DEFAULT_MESH_CA_URL),
eq("test-zone2"),
eq(234567L),
eq(512),
eq("RSA"),
eq("SHA256withRSA"),
eq(meshCaChannelFactory),
eq(backoffPolicyProvider),
eq(4321L),
eq(3),
(GoogleCredentials) isNull(),
eq(mockService),
eq(timeProvider),
eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
}
private static final String NONDEFAULT_MESHCA_CONFIG =
"{\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"target_uri\": \"nonDefaultMeshCaUrl\",\n"
+ " \"channel_credentials\": {\"google_default\": {}},\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " \"token_exchange_service\": \"test.sts.com\",\n"
+ " \"subject_token_path\": \"/tmp/path4\"\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " },\n" // end google_grpc
+ " \"time_out\": {\"seconds\": 12}\n"
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"certificate_lifetime\": {\"seconds\": 234567},\n"
+ " \"renewal_grace_period\": {\"seconds\": 4321},\n"
+ " \"key_type\": \"RSA\",\n"
+ " \"key_size\": 512,\n"
+ " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n"
+ " }";
private static final String MINIMAL_MESHCA_CONFIG =
"{\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " \"subject_token_path\": \"/tmp/path5\"\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " }\n" // end google_grpc
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n"
+ " }";
private static final String V1BETA1_ZONE_MESHCA_CONFIG =
"{\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " \"subject_token_path\": \"/tmp/path5\"\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " }\n" // end google_grpc
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"location\": \"https://container.googleapis.com/v1beta1/projects/test-project1/zones/test-zone2/clusters/test-cluster3\"\n"
+ " }";
private static final String MINIMAL_BAD_CLUSTER_URL_MESHCA_CONFIG =
"{\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " \"subject_token_path\": \"/tmp/path5\"\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " }\n" // end google_grpc
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"location\": \"https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3\"\n"
+ " }";
private static final String MISSING_SAJWT_MESHCA_CONFIG =
"{\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " }\n" // end google_grpc
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n"
+ " }";
private static final String MISSING_GKE_CLUSTER_URL_MESHCA_CONFIG =
"{\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"target_uri\": \"meshca.com\",\n"
+ " \"channel_credentials\": {\"google_default\": {}},\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " \"token_exchange_service\": \"securetoken.googleapis.com\",\n"
+ " \"subject_token_path\": \"/etc/secret/sajwt.token\"\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " },\n" // end google_grpc
+ " \"time_out\": {\"seconds\": 10}\n"
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"certificate_lifetime\": {\"seconds\": 86400},\n"
+ " \"renewal_grace_period\": {\"seconds\": 3600},\n"
+ " \"key_type\": \"RSA\",\n"
+ " \"key_size\": 2048\n"
+ " }";
private static final String BAD_CHANNEL_CREDS_MESHCA_CONFIG =
"{\n"
+ " \"server\": {\n"
+ " \"api_type\": \"GRPC\",\n"
+ " \"grpc_services\": [{\n"
+ " \"google_grpc\": {\n"
+ " \"channel_credentials\": {\"mtls\": \"true\"},\n"
+ " \"call_credentials\": [{\n"
+ " \"sts_service\": {\n"
+ " \"subject_token_path\": \"/tmp/path5\"\n"
+ " }\n"
+ " }]\n" // end call_credentials
+ " }\n" // end google_grpc
+ " }]\n" // end grpc_services
+ " },\n" // end server
+ " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n"
+ " }";
}

View File

@ -1,591 +0,0 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.auth.http.AuthHttpConstants;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.security.meshca.v1.MeshCertificateRequest;
import com.google.security.meshca.v1.MeshCertificateResponse;
import com.google.security.meshca.v1.MeshCertificateServiceGrpc;
import io.grpc.Context;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.TimeProvider;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Delayed;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.bouncycastle.operator.OperatorCreationException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.Spy;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link MeshCaCertificateProvider}. */
@RunWith(JUnit4.class)
public class MeshCaCertificateProviderTest {
private static final String TEST_STS_TOKEN = "test-stsToken";
private static final long RENEWAL_GRACE_PERIOD_SECONDS = TimeUnit.HOURS.toSeconds(1L);
private static final Metadata.Key<String> KEY_FOR_AUTHORIZATION =
Metadata.Key.of(AuthHttpConstants.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER);
private static final String ZONE = "us-west2-a";
private static final long START_DELAY = 200_000_000L; // 0.2 seconds
private static final long[] DELAY_VALUES = {START_DELAY, START_DELAY * 2, START_DELAY * 4};
private static final long RPC_TIMEOUT_MILLIS = 1000L;
/**
* Expire time of cert SERVER_0_PEM_FILE.
*/
static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L;
/**
* Cert validity of 12 hours for the above cert.
*/
private static final long CERT0_VALIDITY_MILLIS = TimeUnit.MILLISECONDS
.convert(12, TimeUnit.HOURS);
/**
* Compute current time based on cert expiry and cert validity.
*/
private static final long CURRENT_TIME_NANOS =
TimeUnit.MILLISECONDS.toNanos(CERT0_EXPIRY_TIME_MILLIS - CERT0_VALIDITY_MILLIS);
@Rule
public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
private static class ResponseToSend {
Throwable getThrowable() {
throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
}
List<String> getList() {
throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
}
}
private static class ResponseThrowable extends ResponseToSend {
final Throwable throwableToSend;
ResponseThrowable(Throwable throwable) {
throwableToSend = throwable;
}
@Override
Throwable getThrowable() {
return throwableToSend;
}
}
private static class ResponseList extends ResponseToSend {
final List<String> listToSend;
ResponseList(List<String> list) {
listToSend = list;
}
@Override
List<String> getList() {
return listToSend;
}
}
private final Queue<MeshCertificateRequest> receivedRequests = new ArrayDeque<>();
private final Queue<String> receivedStsCreds = new ArrayDeque<>();
private final Queue<String> receivedZoneValues = new ArrayDeque<>();
private final Queue<ResponseToSend> responsesToSend = new ArrayDeque<>();
private final Queue<String> oauth2Tokens = new ArrayDeque<>();
private final AtomicBoolean callEnded = new AtomicBoolean(true);
@Mock private MeshCertificateServiceGrpc.MeshCertificateServiceImplBase mockedMeshCaService;
@Mock private CertificateProvider.Watcher mockWatcher;
@Mock private BackoffPolicy.Provider backoffPolicyProvider;
@Mock private BackoffPolicy backoffPolicy;
@Spy private GoogleCredentials oauth2Creds;
@Mock private ScheduledExecutorService timeService;
@Mock private TimeProvider timeProvider;
private ManagedChannel channel;
private MeshCaCertificateProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
when(backoffPolicyProvider.get()).thenReturn(backoffPolicy);
when(backoffPolicy.nextBackoffNanos())
.thenReturn(DELAY_VALUES[0], DELAY_VALUES[1], DELAY_VALUES[2]);
doAnswer(
new Answer<AccessToken>() {
@Override
public AccessToken answer(InvocationOnMock invocation) throws Throwable {
return new AccessToken(
oauth2Tokens.poll(), new Date(System.currentTimeMillis() + 1000L));
}
})
.when(oauth2Creds)
.refreshAccessToken();
final String meshCaUri = InProcessServerBuilder.generateName();
MeshCertificateServiceGrpc.MeshCertificateServiceImplBase meshCaServiceImpl =
new MeshCertificateServiceGrpc.MeshCertificateServiceImplBase() {
@Override
public void createCertificate(
MeshCertificateRequest request,
StreamObserver<MeshCertificateResponse> responseObserver) {
assertThat(callEnded.get()).isTrue(); // ensure previous call was ended
callEnded.set(false);
Context.current()
.addListener(
new Context.CancellationListener() {
@Override
public void cancelled(Context context) {
callEnded.set(true);
}
},
MoreExecutors.directExecutor());
receivedRequests.offer(request);
ResponseToSend response = responsesToSend.poll();
if (response instanceof ResponseThrowable) {
responseObserver.onError(response.getThrowable());
} else if (response instanceof ResponseList) {
List<String> certChainInResponse = response.getList();
MeshCertificateResponse responseToSend =
MeshCertificateResponse.newBuilder()
.addAllCertChain(certChainInResponse)
.build();
responseObserver.onNext(responseToSend);
responseObserver.onCompleted();
} else {
callEnded.set(true);
}
}
};
mockedMeshCaService =
mock(
MeshCertificateServiceGrpc.MeshCertificateServiceImplBase.class,
delegatesTo(meshCaServiceImpl));
ServerInterceptor interceptor =
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
receivedStsCreds.offer(headers.get(KEY_FOR_AUTHORIZATION));
receivedZoneValues.offer(headers.get(MeshCaCertificateProvider.KEY_FOR_ZONE_INFO));
return next.startCall(call, headers);
}
};
cleanupRule.register(
InProcessServerBuilder.forName(meshCaUri)
.addService(mockedMeshCaService)
.intercept(interceptor)
.directExecutor()
.build()
.start());
channel =
cleanupRule.register(InProcessChannelBuilder.forName(meshCaUri).directExecutor().build());
MeshCaCertificateProvider.MeshCaChannelFactory channelFactory =
new MeshCaCertificateProvider.MeshCaChannelFactory() {
@Override
ManagedChannel createChannel(String serverUri) {
assertThat(serverUri).isEqualTo(meshCaUri);
return channel;
}
};
CertificateProvider.DistributorWatcher watcher = new CertificateProvider.DistributorWatcher();
watcher.addWatcher(mockWatcher); //
provider =
new MeshCaCertificateProvider(
watcher,
true,
meshCaUri,
ZONE,
TimeUnit.HOURS.toSeconds(9L),
2048,
"RSA",
"SHA256withRSA",
channelFactory,
backoffPolicyProvider,
RENEWAL_GRACE_PERIOD_SECONDS,
MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT,
oauth2Creds,
timeService,
timeProvider,
RPC_TIMEOUT_MILLIS);
}
@Test
public void startAndClose() {
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.start();
SynchronizationContext.ScheduledHandle savedScheduledHandle = provider.scheduledHandle;
assertThat(savedScheduledHandle).isNotNull();
assertThat(savedScheduledHandle.isPending()).isTrue();
verify(timeService, times(1))
.schedule(
any(Runnable.class),
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
eq(TimeUnit.SECONDS));
DistributorWatcher distWatcher = provider.getWatcher();
assertThat(distWatcher.downstreamWatchers).hasSize(1);
PrivateKey mockKey = mock(PrivateKey.class);
X509Certificate mockCert = mock(X509Certificate.class);
distWatcher.updateCertificate(mockKey, ImmutableList.of(mockCert));
distWatcher.updateTrustedRoots(ImmutableList.of(mockCert));
provider.close();
assertThat(provider.scheduledHandle).isNull();
assertThat(savedScheduledHandle.isPending()).isFalse();
assertThat(distWatcher.downstreamWatchers).isEmpty();
assertThat(distWatcher.getLastIdentityCert()).isNull();
}
@Test
public void startTwice_noException() {
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.start();
SynchronizationContext.ScheduledHandle savedScheduledHandle1 = provider.scheduledHandle;
provider.start();
SynchronizationContext.ScheduledHandle savedScheduledHandle2 = provider.scheduledHandle;
assertThat(savedScheduledHandle2).isNotSameInstanceAs(savedScheduledHandle1);
assertThat(savedScheduledHandle2.isPending()).isTrue();
}
@Test
public void getCertificate()
throws IOException, CertificateException, OperatorCreationException,
NoSuchAlgorithmException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend.offer(
new ResponseList(ImmutableList.of(
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
MeshCertificateRequest receivedReq = receivedRequests.poll();
assertThat(receivedReq.getValidity().getSeconds()).isEqualTo(TimeUnit.HOURS.toSeconds(9L));
// cannot decode CSR: just check the PEM format delimiters
String csr = receivedReq.getCsr();
assertThat(csr).startsWith("-----BEGIN NEW CERTIFICATE REQUEST-----");
verifyReceivedMetadataValues(1);
verify(timeService, times(1))
.schedule(
any(Runnable.class),
eq(
TimeUnit.MILLISECONDS.toSeconds(
CERT0_VALIDITY_MILLIS
- TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
eq(TimeUnit.SECONDS));
verifyMockWatcher();
}
@Test
public void getCertificate_withError()
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
eq(TimeUnit.SECONDS));
verifyReceivedMetadataValues(1);
}
@Test
public void getCertificate_withError_withExistingCert()
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
PrivateKey mockKey = mock(PrivateKey.class);
X509Certificate mockCert = mock(X509Certificate.class);
// have current cert expire in 3 hours from current time
long threeHoursFromNowMillis = TimeUnit.NANOSECONDS
.toMillis(CURRENT_TIME_NANOS + TimeUnit.HOURS.toNanos(3));
when(mockCert.getNotAfter()).thenReturn(new Date(threeHoursFromNowMillis));
provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
reset(mockWatcher);
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).onError(any(Status.class));
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(5400L),
eq(TimeUnit.SECONDS));
assertThat(provider.getWatcher().getLastIdentityCert()).isNotNull();
verifyReceivedMetadataValues(1);
}
@Test
public void getCertificate_withError_withExistingExpiredCert()
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
PrivateKey mockKey = mock(PrivateKey.class);
X509Certificate mockCert = mock(X509Certificate.class);
// have current cert expire in 3 seconds from current time
long threeSecondsFromNowMillis = TimeUnit.NANOSECONDS
.toMillis(CURRENT_TIME_NANOS + TimeUnit.SECONDS.toNanos(3));
when(mockCert.getNotAfter()).thenReturn(new Date(threeSecondsFromNowMillis));
provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
reset(mockWatcher);
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
eq(TimeUnit.SECONDS));
assertThat(provider.getWatcher().getLastIdentityCert()).isNull();
verifyReceivedMetadataValues(1);
}
@Test
public void getCertificate_retriesWithErrors()
throws IOException, CertificateException, OperatorCreationException,
NoSuchAlgorithmException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
oauth2Tokens.offer(TEST_STS_TOKEN + "1");
oauth2Tokens.offer(TEST_STS_TOKEN + "2");
responsesToSend.offer(new ResponseThrowable(new StatusRuntimeException(Status.UNKNOWN)));
responsesToSend.offer(
new ResponseThrowable(
new Exception(new StatusRuntimeException(Status.RESOURCE_EXHAUSTED))));
responsesToSend.offer(new ResponseList(ImmutableList.of(
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
TestScheduledFuture<?> scheduledFutureSleep = new TestScheduledFuture<>();
doReturn(scheduledFutureSleep).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
provider.refreshCertificate();
assertThat(receivedRequests.size()).isEqualTo(3);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(TimeUnit.MILLISECONDS.toSeconds(
CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
eq(TimeUnit.SECONDS));
verifyRetriesWithBackoff(scheduledFutureSleep, 2);
verifyMockWatcher();
verifyReceivedMetadataValues(3);
}
@Test
public void getCertificate_retriesWithTimeouts()
throws IOException, CertificateException, OperatorCreationException,
NoSuchAlgorithmException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
oauth2Tokens.offer(TEST_STS_TOKEN + "1");
oauth2Tokens.offer(TEST_STS_TOKEN + "2");
oauth2Tokens.offer(TEST_STS_TOKEN + "3");
responsesToSend.offer(new ResponseToSend());
responsesToSend.offer(new ResponseToSend());
responsesToSend.offer(new ResponseToSend());
responsesToSend.offer(new ResponseList(ImmutableList.of(
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
TestScheduledFuture<?> scheduledFuture = new TestScheduledFuture<>();
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
TestScheduledFuture<?> scheduledFutureSleep = new TestScheduledFuture<>();
doReturn(scheduledFutureSleep).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
provider.refreshCertificate();
assertThat(receivedRequests.size()).isEqualTo(4);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(TimeUnit.MILLISECONDS.toSeconds(
CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
eq(TimeUnit.SECONDS));
verifyRetriesWithBackoff(scheduledFutureSleep, 3);
verifyMockWatcher();
verifyReceivedMetadataValues(4);
}
private void verifyRetriesWithBackoff(
TestScheduledFuture<?> scheduledFutureSleep, int numOfRetries) {
for (int i = 0; i < numOfRetries; i++) {
long delayValue = DELAY_VALUES[i];
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(delayValue),
eq(TimeUnit.NANOSECONDS));
assertThat(scheduledFutureSleep.calls.get(i).timeout).isEqualTo(delayValue);
assertThat(scheduledFutureSleep.calls.get(i).unit).isEqualTo(TimeUnit.NANOSECONDS);
}
}
private void verifyMockWatcher() throws IOException, CertificateException {
ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1))
.updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
List<X509Certificate> certChain = certChainCaptor.getValue();
assertThat(certChain).hasSize(3);
assertThat(certChain.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_0_PEM_FILE));
assertThat(certChain.get(1))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_1_PEM_FILE));
assertThat(certChain.get(2))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
List<X509Certificate> roots = rootsCaptor.getValue();
assertThat(roots).hasSize(1);
assertThat(roots.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
verify(mockWatcher, never()).onError(any(Status.class));
}
private void verifyReceivedMetadataValues(int count) {
assertThat(receivedStsCreds).hasSize(count);
assertThat(receivedZoneValues).hasSize(count);
for (int i = 0; i < count; i++) {
assertThat(receivedStsCreds.poll()).isEqualTo("Bearer " + TEST_STS_TOKEN + i);
assertThat(receivedZoneValues.poll()).isEqualTo("location=locations/us-west2-a");
}
}
static class TestScheduledFuture<V> implements ScheduledFuture<V> {
static class Record {
long timeout;
TimeUnit unit;
Record(long timeout, TimeUnit unit) {
this.timeout = timeout;
this.unit = unit;
}
}
ArrayList<Record> calls = new ArrayList<>();
@Override
public long getDelay(TimeUnit unit) {
return 0;
}
@Override
public int compareTo(Delayed o) {
return 0;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return false;
}
@Override
public boolean isCancelled() {
return false;
}
@Override
public boolean isDone() {
return false;
}
@Override
public V get() {
return null;
}
@Override
public V get(long timeout, TimeUnit unit) {
calls.add(new Record(timeout, unit));
return null;
}
}
}