diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java index ac5ca9711f..12eb6f6573 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java @@ -40,8 +40,6 @@ public final class CertificateProviderRegistry { instance = new CertificateProviderRegistry(); // TODO(sanjaypujare): replace with Java's SPI mechanism and META-INF resource instance.register(new FileWatcherCertificateProviderProvider()); - instance.register(new DynamicReloadingCertificateProviderProvider()); - instance.register(new MeshCaCertificateProviderProvider()); } return instance; } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProvider.java deleted file mode 100644 index af7324f258..0000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProvider.java +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; -import io.grpc.InternalLogId; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sds.trust.CertificateUtils; -import java.io.File; -import java.io.FileInputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; -import java.util.Arrays; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.logging.Level; -import java.util.logging.Logger; - -/** Implementation of {@link CertificateProvider} for dynamic reloading cert provider. */ -final class DynamicReloadingCertificateProvider extends CertificateProvider { - private static final Logger logger = - Logger.getLogger(DynamicReloadingCertificateProvider.class.getName()); - - private final SynchronizationContext syncContext; - private final ScheduledExecutorService scheduledExecutorService; - private final TimeProvider timeProvider; - private final Path directory; - private final String certFile; - private final String privateKeyFile; - private final String trustFile; - private final long refreshIntervalInSeconds; - @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle; - private Path lastModifiedTarget; - - DynamicReloadingCertificateProvider( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String directory, - String certFile, - String privateKeyFile, - String trustFile, - long refreshIntervalInSeconds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider) { - super(watcher, notifyCertUpdates); - this.scheduledExecutorService = - checkNotNull(scheduledExecutorService, "scheduledExecutorService"); - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - this.directory = Paths.get(checkNotNull(directory, "diretory")); - this.certFile = checkNotNull(certFile, "certFile"); - this.privateKeyFile = checkNotNull(privateKeyFile, "privateKeyFile"); - this.trustFile = checkNotNull(trustFile, "trustFile"); - this.refreshIntervalInSeconds = refreshIntervalInSeconds; - this.syncContext = createSynchronizationContext(directory); - } - - private SynchronizationContext createSynchronizationContext(String details) { - final InternalLogId logId = - InternalLogId.allocate("DynamicReloadingCertificateProvider", details); - return new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - private boolean panicMode; - - @Override - public void uncaughtException(Thread t, Throwable e) { - logger.log( - Level.SEVERE, - "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", - e); - panic(e); - } - - void panic(final Throwable t) { - if (panicMode) { - // Preserve the first panic information - return; - } - panicMode = true; - close(); - } - }); - } - - @Override - public void start() { - scheduleNextRefreshCertificate(/* delayInSeconds= */0); - } - - @Override - public void close() { - if (scheduledHandle != null) { - scheduledHandle.cancel(); - scheduledHandle = null; - } - getWatcher().close(); - } - - private void scheduleNextRefreshCertificate(long delayInSeconds) { - RefreshCertificateTask runnable = new RefreshCertificateTask(); - scheduledHandle = - syncContext.schedule(runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService); - } - - @VisibleForTesting - void checkAndReloadCertificates() { - try { - Path targetPath = Files.readSymbolicLink(directory); - if (targetPath.equals(lastModifiedTarget)) { - return; - } - try (FileInputStream privateKeyStream = - new FileInputStream(new File(targetPath.toFile(), privateKeyFile)); - FileInputStream certsStream = - new FileInputStream(new File(targetPath.toFile(), certFile)); - FileInputStream caCertsStream = - new FileInputStream(new File(targetPath.toFile(), trustFile))) { - PrivateKey privateKey = CertificateUtils.getPrivateKey(privateKeyStream); - X509Certificate[] certs = CertificateUtils.toX509Certificates(certsStream); - X509Certificate[] caCerts = CertificateUtils.toX509Certificates(caCertsStream); - getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); - getWatcher().updateTrustedRoots(Arrays.asList(caCerts)); - } - lastModifiedTarget = targetPath; - } catch (Throwable t) { - generateErrorIfCurrentCertExpired(t); - } finally { - scheduleNextRefreshCertificate(refreshIntervalInSeconds); - } - } - - private void generateErrorIfCurrentCertExpired(Throwable t) { - X509Certificate currentCert = getWatcher().getLastIdentityCert(); - if (currentCert != null) { - long delaySeconds = computeDelaySecondsToCertExpiry(currentCert); - if (delaySeconds > refreshIntervalInSeconds) { - logger.log(Level.FINER, "reload certificate error", t); - return; - } - // The current cert is going to expire in less than {@link refreshIntervalInSeconds} - // Clear the current cert and notify our watchers thru {@code onError} - getWatcher().clearValues(); - } - getWatcher().onError(Status.fromThrowable(t)); - } - - @SuppressWarnings("JdkObsolete") - private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) { - checkNotNull(lastCert, "lastCert"); - return TimeUnit.NANOSECONDS.toSeconds( - TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - - timeProvider.currentTimeNanos()); - } - - @VisibleForTesting - class RefreshCertificateTask implements Runnable { - @Override - public void run() { - checkAndReloadCertificates(); - } - } - - abstract static class Factory { - private static final Factory DEFAULT_INSTANCE = - new Factory() { - @Override - DynamicReloadingCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String directory, - String certFile, - String privateKeyFile, - String trustFile, - long refreshIntervalInSeconds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider) { - return new DynamicReloadingCertificateProvider( - watcher, - notifyCertUpdates, - directory, - certFile, - privateKeyFile, - trustFile, - refreshIntervalInSeconds, - scheduledExecutorService, - timeProvider); - } - }; - - static Factory getInstance() { - return DEFAULT_INSTANCE; - } - - abstract DynamicReloadingCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String directory, - String certFile, - String privateKeyFile, - String trustFile, - long refreshIntervalInSeconds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider); - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProvider.java deleted file mode 100644 index 0d1cf50922..0000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProvider.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.grpc.internal.JsonUtil; -import io.grpc.internal.TimeProvider; -import java.util.Map; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; - -/** - * Provider of {@link DynamicReloadingCertificateProvider}s. - */ -final class DynamicReloadingCertificateProviderProvider implements CertificateProviderProvider { - - private static final String DIRECTORY_KEY = "directory"; - private static final String CERT_FILE_KEY = "certificate-file"; - private static final String KEY_FILE_KEY = "private-key-file"; - private static final String ROOT_FILE_KEY = "ca-certificate-file"; - private static final String REFRESH_INTERVAL_KEY = "refresh-interval"; - - @VisibleForTesting static final long REFRESH_INTERVAL_DEFAULT = 600L; - - - static final String DYNAMIC_RELOADING_PROVIDER_NAME = "gke-cas-certs"; - - final DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory; - private final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory; - private final TimeProvider timeProvider; - - DynamicReloadingCertificateProviderProvider() { - this( - DynamicReloadingCertificateProvider.Factory.getInstance(), - ScheduledExecutorServiceFactory.DEFAULT_INSTANCE, - TimeProvider.SYSTEM_TIME_PROVIDER); - } - - @VisibleForTesting - DynamicReloadingCertificateProviderProvider( - DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory, - ScheduledExecutorServiceFactory scheduledExecutorServiceFactory, - TimeProvider timeProvider) { - this.dynamicReloadingCertificateProviderFactory = dynamicReloadingCertificateProviderFactory; - this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory; - this.timeProvider = timeProvider; - } - - @Override - public String getName() { - return DYNAMIC_RELOADING_PROVIDER_NAME; - } - - @Override - public CertificateProvider createCertificateProvider( - Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) { - - Config configObj = validateAndTranslateConfig(config); - return dynamicReloadingCertificateProviderFactory.create( - watcher, - notifyCertUpdates, - configObj.directory, - configObj.certFile, - configObj.keyFile, - configObj.rootFile, - configObj.refrehInterval, - scheduledExecutorServiceFactory.create(), - timeProvider); - } - - private static String checkForNullAndGet(Map map, String key) { - return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config"); - } - - private static Config validateAndTranslateConfig(Object config) { - checkArgument(config instanceof Map, "Only Map supported for config"); - @SuppressWarnings("unchecked") Map map = (Map)config; - - Config configObj = new Config(); - configObj.directory = checkForNullAndGet(map, DIRECTORY_KEY); - configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY); - configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY); - configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); - configObj.refrehInterval = JsonUtil.getNumberAsLong(map, REFRESH_INTERVAL_KEY); - if (configObj.refrehInterval == null) { - configObj.refrehInterval = REFRESH_INTERVAL_DEFAULT; - } - return configObj; - } - - abstract static class ScheduledExecutorServiceFactory { - - private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE = - new ScheduledExecutorServiceFactory() { - - @Override - ScheduledExecutorService create() { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder() - .setNameFormat("dynamicReloading" + "-%d") - .setDaemon(true) - .build()); - } - }; - - abstract ScheduledExecutorService create(); - } - - /** POJO class for storing various config values. */ - @VisibleForTesting - static class Config { - String directory; - String certFile; - String keyFile; - String rootFile; - Long refrehInterval; - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java deleted file mode 100644 index dee649a613..0000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java +++ /dev/null @@ -1,500 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.Status.Code.ABORTED; -import static io.grpc.Status.Code.CANCELLED; -import static io.grpc.Status.Code.DEADLINE_EXCEEDED; -import static io.grpc.Status.Code.INTERNAL; -import static io.grpc.Status.Code.RESOURCE_EXHAUSTED; -import static io.grpc.Status.Code.UNAVAILABLE; -import static io.grpc.Status.Code.UNKNOWN; -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.auth.oauth2.GoogleCredentials; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import com.google.protobuf.Duration; -import com.google.security.meshca.v1.MeshCertificateRequest; -import com.google.security.meshca.v1.MeshCertificateResponse; -import com.google.security.meshca.v1.MeshCertificateServiceGrpc; -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.ForwardingClientCall; -import io.grpc.Grpc; -import io.grpc.InternalLogId; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.TlsChannelCredentials; -import io.grpc.auth.MoreCallCredentials; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sds.trust.CertificateUtils; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.StringWriter; -import java.security.KeyPair; -import java.security.KeyPairGenerator; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.logging.Level; -import java.util.logging.Logger; -import javax.security.auth.x500.X500Principal; -import org.bouncycastle.openssl.jcajce.JcaPEMWriter; -import org.bouncycastle.operator.ContentSigner; -import org.bouncycastle.operator.OperatorCreationException; -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; -import org.bouncycastle.pkcs.PKCS10CertificationRequest; -import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder; -import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder; -import org.bouncycastle.util.io.pem.PemObject; - -/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */ -final class MeshCaCertificateProvider extends CertificateProvider { - private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName()); - - MeshCaCertificateProvider( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String meshCaUrl, - String zone, - long validitySeconds, - int keySize, - String unused, //TODO(sanjaypujare): to remove during refactoring - String signatureAlg, MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - long renewalGracePeriodSeconds, - int maxRetryAttempts, - GoogleCredentials oauth2Creds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider, - long rpcTimeoutMillis) { - super(watcher, notifyCertUpdates); - this.meshCaUrl = checkNotNull(meshCaUrl, "meshCaUrl"); - checkArgument( - validitySeconds > INITIAL_DELAY_SECONDS, - "validitySeconds must be greater than " + INITIAL_DELAY_SECONDS); - this.validitySeconds = validitySeconds; - this.keySize = keySize; - this.signatureAlg = checkNotNull(signatureAlg, "signatureAlg"); - this.meshCaChannelFactory = checkNotNull(meshCaChannelFactory, "meshCaChannelFactory"); - this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - checkArgument( - renewalGracePeriodSeconds > 0L && renewalGracePeriodSeconds < validitySeconds, - "renewalGracePeriodSeconds should be between 0 and " + validitySeconds); - this.renewalGracePeriodSeconds = renewalGracePeriodSeconds; - checkArgument(maxRetryAttempts >= 0, "maxRetryAttempts must be >= 0"); - this.maxRetryAttempts = maxRetryAttempts; - this.oauth2Creds = checkNotNull(oauth2Creds, "oauth2Creds"); - this.scheduledExecutorService = - checkNotNull(scheduledExecutorService, "scheduledExecutorService"); - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - this.headerInterceptor = new ZoneInfoClientInterceptor(checkNotNull(zone, "zone")); - this.syncContext = createSynchronizationContext(meshCaUrl); - this.rpcTimeoutMillis = rpcTimeoutMillis; - } - - private SynchronizationContext createSynchronizationContext(String details) { - final InternalLogId logId = InternalLogId.allocate("MeshCaCertificateProvider", details); - return new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - private boolean panicMode; - - @Override - public void uncaughtException(Thread t, Throwable e) { - logger.log( - Level.SEVERE, - "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", - e); - panic(e); - } - - void panic(final Throwable t) { - if (panicMode) { - // Preserve the first panic information - return; - } - panicMode = true; - close(); - } - }); - } - - @Override - public void start() { - scheduleNextRefreshCertificate(INITIAL_DELAY_SECONDS); - } - - @Override - public void close() { - if (scheduledHandle != null) { - scheduledHandle.cancel(); - scheduledHandle = null; - } - getWatcher().close(); - } - - private void scheduleNextRefreshCertificate(long delayInSeconds) { - if (scheduledHandle != null && scheduledHandle.isPending()) { - logger.log(Level.SEVERE, "Pending task found: inconsistent state in scheduledHandle!"); - scheduledHandle.cancel(); - } - RefreshCertificateTask runnable = new RefreshCertificateTask(); - scheduledHandle = syncContext.schedule( - runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService); - } - - @VisibleForTesting - void refreshCertificate() - throws NoSuchAlgorithmException, IOException, OperatorCreationException { - long refreshDelaySeconds = computeRefreshSecondsFromCurrentCertExpiry(); - ManagedChannel channel = meshCaChannelFactory.createChannel(meshCaUrl); - try { - String uniqueReqIdForAllRetries = UUID.randomUUID().toString(); - Duration duration = Duration.newBuilder().setSeconds(validitySeconds).build(); - KeyPair keyPair = generateKeyPair(); - String csr = generateCsr(keyPair); - MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub = - createStubToMeshCa(channel); - List x509Chain = makeRequestWithRetries(stub, uniqueReqIdForAllRetries, - duration, csr); - if (x509Chain != null) { - refreshDelaySeconds = - computeDelaySecondsToCertExpiry(x509Chain.get(0)) - renewalGracePeriodSeconds; - getWatcher().updateCertificate(keyPair.getPrivate(), x509Chain); - getWatcher().updateTrustedRoots(ImmutableList.of(x509Chain.get(x509Chain.size() - 1))); - } - } finally { - shutdownChannel(channel); - scheduleNextRefreshCertificate(refreshDelaySeconds); - } - } - - private MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub createStubToMeshCa( - ManagedChannel channel) { - return MeshCertificateServiceGrpc - .newBlockingStub(channel) - .withCallCredentials(MoreCallCredentials.from(oauth2Creds)) - .withInterceptors(headerInterceptor); - } - - private List makeRequestWithRetries( - MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub, - String reqId, - Duration duration, - String csr) { - MeshCertificateRequest request = - MeshCertificateRequest.newBuilder() - .setValidity(duration) - .setCsr(csr) - .setRequestId(reqId) - .build(); - - BackoffPolicy backoffPolicy = backoffPolicyProvider.get(); - Throwable lastException = null; - for (int i = 0; i <= maxRetryAttempts; i++) { - try { - MeshCertificateResponse response = - stub.withDeadlineAfter(rpcTimeoutMillis, TimeUnit.MILLISECONDS) - .createCertificate(request); - return getX509CertificatesFromResponse(response); - } catch (Throwable t) { - if (!retriable(t)) { - generateErrorIfCurrentCertExpired(t); - return null; - } - lastException = t; - sleepForNanos(backoffPolicy.nextBackoffNanos()); - } - } - generateErrorIfCurrentCertExpired(lastException); - return null; - } - - private void sleepForNanos(long nanos) { - ScheduledFuture future = scheduledExecutorService.schedule(new Runnable() { - @Override - public void run() { - // do nothing - } - }, nanos, TimeUnit.NANOSECONDS); - try { - future.get(nanos, TimeUnit.NANOSECONDS); - } catch (InterruptedException ie) { - logger.log(Level.SEVERE, "Inside sleep", ie); - Thread.currentThread().interrupt(); - } catch (ExecutionException | TimeoutException ex) { - logger.log(Level.SEVERE, "Inside sleep", ex); - } - } - - private static boolean retriable(Throwable t) { - return RETRIABLE_CODES.contains(Status.fromThrowable(t).getCode()); - } - - private void generateErrorIfCurrentCertExpired(Throwable t) { - X509Certificate currentCert = getWatcher().getLastIdentityCert(); - if (currentCert != null) { - long delaySeconds = computeDelaySecondsToCertExpiry(currentCert); - if (delaySeconds > INITIAL_DELAY_SECONDS) { - return; - } - getWatcher().clearValues(); - } - getWatcher().onError(Status.fromThrowable(t)); - } - - private KeyPair generateKeyPair() throws NoSuchAlgorithmException { - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); - keyPairGenerator.initialize(keySize); - return keyPairGenerator.generateKeyPair(); - } - - private String generateCsr(KeyPair pair) throws IOException, OperatorCreationException { - PKCS10CertificationRequestBuilder p10Builder = - new JcaPKCS10CertificationRequestBuilder( - new X500Principal("CN=EXAMPLE.COM"), pair.getPublic()); - JcaContentSignerBuilder csBuilder = new JcaContentSignerBuilder(signatureAlg); - ContentSigner signer = csBuilder.build(pair.getPrivate()); - PKCS10CertificationRequest csr = p10Builder.build(signer); - PemObject pemObject = new PemObject("NEW CERTIFICATE REQUEST", csr.getEncoded()); - try (StringWriter str = new StringWriter()) { - try (JcaPEMWriter pemWriter = new JcaPEMWriter(str)) { - pemWriter.writeObject(pemObject); - } - return str.toString(); - } - } - - /** Compute refresh interval as half of interval to current cert expiry. */ - private long computeRefreshSecondsFromCurrentCertExpiry() { - X509Certificate lastCert = getWatcher().getLastIdentityCert(); - if (lastCert == null) { - return INITIAL_DELAY_SECONDS; - } - long delayToCertExpirySeconds = computeDelaySecondsToCertExpiry(lastCert) / 2; - return Math.max(delayToCertExpirySeconds, INITIAL_DELAY_SECONDS); - } - - @SuppressWarnings("JdkObsolete") - private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) { - checkNotNull(lastCert, "lastCert"); - return TimeUnit.NANOSECONDS.toSeconds( - TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - timeProvider - .currentTimeNanos()); - } - - private static void shutdownChannel(ManagedChannel channel) { - channel.shutdown(); - try { - channel.awaitTermination(10, TimeUnit.SECONDS); - } catch (InterruptedException ex) { - logger.log(Level.SEVERE, "awaiting channel Termination", ex); - channel.shutdownNow(); - Thread.currentThread().interrupt(); - } - } - - private List getX509CertificatesFromResponse( - MeshCertificateResponse response) throws CertificateException, IOException { - List certChain = response.getCertChainList(); - List x509Chain = new ArrayList<>(certChain.size()); - for (String certString : certChain) { - try (ByteArrayInputStream bais = new ByteArrayInputStream(certString.getBytes(UTF_8))) { - x509Chain.add(CertificateUtils.toX509Certificate(bais)); - } - } - return x509Chain; - } - - @VisibleForTesting - class RefreshCertificateTask implements Runnable { - @Override - public void run() { - try { - refreshCertificate(); - } catch (NoSuchAlgorithmException | OperatorCreationException | IOException ex) { - logger.log(Level.SEVERE, "refreshing certificate", ex); - } - } - } - - /** Factory for creating channels to MeshCA sever. */ - abstract static class MeshCaChannelFactory { - - private static final MeshCaChannelFactory DEFAULT_INSTANCE = - new MeshCaChannelFactory() { - - /** Creates a channel to the URL in the given list. */ - @Override - ManagedChannel createChannel(String serverUri) { - checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!"); - logger.log(Level.INFO, "Creating channel to {0}", serverUri); - - return Grpc.newChannelBuilder(serverUri, TlsChannelCredentials.create()) - .keepAliveTime(1, TimeUnit.MINUTES) - .build(); - } - }; - - static MeshCaChannelFactory getInstance() { - return DEFAULT_INSTANCE; - } - - /** - * Creates a channel to the server. - */ - abstract ManagedChannel createChannel(String serverUri); - } - - /** Factory for creating channels to MeshCA sever. */ - abstract static class Factory { - private static final Factory DEFAULT_INSTANCE = - new Factory() { - - @Override - MeshCaCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String meshCaUrl, - String zone, - long validitySeconds, - int keySize, - String alg, - String signatureAlg, - MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - long renewalGracePeriodSeconds, - int maxRetryAttempts, - GoogleCredentials oauth2Creds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider, - long rpcTimeoutMillis) { - return new MeshCaCertificateProvider( - watcher, - notifyCertUpdates, - meshCaUrl, - zone, - validitySeconds, - keySize, - alg, - signatureAlg, - meshCaChannelFactory, - backoffPolicyProvider, - renewalGracePeriodSeconds, - maxRetryAttempts, - oauth2Creds, - scheduledExecutorService, - timeProvider, - rpcTimeoutMillis); - } - }; - - static Factory getInstance() { - return DEFAULT_INSTANCE; - } - - abstract MeshCaCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String meshCaUrl, - String zone, - long validitySeconds, - int keySize, - String alg, - String signatureAlg, - MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - long renewalGracePeriodSeconds, - int maxRetryAttempts, - GoogleCredentials oauth2Creds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider, - long rpcTimeoutMillis); - } - - private class ZoneInfoClientInterceptor implements ClientInterceptor { - private final String zone; - - ZoneInfoClientInterceptor(String zone) { - this.zone = zone; - } - - @Override - public ClientCall interceptCall( - MethodDescriptor method, CallOptions callOptions, Channel next) { - return new ForwardingClientCall.SimpleForwardingClientCall( - next.newCall(method, callOptions)) { - - @Override - public void start(Listener responseListener, Metadata headers) { - headers.put(KEY_FOR_ZONE_INFO, "location=locations/" + zone); - super.start(responseListener, headers); - } - }; - } - } - - @VisibleForTesting - static final Metadata.Key KEY_FOR_ZONE_INFO = - Metadata.Key.of("x-goog-request-params", Metadata.ASCII_STRING_MARSHALLER); - @VisibleForTesting - static final long INITIAL_DELAY_SECONDS = 4L; - - private static final EnumSet RETRIABLE_CODES = - EnumSet.of( - CANCELLED, - UNKNOWN, - DEADLINE_EXCEEDED, - RESOURCE_EXHAUSTED, - ABORTED, - INTERNAL, - UNAVAILABLE); - - private final SynchronizationContext syncContext; - private final ScheduledExecutorService scheduledExecutorService; - private final int maxRetryAttempts; - private final ZoneInfoClientInterceptor headerInterceptor; - private final BackoffPolicy.Provider backoffPolicyProvider; - private final String meshCaUrl; - private final long validitySeconds; - private final long renewalGracePeriodSeconds; - private final int keySize; - private final String signatureAlg; - private final GoogleCredentials oauth2Creds; - private final TimeProvider timeProvider; - private final MeshCaChannelFactory meshCaChannelFactory; - @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle; - private final long rpcTimeoutMillis; -} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java deleted file mode 100644 index a605f15ae6..0000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; -import static io.grpc.internal.JsonUtil.getObject; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.JsonUtil; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sts.StsCredentials; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -/** - * Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may - * move this out of the internal package and make this an official API in the future. - */ -final class MeshCaCertificateProviderProvider implements CertificateProviderProvider { - - private static final String SERVER_CONFIG_KEY = "server"; - private static final String MESHCA_URL_KEY = "target_uri"; - private static final String RPC_TIMEOUT_SECONDS_KEY = "time_out"; - private static final String GKECLUSTER_URL_KEY = "location"; - private static final String CERT_VALIDITY_SECONDS_KEY = "certificate_lifetime"; - private static final String RENEWAL_GRACE_PERIOD_SECONDS_KEY = "renewal_grace_period"; - private static final String KEY_ALGO_KEY = "key_type"; // aka keyType - private static final String KEY_SIZE_KEY = "key_size"; - private static final String STS_SERVICE_KEY = "sts_service"; - private static final String TOKEN_EXCHANGE_SERVICE_KEY = "token_exchange_service"; - private static final String GKE_SA_JWT_LOCATION_KEY = "subject_token_path"; - - @VisibleForTesting static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com"; - @VisibleForTesting static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L; - @VisibleForTesting static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; - @VisibleForTesting static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; - @VisibleForTesting static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType - @VisibleForTesting static final int KEY_SIZE_DEFAULT = 2048; - @VisibleForTesting static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA"; - @VisibleForTesting static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3; - @VisibleForTesting - static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken"; - - @VisibleForTesting - static final long RPC_TIMEOUT_SECONDS = 10L; - - private static final Pattern CLUSTER_URL_PATTERN = Pattern - .compile(".*/projects/(.*)/(?:locations|zones)/(.*)/clusters/.*"); - - private static final String TRUST_DOMAIN_SUFFIX = ".svc.id.goog"; - private static final String AUDIENCE_PREFIX = "identitynamespace:"; - static final String MESH_CA_NAME = "meshCA"; - - final StsCredentials.Factory stsCredentialsFactory; - final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory; - final BackoffPolicy.Provider backoffPolicyProvider; - final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory; - final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory; - final TimeProvider timeProvider; - - MeshCaCertificateProviderProvider() { - this( - StsCredentials.Factory.getInstance(), - MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(), - new ExponentialBackoffPolicy.Provider(), - MeshCaCertificateProvider.Factory.getInstance(), - ScheduledExecutorServiceFactory.DEFAULT_INSTANCE, - TimeProvider.SYSTEM_TIME_PROVIDER); - } - - @VisibleForTesting - MeshCaCertificateProviderProvider( - StsCredentials.Factory stsCredentialsFactory, - MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory, - ScheduledExecutorServiceFactory scheduledExecutorServiceFactory, - TimeProvider timeProvider) { - this.stsCredentialsFactory = stsCredentialsFactory; - this.meshCaChannelFactory = meshCaChannelFactory; - this.backoffPolicyProvider = backoffPolicyProvider; - this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory; - this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory; - this.timeProvider = timeProvider; - } - - @Override - public String getName() { - return MESH_CA_NAME; - } - - @Override - public CertificateProvider createCertificateProvider( - Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) { - - Config configObj = validateAndTranslateConfig(config); - - // Construct audience from project and gkeClusterUrl - String audience = - AUDIENCE_PREFIX + configObj.project + TRUST_DOMAIN_SUFFIX + ":" + configObj.gkeClusterUrl; - StsCredentials stsCredentials = stsCredentialsFactory - .create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation); - - return meshCaCertificateProviderFactory.create( - watcher, - notifyCertUpdates, - configObj.meshCaUrl, - configObj.zone, - configObj.certValiditySeconds, - configObj.keySize, - configObj.keyAlgo, - configObj.signatureAlgo, - meshCaChannelFactory, - backoffPolicyProvider, - configObj.renewalGracePeriodSeconds, - configObj.maxRetryAttempts, - stsCredentials, - scheduledExecutorServiceFactory.create(configObj.meshCaUrl), - timeProvider, - TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)); - } - - private static Config validateAndTranslateConfig(Object config) { - // TODO(sanjaypujare): add support for string, struct proto etc - checkArgument(config instanceof Map, "Only Map supported for config"); - @SuppressWarnings("unchecked") Map map = (Map)config; - - Config configObj = new Config(); - extractMeshCaServerConfig(configObj, getObject(map, SERVER_CONFIG_KEY)); - configObj.certValiditySeconds = - getSeconds( - JsonUtil.getObject(map, CERT_VALIDITY_SECONDS_KEY), CERT_VALIDITY_SECONDS_DEFAULT); - configObj.renewalGracePeriodSeconds = - getSeconds( - JsonUtil.getObject(map, RENEWAL_GRACE_PERIOD_SECONDS_KEY), - RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT); - String keyType = JsonUtil.getString(map, KEY_ALGO_KEY); - checkArgument( - keyType == null || keyType.equals(KEY_ALGO_DEFAULT), "key_type can only be null or 'RSA'"); - // TODO: remove signatureAlgo, keyType (or keyAlgo), maxRetryAttempts - configObj.maxRetryAttempts = MAX_RETRY_ATTEMPTS_DEFAULT; - configObj.keyAlgo = KEY_ALGO_DEFAULT; - configObj.signatureAlgo = SIGNATURE_ALGO_DEFAULT; - configObj.keySize = JsonUtil.getNumberAsInteger(map, KEY_SIZE_KEY); - if (configObj.keySize == null) { - configObj.keySize = KEY_SIZE_DEFAULT; - } - configObj.gkeClusterUrl = - checkNotNull(JsonUtil.getString(map, GKECLUSTER_URL_KEY), - "'location' is required in the config"); - parseProjectAndZone(configObj.gkeClusterUrl, configObj); - return configObj; - } - - private static void extractMeshCaServerConfig(Config configObj, Map serverConfig) { - // init with defaults - configObj.meshCaUrl = MESHCA_URL_DEFAULT; - configObj.rpcTimeoutSeconds = RPC_TIMEOUT_SECONDS_DEFAULT; - configObj.stsUrl = STS_URL_DEFAULT; - if (serverConfig != null) { - checkArgument( - "GRPC".equals(JsonUtil.getString(serverConfig, "api_type")), - "Only GRPC api_type supported"); - List> grpcServices = - checkNotNull( - JsonUtil.getListOfObjects(serverConfig, "grpc_services"), "grpc_services not found"); - for (Map grpcService : grpcServices) { - Map googleGrpcConfig = JsonUtil.getObject(grpcService, "google_grpc"); - if (googleGrpcConfig != null) { - String value = JsonUtil.getString(googleGrpcConfig, MESHCA_URL_KEY); - if (value != null) { - configObj.meshCaUrl = value; - } - Map channelCreds = - JsonUtil.getObject(googleGrpcConfig, "channel_credentials"); - if (channelCreds != null) { - Map googleDefaultChannelCreds = - checkNotNull( - JsonUtil.getObject(channelCreds, "google_default"), - "channel_credentials need to be google_default!"); - checkArgument( - googleDefaultChannelCreds.isEmpty(), - "google_default credentials contain illegal value"); - } - List> callCreds = - JsonUtil.getListOfObjects(googleGrpcConfig, "call_credentials"); - for (Map callCred : callCreds) { - Map stsCreds = JsonUtil.getObject(callCred, STS_SERVICE_KEY); - if (stsCreds != null) { - value = JsonUtil.getString(stsCreds, TOKEN_EXCHANGE_SERVICE_KEY); - if (value != null) { - configObj.stsUrl = value; - } - configObj.gkeSaJwtLocation = JsonUtil.getString(stsCreds, GKE_SA_JWT_LOCATION_KEY); - } - } - configObj.rpcTimeoutSeconds = - getSeconds( - JsonUtil.getObject(grpcService, RPC_TIMEOUT_SECONDS_KEY), - RPC_TIMEOUT_SECONDS_DEFAULT); - } - } - } - // check required value(s) - checkNotNull(configObj.gkeSaJwtLocation, "'subject_token_path' is required in the config"); - } - - private static Long getSeconds(Map duration, long defaultValue) { - if (duration != null) { - return JsonUtil.getNumberAsLong(duration, "seconds"); - } - return defaultValue; - } - - private static void parseProjectAndZone(String gkeClusterUrl, Config configObj) { - Matcher matcher = CLUSTER_URL_PATTERN.matcher(gkeClusterUrl); - checkState(matcher.find(), "gkeClusterUrl does not have correct format"); - checkState(matcher.groupCount() == 2, "gkeClusterUrl does not have project and location parts"); - configObj.project = matcher.group(1); - configObj.zone = matcher.group(2); - } - - abstract static class ScheduledExecutorServiceFactory { - - private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE = - new ScheduledExecutorServiceFactory() { - - @Override - ScheduledExecutorService create(String serverUri) { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder() - .setNameFormat("meshca-" + serverUri + "-%d") - .setDaemon(true) - .build()); - } - }; - - static ScheduledExecutorServiceFactory getInstance() { - return DEFAULT_INSTANCE; - } - - abstract ScheduledExecutorService create(String serverUri); - } - - /** POJO class for storing various config values. */ - @VisibleForTesting - static class Config { - String meshCaUrl; - Long rpcTimeoutSeconds; - String gkeClusterUrl; - Long certValiditySeconds; - Long renewalGracePeriodSeconds; - String keyAlgo; // aka keyType - Integer keySize; - String signatureAlgo; - Integer maxRetryAttempts; - String stsUrl; - String gkeSaJwtLocation; - String zone; - String project; - } -} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProviderTest.java deleted file mode 100644 index 2ee5fb3964..0000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProviderTest.java +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.internal.JsonParser; -import io.grpc.internal.TimeProvider; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Unit tests for {@link DynamicReloadingCertificateProviderProvider}. */ -@RunWith(JUnit4.class) -public class DynamicReloadingCertificateProviderProviderTest { - - @Mock DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory; - @Mock private DynamicReloadingCertificateProviderProvider.ScheduledExecutorServiceFactory - scheduledExecutorServiceFactory; - @Mock private TimeProvider timeProvider; - - private DynamicReloadingCertificateProviderProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - provider = - new DynamicReloadingCertificateProviderProvider( - dynamicReloadingCertificateProviderFactory, - scheduledExecutorServiceFactory, - timeProvider); - } - - @Test - public void providerRegisteredName() { - CertificateProviderProvider certProviderProvider = - CertificateProviderRegistry.getInstance() - .getProvider( - DynamicReloadingCertificateProviderProvider.DYNAMIC_RELOADING_PROVIDER_NAME); - assertThat(certProviderProvider) - .isInstanceOf(DynamicReloadingCertificateProviderProvider.class); - DynamicReloadingCertificateProviderProvider dynamicReloadingCertificateProviderProvider = - (DynamicReloadingCertificateProviderProvider) certProviderProvider; - assertThat( - dynamicReloadingCertificateProviderProvider.dynamicReloadingCertificateProviderFactory) - .isSameInstanceAs(DynamicReloadingCertificateProvider.Factory.getInstance()); - } - - @Test - public void createProvider_minimalConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MINIMAL_DYNAMIC_RELOADING_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(dynamicReloadingCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq("/var/run/gke-spiffe/certs/..data"), - eq("certificates.pem"), - eq("private_key.pem"), - eq("ca_certificates.pem"), - eq(600L), - eq(mockService), - eq(timeProvider)); - } - - @Test - public void createProvider_fullConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(FULL_DYNAMIC_RELOADING_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(dynamicReloadingCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq("/var/run/gke-spiffe/certs/..data1"), - eq("certificates2.pem"), - eq("private_key3.pem"), - eq("ca_certificates4.pem"), - eq(7890L), - eq(mockService), - eq(timeProvider)); - } - - @Test - public void createProvider_missingDir_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_DIR_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'directory' is required in the config"); - } - } - - @Test - public void createProvider_missingCert_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_CERT_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'certificate-file' is required in the config"); - } - } - - @Test - public void createProvider_missingKey_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_KEY_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'private-key-file' is required in the config"); - } - } - - @Test - public void createProvider_missingRoot_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_ROOT_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'ca-certificate-file' is required in the config"); - } - } - - private static final String MINIMAL_DYNAMIC_RELOADING_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"certificate-file\": \"certificates.pem\"," - + " \"private-key-file\": \"private_key.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String FULL_DYNAMIC_RELOADING_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data1\"," - + " \"certificate-file\": \"certificates2.pem\"," - + " \"private-key-file\": \"private_key3.pem\"," - + " \"ca-certificate-file\": \"ca_certificates4.pem\"," - + " \"refresh-interval\": 7890" - + " }"; - - private static final String MISSING_DIR_CONFIG = - "{\n" - + " \"certificate-file\": \"certificates.pem\"," - + " \"private-key-file\": \"private_key.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String MISSING_CERT_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"private-key-file\": \"private_key.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String MISSING_KEY_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"certificate-file\": \"certificates.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String MISSING_ROOT_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"certificate-file\": \"certificates.pem\"," - + " \"private-key-file\": \"private_key.pem\"" - + " }"; -} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderTest.java deleted file mode 100644 index a2f7cb7e32..0000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderTest.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.Status; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import java.io.File; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.NoSuchFileException; -import java.nio.file.Paths; -import java.security.PrivateKey; -import java.security.cert.CertificateException; -import java.security.cert.X509Certificate; -import java.util.List; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Unit tests for {@link DynamicReloadingCertificateProvider}. */ -@RunWith(JUnit4.class) -public class DynamicReloadingCertificateProviderTest { - private static final String CERT_FILE = "cert.pem"; - private static final String KEY_FILE = "key.pem"; - private static final String ROOT_FILE = "root.pem"; - - @Mock private CertificateProvider.Watcher mockWatcher; - @Mock private ScheduledExecutorService timeService; - @Mock private TimeProvider timeProvider; - - @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); - private String symlink; - - private DynamicReloadingCertificateProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - - DistributorWatcher watcher = new DistributorWatcher(); - watcher.addWatcher(mockWatcher); - - symlink = new File(tempFolder.getRoot(), "..data").getAbsolutePath(); - provider = - new DynamicReloadingCertificateProvider( - watcher, - true, - symlink, - CERT_FILE, - KEY_FILE, - ROOT_FILE, - 600L, - timeService, - timeProvider); - } - - private void populateTarget( - String certFile, String keyFile, String rootFile, boolean deleteExisting, boolean createNew) - throws IOException { - String target = tempFolder.newFolder().getAbsolutePath(); - if (certFile != null) { - certFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFile); - Files.copy(Paths.get(certFile), Paths.get(target, CERT_FILE)); - } - if (keyFile != null) { - keyFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFile); - Files.copy(Paths.get(keyFile), Paths.get(target, KEY_FILE)); - } - if (rootFile != null) { - rootFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFile); - Files.copy(Paths.get(rootFile), Paths.get(target, ROOT_FILE)); - } - if (deleteExisting) { - Files.delete(Paths.get(symlink)); - } - if (createNew) { - Files.createSymbolicLink(Paths.get(symlink), Paths.get(target)); - } - } - - @Test - public void getCertificateAndCheckUpdates() throws IOException, CertificateException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, true); - provider.checkAndReloadCertificates(); - verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); - verifyTimeServiceAndScheduledHandle(); - - reset(mockWatcher, timeService); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.scheduledHandle.cancel(); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(null, null, (String[]) null); - verifyTimeServiceAndScheduledHandle(); - - reset(mockWatcher, timeService); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.scheduledHandle.cancel(); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, true, true); - provider.checkAndReloadCertificates(); - verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); - verifyTimeServiceAndScheduledHandle(); - } - - @Test - public void getCertificate_initialMissingCertFile() throws IOException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, true); - when(timeProvider.currentTimeNanos()) - .thenReturn(TimeProvider.SYSTEM_TIME_PROVIDER.currentTimeNanos()); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(Status.Code.UNKNOWN, java.io.FileNotFoundException.class, "cert.pem"); - } - - @Test - public void getCertificate_missingSymlink() throws IOException { - commonErrorTest(null, null, null, true, false, NoSuchFileException.class, "..data"); - } - - @Test - public void getCertificate_missingCertFile() throws IOException { - commonErrorTest( - null, - CLIENT_KEY_FILE, - CA_PEM_FILE, - true, - true, - java.io.FileNotFoundException.class, - "cert.pem"); - } - - @Test - public void getCertificate_missingKeyFile() throws IOException { - commonErrorTest( - CLIENT_PEM_FILE, - null, - CA_PEM_FILE, - true, - true, - java.io.FileNotFoundException.class, - "key.pem"); - } - - @Test - public void getCertificate_badKeyFile() throws IOException { - commonErrorTest( - CLIENT_PEM_FILE, - SERVER_0_PEM_FILE, - CA_PEM_FILE, - true, - true, - java.security.KeyException.class, - "could not find a PKCS #8 private key in input stream"); - } - - @Test - public void getCertificate_missingRootFile() throws IOException { - commonErrorTest( - CLIENT_PEM_FILE, - CLIENT_KEY_FILE, - null, - true, - true, - java.io.FileNotFoundException.class, - "root.pem"); - } - - private void commonErrorTest( - String certFile, - String keyFile, - String rootFile, - boolean deleteExisting, - boolean createNew, - Class throwableType, - String... causeMessages) - throws IOException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, true); - provider.checkAndReloadCertificates(); - - reset(mockWatcher); - populateTarget(certFile, keyFile, rootFile, deleteExisting, createNew); - when(timeProvider.currentTimeNanos()) - .thenReturn( - TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); - provider.scheduledHandle.cancel(); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(null, null, (String[]) null); - - reset(mockWatcher, timeProvider); - when(timeProvider.currentTimeNanos()) - .thenReturn( - TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L)); - provider.scheduledHandle.cancel(); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(Status.Code.UNKNOWN, throwableType, causeMessages); - } - - private void verifyWatcherErrorUpdates( - Status.Code code, Class throwableType, String... causeMessages) { - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - if (code == null && throwableType == null && causeMessages == null) { - verify(mockWatcher, never()).onError(any(Status.class)); - } else { - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)).onError(statusCaptor.capture()); - Status status = statusCaptor.getValue(); - assertThat(status.getCode()).isEqualTo(code); - Throwable cause = status.getCause(); - assertThat(cause).isInstanceOf(throwableType); - for (String causeMessage : causeMessages) { - assertThat(cause).hasMessageThat().contains(causeMessage); - cause = cause.getCause(); - } - } - } - - private void verifyTimeServiceAndScheduledHandle() { - verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS)); - assertThat(provider.scheduledHandle).isNotNull(); - assertThat(provider.scheduledHandle.isPending()).isTrue(); - } - - private void verifyWatcherUpdates(String certPemFile, String rootPemFile) - throws IOException, CertificateException { - ArgumentCaptor> certChainCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)) - .updateCertificate(any(PrivateKey.class), certChainCaptor.capture()); - List certChain = certChainCaptor.getValue(); - assertThat(certChain).hasSize(1); - assertThat(certChain.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(certPemFile)); - - ArgumentCaptor> rootsCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture()); - List roots = rootsCaptor.getValue(); - assertThat(roots).hasSize(1); - assertThat(roots.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(rootPemFile)); - verify(mockWatcher, never()).onError(any(Status.class)); - } -} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java index 7ecc02a99e..4b22cfb4e3 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java @@ -45,8 +45,11 @@ import java.nio.file.Paths; import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Delayed; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; @@ -62,6 +65,10 @@ import org.mockito.MockitoAnnotations; /** Unit tests for {@link FileWatcherCertificateProvider}. */ @RunWith(JUnit4.class) public class FileWatcherCertificateProviderTest { + /** + * Expire time of cert SERVER_0_PEM_FILE. + */ + static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L; private static final String CERT_FILE = "cert.pem"; private static final String KEY_FILE = "key.pem"; private static final String ROOT_FILE = "root.pem"; @@ -126,8 +133,8 @@ public class FileWatcherCertificateProviderTest { @Test public void getCertificateAndCheckUpdates() throws IOException, CertificateException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -147,8 +154,8 @@ public class FileWatcherCertificateProviderTest { @Test public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -168,8 +175,8 @@ public class FileWatcherCertificateProviderTest { @Test public void closeDoesNotScheduleNext() throws IOException, CertificateException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -186,8 +193,8 @@ public class FileWatcherCertificateProviderTest { @Test public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -208,8 +215,8 @@ public class FileWatcherCertificateProviderTest { @Test public void certAndKeyFileUpdateOnly() throws IOException, CertificateException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -229,8 +236,8 @@ public class FileWatcherCertificateProviderTest { @Test public void getCertificate_initialMissingCertFile() throws IOException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -269,8 +276,8 @@ public class FileWatcherCertificateProviderTest { @Test public void getCertificate_missingRootFile() throws IOException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -283,7 +290,7 @@ public class FileWatcherCertificateProviderTest { when(timeProvider.currentTimeNanos()) .thenReturn( TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + CERT0_EXPIRY_TIME_MILLIS - 610_000L)); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem"); } @@ -299,8 +306,8 @@ public class FileWatcherCertificateProviderTest { int secondUpdateRootCount, String... causeMessages) throws IOException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -314,7 +321,7 @@ public class FileWatcherCertificateProviderTest { when(timeProvider.currentTimeNanos()) .thenReturn( TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + CERT0_EXPIRY_TIME_MILLIS - 610_000L)); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates( null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null); @@ -323,7 +330,7 @@ public class FileWatcherCertificateProviderTest { when(timeProvider.currentTimeNanos()) .thenReturn( TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L)); + CERT0_EXPIRY_TIME_MILLIS - 590_000L)); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates( Status.Code.UNKNOWN, @@ -392,4 +399,55 @@ public class FileWatcherCertificateProviderTest { verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); } } + + static class TestScheduledFuture implements ScheduledFuture { + + static class Record { + long timeout; + TimeUnit unit; + + Record(long timeout, TimeUnit unit) { + this.timeout = timeout; + this.unit = unit; + } + } + + ArrayList calls = new ArrayList<>(); + + @Override + public long getDelay(TimeUnit unit) { + return 0; + } + + @Override + public int compareTo(Delayed o) { + return 0; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public V get() { + return null; + } + + @Override + public V get(long timeout, TimeUnit unit) { + calls.add(new Record(timeout, unit)); + return null; + } + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java deleted file mode 100644 index 791d5a395c..0000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java +++ /dev/null @@ -1,409 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.certprovider.MeshCaCertificateProviderProvider.RPC_TIMEOUT_SECONDS; -import static org.junit.Assert.fail; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isNull; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.auth.oauth2.GoogleCredentials; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.JsonParser; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sts.StsCredentials; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Unit tests for {@link MeshCaCertificateProviderProvider}. */ -@RunWith(JUnit4.class) -public class MeshCaCertificateProviderProviderTest { - - public static final String EXPECTED_AUDIENCE = - "identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3"; - public static final String EXPECTED_AUDIENCE_V1BETA1_ZONE = - "identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1beta1/projects/test-project1/zones/test-zone2/clusters/test-cluster3"; - public static final String TMP_PATH_4 = "/tmp/path4"; - public static final String NON_DEFAULT_MESH_CA_URL = "nonDefaultMeshCaUrl"; - - @Mock - StsCredentials.Factory stsCredentialsFactory; - - @Mock - MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory; - - @Mock - BackoffPolicy.Provider backoffPolicyProvider; - - @Mock - MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory; - - @Mock - private MeshCaCertificateProviderProvider.ScheduledExecutorServiceFactory - scheduledExecutorServiceFactory; - - @Mock - private TimeProvider timeProvider; - - private MeshCaCertificateProviderProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - provider = - new MeshCaCertificateProviderProvider( - stsCredentialsFactory, - meshCaChannelFactory, - backoffPolicyProvider, - meshCaCertificateProviderFactory, - scheduledExecutorServiceFactory, - timeProvider); - } - - @Test - public void providerRegisteredName() { - CertificateProviderProvider certProviderProvider = CertificateProviderRegistry.getInstance() - .getProvider(MeshCaCertificateProviderProvider.MESH_CA_NAME); - assertThat(certProviderProvider).isInstanceOf(MeshCaCertificateProviderProvider.class); - MeshCaCertificateProviderProvider meshCaCertificateProviderProvider = - (MeshCaCertificateProviderProvider) certProviderProvider; - assertThat(meshCaCertificateProviderProvider.stsCredentialsFactory) - .isSameInstanceAs(StsCredentials.Factory.getInstance()); - assertThat(meshCaCertificateProviderProvider.meshCaChannelFactory) - .isSameInstanceAs(MeshCaCertificateProvider.MeshCaChannelFactory.getInstance()); - assertThat(meshCaCertificateProviderProvider.backoffPolicyProvider) - .isInstanceOf(ExponentialBackoffPolicy.Provider.class); - assertThat(meshCaCertificateProviderProvider.meshCaCertificateProviderFactory) - .isSameInstanceAs(MeshCaCertificateProvider.Factory.getInstance()); - } - - @Test - public void createProvider_minimalConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MINIMAL_MESHCA_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create( - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT))) - .thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(stsCredentialsFactory, times(1)) - .create( - eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT), - eq(EXPECTED_AUDIENCE), - eq("/tmp/path5")); - verify(meshCaCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT), - eq("test-zone2"), - eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT), - eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT), - eq(meshCaChannelFactory), - eq(backoffPolicyProvider), - eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT), - (GoogleCredentials) isNull(), - eq(mockService), - eq(timeProvider), - eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS))); - } - - @Test - public void createProvider_minimalConfig_v1beta1AndZone() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(V1BETA1_ZONE_MESHCA_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create( - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT))) - .thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(stsCredentialsFactory, times(1)) - .create( - eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT), - eq(EXPECTED_AUDIENCE_V1BETA1_ZONE), - eq("/tmp/path5")); - verify(meshCaCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT), - eq("test-zone2"), - eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT), - eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT), - eq(meshCaChannelFactory), - eq(backoffPolicyProvider), - eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT), - (GoogleCredentials) isNull(), - eq(mockService), - eq(timeProvider), - eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS))); - } - - @Test - public void createProvider_missingGkeUrl_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_GKE_CLUSTER_URL_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'location' is required in the config"); - } - } - - @Test - public void createProvider_missingSaJwtLocation_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_SAJWT_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'subject_token_path' is required in the config"); - } - } - - @Test - public void createProvider_missingProject_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MINIMAL_BAD_CLUSTER_URL_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (IllegalStateException ex) { - assertThat(ex).hasMessageThat().isEqualTo("gkeClusterUrl does not have correct format"); - } - } - - @Test - public void createProvider_badChannelCreds_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(BAD_CHANNEL_CREDS_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException ex) { - assertThat(ex).hasMessageThat().isEqualTo("channel_credentials need to be google_default!"); - } - } - - @Test - public void createProvider_nonDefaultFullConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(NONDEFAULT_MESHCA_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create(eq(NON_DEFAULT_MESH_CA_URL))) - .thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(stsCredentialsFactory, times(1)) - .create( - eq("test.sts.com"), - eq(EXPECTED_AUDIENCE), - eq(TMP_PATH_4)); - verify(meshCaCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq(NON_DEFAULT_MESH_CA_URL), - eq("test-zone2"), - eq(234567L), - eq(512), - eq("RSA"), - eq("SHA256withRSA"), - eq(meshCaChannelFactory), - eq(backoffPolicyProvider), - eq(4321L), - eq(3), - (GoogleCredentials) isNull(), - eq(mockService), - eq(timeProvider), - eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS))); - } - - private static final String NONDEFAULT_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"target_uri\": \"nonDefaultMeshCaUrl\",\n" - + " \"channel_credentials\": {\"google_default\": {}},\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"token_exchange_service\": \"test.sts.com\",\n" - + " \"subject_token_path\": \"/tmp/path4\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " },\n" // end google_grpc - + " \"time_out\": {\"seconds\": 12}\n" - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"certificate_lifetime\": {\"seconds\": 234567},\n" - + " \"renewal_grace_period\": {\"seconds\": 4321},\n" - + " \"key_type\": \"RSA\",\n" - + " \"key_size\": 512,\n" - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MINIMAL_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String V1BETA1_ZONE_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1beta1/projects/test-project1/zones/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MINIMAL_BAD_CLUSTER_URL_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MISSING_SAJWT_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MISSING_GKE_CLUSTER_URL_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"target_uri\": \"meshca.com\",\n" - + " \"channel_credentials\": {\"google_default\": {}},\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"token_exchange_service\": \"securetoken.googleapis.com\",\n" - + " \"subject_token_path\": \"/etc/secret/sajwt.token\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " },\n" // end google_grpc - + " \"time_out\": {\"seconds\": 10}\n" - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"certificate_lifetime\": {\"seconds\": 86400},\n" - + " \"renewal_grace_period\": {\"seconds\": 3600},\n" - + " \"key_type\": \"RSA\",\n" - + " \"key_size\": 2048\n" - + " }"; - - private static final String BAD_CHANNEL_CREDS_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"channel_credentials\": {\"mtls\": \"true\"},\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; -} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java deleted file mode 100644 index 4b4791dd0c..0000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java +++ /dev/null @@ -1,591 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static org.mockito.AdditionalAnswers.delegatesTo; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.auth.http.AuthHttpConstants; -import com.google.auth.oauth2.AccessToken; -import com.google.auth.oauth2.GoogleCredentials; -import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.MoreExecutors; -import com.google.security.meshca.v1.MeshCertificateRequest; -import com.google.security.meshca.v1.MeshCertificateResponse; -import com.google.security.meshca.v1.MeshCertificateServiceGrpc; -import io.grpc.Context; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; -import io.grpc.SynchronizationContext; -import io.grpc.inprocess.InProcessChannelBuilder; -import io.grpc.inprocess.InProcessServerBuilder; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.TimeProvider; -import io.grpc.stub.StreamObserver; -import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import java.io.IOException; -import java.security.NoSuchAlgorithmException; -import java.security.PrivateKey; -import java.security.cert.CertificateException; -import java.security.cert.X509Certificate; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Queue; -import java.util.concurrent.Delayed; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import org.bouncycastle.operator.OperatorCreationException; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.Spy; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -/** Unit tests for {@link MeshCaCertificateProvider}. */ -@RunWith(JUnit4.class) -public class MeshCaCertificateProviderTest { - - private static final String TEST_STS_TOKEN = "test-stsToken"; - private static final long RENEWAL_GRACE_PERIOD_SECONDS = TimeUnit.HOURS.toSeconds(1L); - private static final Metadata.Key KEY_FOR_AUTHORIZATION = - Metadata.Key.of(AuthHttpConstants.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER); - private static final String ZONE = "us-west2-a"; - private static final long START_DELAY = 200_000_000L; // 0.2 seconds - private static final long[] DELAY_VALUES = {START_DELAY, START_DELAY * 2, START_DELAY * 4}; - private static final long RPC_TIMEOUT_MILLIS = 1000L; - /** - * Expire time of cert SERVER_0_PEM_FILE. - */ - static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L; - /** - * Cert validity of 12 hours for the above cert. - */ - private static final long CERT0_VALIDITY_MILLIS = TimeUnit.MILLISECONDS - .convert(12, TimeUnit.HOURS); - /** - * Compute current time based on cert expiry and cert validity. - */ - private static final long CURRENT_TIME_NANOS = - TimeUnit.MILLISECONDS.toNanos(CERT0_EXPIRY_TIME_MILLIS - CERT0_VALIDITY_MILLIS); - @Rule - public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); - - private static class ResponseToSend { - Throwable getThrowable() { - throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName()); - } - - List getList() { - throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName()); - } - } - - private static class ResponseThrowable extends ResponseToSend { - final Throwable throwableToSend; - - ResponseThrowable(Throwable throwable) { - throwableToSend = throwable; - } - - @Override - Throwable getThrowable() { - return throwableToSend; - } - } - - private static class ResponseList extends ResponseToSend { - final List listToSend; - - ResponseList(List list) { - listToSend = list; - } - - @Override - List getList() { - return listToSend; - } - } - - private final Queue receivedRequests = new ArrayDeque<>(); - private final Queue receivedStsCreds = new ArrayDeque<>(); - private final Queue receivedZoneValues = new ArrayDeque<>(); - private final Queue responsesToSend = new ArrayDeque<>(); - private final Queue oauth2Tokens = new ArrayDeque<>(); - private final AtomicBoolean callEnded = new AtomicBoolean(true); - - @Mock private MeshCertificateServiceGrpc.MeshCertificateServiceImplBase mockedMeshCaService; - @Mock private CertificateProvider.Watcher mockWatcher; - @Mock private BackoffPolicy.Provider backoffPolicyProvider; - @Mock private BackoffPolicy backoffPolicy; - @Spy private GoogleCredentials oauth2Creds; - @Mock private ScheduledExecutorService timeService; - @Mock private TimeProvider timeProvider; - - private ManagedChannel channel; - private MeshCaCertificateProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - when(backoffPolicyProvider.get()).thenReturn(backoffPolicy); - when(backoffPolicy.nextBackoffNanos()) - .thenReturn(DELAY_VALUES[0], DELAY_VALUES[1], DELAY_VALUES[2]); - doAnswer( - new Answer() { - @Override - public AccessToken answer(InvocationOnMock invocation) throws Throwable { - return new AccessToken( - oauth2Tokens.poll(), new Date(System.currentTimeMillis() + 1000L)); - } - }) - .when(oauth2Creds) - .refreshAccessToken(); - final String meshCaUri = InProcessServerBuilder.generateName(); - MeshCertificateServiceGrpc.MeshCertificateServiceImplBase meshCaServiceImpl = - new MeshCertificateServiceGrpc.MeshCertificateServiceImplBase() { - - @Override - public void createCertificate( - MeshCertificateRequest request, - StreamObserver responseObserver) { - assertThat(callEnded.get()).isTrue(); // ensure previous call was ended - callEnded.set(false); - Context.current() - .addListener( - new Context.CancellationListener() { - @Override - public void cancelled(Context context) { - callEnded.set(true); - } - }, - MoreExecutors.directExecutor()); - receivedRequests.offer(request); - ResponseToSend response = responsesToSend.poll(); - if (response instanceof ResponseThrowable) { - responseObserver.onError(response.getThrowable()); - } else if (response instanceof ResponseList) { - List certChainInResponse = response.getList(); - MeshCertificateResponse responseToSend = - MeshCertificateResponse.newBuilder() - .addAllCertChain(certChainInResponse) - .build(); - responseObserver.onNext(responseToSend); - responseObserver.onCompleted(); - } else { - callEnded.set(true); - } - } - }; - mockedMeshCaService = - mock( - MeshCertificateServiceGrpc.MeshCertificateServiceImplBase.class, - delegatesTo(meshCaServiceImpl)); - ServerInterceptor interceptor = - new ServerInterceptor() { - @Override - public ServerCall.Listener interceptCall( - ServerCall call, Metadata headers, ServerCallHandler next) { - receivedStsCreds.offer(headers.get(KEY_FOR_AUTHORIZATION)); - receivedZoneValues.offer(headers.get(MeshCaCertificateProvider.KEY_FOR_ZONE_INFO)); - return next.startCall(call, headers); - } - }; - cleanupRule.register( - InProcessServerBuilder.forName(meshCaUri) - .addService(mockedMeshCaService) - .intercept(interceptor) - .directExecutor() - .build() - .start()); - channel = - cleanupRule.register(InProcessChannelBuilder.forName(meshCaUri).directExecutor().build()); - MeshCaCertificateProvider.MeshCaChannelFactory channelFactory = - new MeshCaCertificateProvider.MeshCaChannelFactory() { - @Override - ManagedChannel createChannel(String serverUri) { - assertThat(serverUri).isEqualTo(meshCaUri); - return channel; - } - }; - CertificateProvider.DistributorWatcher watcher = new CertificateProvider.DistributorWatcher(); - watcher.addWatcher(mockWatcher); // - provider = - new MeshCaCertificateProvider( - watcher, - true, - meshCaUri, - ZONE, - TimeUnit.HOURS.toSeconds(9L), - 2048, - "RSA", - "SHA256withRSA", - channelFactory, - backoffPolicyProvider, - RENEWAL_GRACE_PERIOD_SECONDS, - MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT, - oauth2Creds, - timeService, - timeProvider, - RPC_TIMEOUT_MILLIS); - } - - @Test - public void startAndClose() { - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.start(); - SynchronizationContext.ScheduledHandle savedScheduledHandle = provider.scheduledHandle; - assertThat(savedScheduledHandle).isNotNull(); - assertThat(savedScheduledHandle.isPending()).isTrue(); - verify(timeService, times(1)) - .schedule( - any(Runnable.class), - eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS), - eq(TimeUnit.SECONDS)); - DistributorWatcher distWatcher = provider.getWatcher(); - assertThat(distWatcher.downstreamWatchers).hasSize(1); - PrivateKey mockKey = mock(PrivateKey.class); - X509Certificate mockCert = mock(X509Certificate.class); - distWatcher.updateCertificate(mockKey, ImmutableList.of(mockCert)); - distWatcher.updateTrustedRoots(ImmutableList.of(mockCert)); - provider.close(); - assertThat(provider.scheduledHandle).isNull(); - assertThat(savedScheduledHandle.isPending()).isFalse(); - assertThat(distWatcher.downstreamWatchers).isEmpty(); - assertThat(distWatcher.getLastIdentityCert()).isNull(); - } - - @Test - public void startTwice_noException() { - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.start(); - SynchronizationContext.ScheduledHandle savedScheduledHandle1 = provider.scheduledHandle; - provider.start(); - SynchronizationContext.ScheduledHandle savedScheduledHandle2 = provider.scheduledHandle; - assertThat(savedScheduledHandle2).isNotSameInstanceAs(savedScheduledHandle1); - assertThat(savedScheduledHandle2.isPending()).isTrue(); - } - - @Test - public void getCertificate() - throws IOException, CertificateException, OperatorCreationException, - NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend.offer( - new ResponseList(ImmutableList.of( - CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE)))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - MeshCertificateRequest receivedReq = receivedRequests.poll(); - assertThat(receivedReq.getValidity().getSeconds()).isEqualTo(TimeUnit.HOURS.toSeconds(9L)); - // cannot decode CSR: just check the PEM format delimiters - String csr = receivedReq.getCsr(); - assertThat(csr).startsWith("-----BEGIN NEW CERTIFICATE REQUEST-----"); - verifyReceivedMetadataValues(1); - verify(timeService, times(1)) - .schedule( - any(Runnable.class), - eq( - TimeUnit.MILLISECONDS.toSeconds( - CERT0_VALIDITY_MILLIS - - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))), - eq(TimeUnit.SECONDS)); - verifyMockWatcher(); - } - - @Test - public void getCertificate_withError() - throws IOException, OperatorCreationException, NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend - .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION))); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS), - eq(TimeUnit.SECONDS)); - verifyReceivedMetadataValues(1); - } - - @Test - public void getCertificate_withError_withExistingCert() - throws IOException, OperatorCreationException, NoSuchAlgorithmException { - PrivateKey mockKey = mock(PrivateKey.class); - X509Certificate mockCert = mock(X509Certificate.class); - // have current cert expire in 3 hours from current time - long threeHoursFromNowMillis = TimeUnit.NANOSECONDS - .toMillis(CURRENT_TIME_NANOS + TimeUnit.HOURS.toNanos(3)); - when(mockCert.getNotAfter()).thenReturn(new Date(threeHoursFromNowMillis)); - provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert)); - reset(mockWatcher); - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend - .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - verify(mockWatcher, never()).onError(any(Status.class)); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(5400L), - eq(TimeUnit.SECONDS)); - assertThat(provider.getWatcher().getLastIdentityCert()).isNotNull(); - verifyReceivedMetadataValues(1); - } - - @Test - public void getCertificate_withError_withExistingExpiredCert() - throws IOException, OperatorCreationException, NoSuchAlgorithmException { - PrivateKey mockKey = mock(PrivateKey.class); - X509Certificate mockCert = mock(X509Certificate.class); - // have current cert expire in 3 seconds from current time - long threeSecondsFromNowMillis = TimeUnit.NANOSECONDS - .toMillis(CURRENT_TIME_NANOS + TimeUnit.SECONDS.toNanos(3)); - when(mockCert.getNotAfter()).thenReturn(new Date(threeSecondsFromNowMillis)); - provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert)); - reset(mockWatcher); - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend - .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS), - eq(TimeUnit.SECONDS)); - assertThat(provider.getWatcher().getLastIdentityCert()).isNull(); - verifyReceivedMetadataValues(1); - } - - @Test - public void getCertificate_retriesWithErrors() - throws IOException, CertificateException, OperatorCreationException, - NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - oauth2Tokens.offer(TEST_STS_TOKEN + "1"); - oauth2Tokens.offer(TEST_STS_TOKEN + "2"); - responsesToSend.offer(new ResponseThrowable(new StatusRuntimeException(Status.UNKNOWN))); - responsesToSend.offer( - new ResponseThrowable( - new Exception(new StatusRuntimeException(Status.RESOURCE_EXHAUSTED)))); - responsesToSend.offer(new ResponseList(ImmutableList.of( - CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE)))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - TestScheduledFuture scheduledFutureSleep = new TestScheduledFuture<>(); - doReturn(scheduledFutureSleep).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS)); - provider.refreshCertificate(); - assertThat(receivedRequests.size()).isEqualTo(3); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(TimeUnit.MILLISECONDS.toSeconds( - CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))), - eq(TimeUnit.SECONDS)); - verifyRetriesWithBackoff(scheduledFutureSleep, 2); - verifyMockWatcher(); - verifyReceivedMetadataValues(3); - } - - @Test - public void getCertificate_retriesWithTimeouts() - throws IOException, CertificateException, OperatorCreationException, - NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - oauth2Tokens.offer(TEST_STS_TOKEN + "1"); - oauth2Tokens.offer(TEST_STS_TOKEN + "2"); - oauth2Tokens.offer(TEST_STS_TOKEN + "3"); - responsesToSend.offer(new ResponseToSend()); - responsesToSend.offer(new ResponseToSend()); - responsesToSend.offer(new ResponseToSend()); - responsesToSend.offer(new ResponseList(ImmutableList.of( - CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE)))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - TestScheduledFuture scheduledFutureSleep = new TestScheduledFuture<>(); - doReturn(scheduledFutureSleep).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS)); - provider.refreshCertificate(); - assertThat(receivedRequests.size()).isEqualTo(4); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(TimeUnit.MILLISECONDS.toSeconds( - CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))), - eq(TimeUnit.SECONDS)); - verifyRetriesWithBackoff(scheduledFutureSleep, 3); - verifyMockWatcher(); - verifyReceivedMetadataValues(4); - } - - private void verifyRetriesWithBackoff( - TestScheduledFuture scheduledFutureSleep, int numOfRetries) { - for (int i = 0; i < numOfRetries; i++) { - long delayValue = DELAY_VALUES[i]; - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(delayValue), - eq(TimeUnit.NANOSECONDS)); - assertThat(scheduledFutureSleep.calls.get(i).timeout).isEqualTo(delayValue); - assertThat(scheduledFutureSleep.calls.get(i).unit).isEqualTo(TimeUnit.NANOSECONDS); - } - } - - private void verifyMockWatcher() throws IOException, CertificateException { - ArgumentCaptor> certChainCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)) - .updateCertificate(any(PrivateKey.class), certChainCaptor.capture()); - List certChain = certChainCaptor.getValue(); - assertThat(certChain).hasSize(3); - assertThat(certChain.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_0_PEM_FILE)); - assertThat(certChain.get(1)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_1_PEM_FILE)); - assertThat(certChain.get(2)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE)); - - ArgumentCaptor> rootsCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture()); - List roots = rootsCaptor.getValue(); - assertThat(roots).hasSize(1); - assertThat(roots.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE)); - verify(mockWatcher, never()).onError(any(Status.class)); - } - - private void verifyReceivedMetadataValues(int count) { - assertThat(receivedStsCreds).hasSize(count); - assertThat(receivedZoneValues).hasSize(count); - for (int i = 0; i < count; i++) { - assertThat(receivedStsCreds.poll()).isEqualTo("Bearer " + TEST_STS_TOKEN + i); - assertThat(receivedZoneValues.poll()).isEqualTo("location=locations/us-west2-a"); - } - } - - static class TestScheduledFuture implements ScheduledFuture { - - static class Record { - long timeout; - TimeUnit unit; - - Record(long timeout, TimeUnit unit) { - this.timeout = timeout; - this.unit = unit; - } - } - - ArrayList calls = new ArrayList<>(); - - @Override - public long getDelay(TimeUnit unit) { - return 0; - } - - @Override - public int compareTo(Delayed o) { - return 0; - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public V get() { - return null; - } - - @Override - public V get(long timeout, TimeUnit unit) { - calls.add(new Record(timeout, unit)); - return null; - } - } -}