diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java new file mode 100644 index 0000000000..6580185794 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java @@ -0,0 +1,249 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.InternalLogId; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.internal.TimeProvider; +import io.grpc.xds.internal.sds.trust.CertificateUtils; + +import java.io.ByteArrayInputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.attribute.FileTime; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +// TODO(sanjaypujare): abstract out common functionality into an an abstract superclass +/** Implementation of {@link CertificateProvider} for file watching cert provider. */ +final class FileWatcherCertificateProvider extends CertificateProvider { + private static final Logger logger = + Logger.getLogger(FileWatcherCertificateProvider.class.getName()); + + private final SynchronizationContext syncContext; + private final ScheduledExecutorService scheduledExecutorService; + private final TimeProvider timeProvider; + private final Path certFile; + private final Path keyFile; + private final Path trustFile; + private final long refreshIntervalInSeconds; + @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle; + private FileTime lastModifiedTimeCert; + private FileTime lastModifiedTimeKey; + private FileTime lastModifiedTimeRoot; + + FileWatcherCertificateProvider( + DistributorWatcher watcher, + boolean notifyCertUpdates, + String certFile, + String keyFile, + String trustFile, + long refreshIntervalInSeconds, + ScheduledExecutorService scheduledExecutorService, + TimeProvider timeProvider) { + super(watcher, notifyCertUpdates); + this.scheduledExecutorService = + checkNotNull(scheduledExecutorService, "scheduledExecutorService"); + this.timeProvider = checkNotNull(timeProvider, "timeProvider"); + this.certFile = Paths.get(checkNotNull(certFile, "certFile")); + this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile")); + this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile")); + this.refreshIntervalInSeconds = refreshIntervalInSeconds; + this.syncContext = createSynchronizationContext(certFile); + } + + private SynchronizationContext createSynchronizationContext(String details) { + final InternalLogId logId = + InternalLogId.allocate("DynamicReloadingCertificateProvider", details); + return new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + private boolean panicMode; + + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.log( + Level.SEVERE, + "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", + e); + panic(e); + } + + void panic(final Throwable t) { + if (panicMode) { + // Preserve the first panic information + return; + } + panicMode = true; + close(); + } + }); + } + + @Override + public void start() { + scheduleNextRefreshCertificate(/* delayInSeconds= */0); + } + + @Override + public void close() { + if (scheduledHandle != null) { + scheduledHandle.cancel(); + scheduledHandle = null; + } + getWatcher().close(); + } + + private void scheduleNextRefreshCertificate(long delayInSeconds) { + RefreshCertificateTask runnable = new RefreshCertificateTask(); + scheduledHandle = + syncContext.schedule(runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService); + } + + @VisibleForTesting + void checkAndReloadCertificates() { + try { + try { + FileTime currentCertTime = Files.getLastModifiedTime(certFile); + FileTime currentKeyTime = Files.getLastModifiedTime(keyFile); + if (!currentCertTime.equals(lastModifiedTimeCert) + && !currentKeyTime.equals(lastModifiedTimeKey)) { + byte[] certFileContents = Files.readAllBytes(certFile); + byte[] keyFileContents = Files.readAllBytes(keyFile); + FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); + FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile); + if (!currentCertTime2.equals(currentCertTime)) { + return; + } + if (!currentKeyTime2.equals(currentKeyTime)) { + return; + } + try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); + ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { + PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); + X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); + getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + } + lastModifiedTimeCert = currentCertTime; + lastModifiedTimeKey = currentKeyTime; + } + } catch (Throwable t) { + generateErrorIfCurrentCertExpired(t); + } + try { + FileTime currentRootTime = Files.getLastModifiedTime(trustFile); + if (currentRootTime.equals(lastModifiedTimeRoot)) { + return; + } + byte[] rootFileContents = Files.readAllBytes(trustFile); + FileTime currentRootTime2 = Files.getLastModifiedTime(trustFile); + if (!currentRootTime2.equals(currentRootTime)) { + return; + } + try (ByteArrayInputStream rootStream = new ByteArrayInputStream(rootFileContents)) { + X509Certificate[] caCerts = CertificateUtils.toX509Certificates(rootStream); + getWatcher().updateTrustedRoots(Arrays.asList(caCerts)); + } + lastModifiedTimeRoot = currentRootTime; + } catch (Throwable t) { + getWatcher().onError(Status.fromThrowable(t)); + } + } finally { + scheduleNextRefreshCertificate(refreshIntervalInSeconds); + } + } + + private void generateErrorIfCurrentCertExpired(Throwable t) { + X509Certificate currentCert = getWatcher().getLastIdentityCert(); + if (currentCert != null) { + long delaySeconds = computeDelaySecondsToCertExpiry(currentCert); + if (delaySeconds > refreshIntervalInSeconds) { + logger.log(Level.FINER, "reload certificate error", t); + return; + } + // The current cert is going to expire in less than {@link refreshIntervalInSeconds} + // Clear the current cert and notify our watchers thru {@code onError} + getWatcher().clearValues(); + } + getWatcher().onError(Status.fromThrowable(t)); + } + + @SuppressWarnings("JdkObsolete") + private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) { + checkNotNull(lastCert, "lastCert"); + return TimeUnit.NANOSECONDS.toSeconds( + TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) + - timeProvider.currentTimeNanos()); + } + + @VisibleForTesting + class RefreshCertificateTask implements Runnable { + @Override + public void run() { + checkAndReloadCertificates(); + } + } + + abstract static class Factory { + private static final Factory DEFAULT_INSTANCE = + new Factory() { + @Override + FileWatcherCertificateProvider create( + DistributorWatcher watcher, + boolean notifyCertUpdates, + String certFile, + String keyFile, + String trustFile, + long refreshIntervalInSeconds, + ScheduledExecutorService scheduledExecutorService, + TimeProvider timeProvider) { + return new FileWatcherCertificateProvider( + watcher, + notifyCertUpdates, + certFile, + keyFile, + trustFile, + refreshIntervalInSeconds, + scheduledExecutorService, + timeProvider); + } + }; + + static Factory getInstance() { + return DEFAULT_INSTANCE; + } + + abstract FileWatcherCertificateProvider create( + DistributorWatcher watcher, + boolean notifyCertUpdates, + String certFile, + String keyFile, + String trustFile, + long refreshIntervalInSeconds, + ScheduledExecutorService scheduledExecutorService, + TimeProvider timeProvider); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java new file mode 100644 index 0000000000..493055acb2 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java @@ -0,0 +1,145 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.google.protobuf.Duration; +import com.google.protobuf.util.Durations; +import io.grpc.internal.JsonUtil; +import io.grpc.internal.TimeProvider; +import java.text.ParseException; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; + +/** + * Provider of {@link FileWatcherCertificateProvider}s. + */ +final class FileWatcherCertificateProviderProvider implements CertificateProviderProvider { + + private static final String CERT_FILE_KEY = "certificate_file"; + private static final String KEY_FILE_KEY = "private_key_file"; + private static final String ROOT_FILE_KEY = "ca_certificate_file"; + private static final String REFRESH_INTERVAL_KEY = "refresh_interval"; + + @VisibleForTesting static final long REFRESH_INTERVAL_DEFAULT = 600L; + + + static final String FILE_WATCHER_PROVIDER_NAME = "file_watcher"; + + static { + CertificateProviderRegistry.getInstance() + .register( + new FileWatcherCertificateProviderProvider( + FileWatcherCertificateProvider.Factory.getInstance(), + ScheduledExecutorServiceFactory.DEFAULT_INSTANCE, + TimeProvider.SYSTEM_TIME_PROVIDER)); + } + + final FileWatcherCertificateProvider.Factory fileWatcherCertificateProviderFactory; + private final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory; + private final TimeProvider timeProvider; + + @VisibleForTesting + FileWatcherCertificateProviderProvider( + FileWatcherCertificateProvider.Factory fileWatcherCertificateProviderFactory, + ScheduledExecutorServiceFactory scheduledExecutorServiceFactory, + TimeProvider timeProvider) { + this.fileWatcherCertificateProviderFactory = fileWatcherCertificateProviderFactory; + this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory; + this.timeProvider = timeProvider; + } + + @Override + public String getName() { + return FILE_WATCHER_PROVIDER_NAME; + } + + @Override + public CertificateProvider createCertificateProvider( + Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) { + + Config configObj = validateAndTranslateConfig(config); + return fileWatcherCertificateProviderFactory.create( + watcher, + notifyCertUpdates, + configObj.certFile, + configObj.keyFile, + configObj.rootFile, + configObj.refrehInterval, + scheduledExecutorServiceFactory.create(), + timeProvider); + } + + private static String checkForNullAndGet(Map map, String key) { + return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config"); + } + + private static Config validateAndTranslateConfig(Object config) { + checkArgument(config instanceof Map, "Only Map supported for config"); + @SuppressWarnings("unchecked") Map map = (Map)config; + + Config configObj = new Config(); + configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY); + configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY); + configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); + String refreshIntervalString = JsonUtil.getString(map, REFRESH_INTERVAL_KEY); + if (refreshIntervalString != null) { + try { + Duration duration = Durations.parse(refreshIntervalString); + configObj.refrehInterval = duration.getSeconds(); + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + if (configObj.refrehInterval == null) { + configObj.refrehInterval = REFRESH_INTERVAL_DEFAULT; + } + return configObj; + } + + abstract static class ScheduledExecutorServiceFactory { + + private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE = + new ScheduledExecutorServiceFactory() { + + @Override + ScheduledExecutorService create() { + return Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setNameFormat("fileWatcher" + "-%d") + .setDaemon(true) + .build()); + } + }; + + abstract ScheduledExecutorService create(); + } + + /** POJO class for storing various config values. */ + @VisibleForTesting + static class Config { + String certFile; + String keyFile; + String rootFile; + Long refrehInterval; + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java new file mode 100644 index 0000000000..b3ab599952 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java @@ -0,0 +1,186 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.internal.JsonParser; +import io.grpc.internal.TimeProvider; +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Unit tests for {@link FileWatcherCertificateProviderProvider}. */ +@RunWith(JUnit4.class) +public class FileWatcherCertificateProviderProviderTest { + + @Mock FileWatcherCertificateProvider.Factory fileWatcherCertificateProviderFactory; + @Mock private FileWatcherCertificateProviderProvider.ScheduledExecutorServiceFactory + scheduledExecutorServiceFactory; + @Mock private TimeProvider timeProvider; + + private FileWatcherCertificateProviderProvider provider; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + provider = + new FileWatcherCertificateProviderProvider( + fileWatcherCertificateProviderFactory, scheduledExecutorServiceFactory, timeProvider); + } + + @Test + public void providerRegisteredName() { + CertificateProviderProvider certProviderProvider = + CertificateProviderRegistry.getInstance() + .getProvider(FileWatcherCertificateProviderProvider.FILE_WATCHER_PROVIDER_NAME); + assertThat(certProviderProvider).isInstanceOf(FileWatcherCertificateProviderProvider.class); + FileWatcherCertificateProviderProvider fileWatcherCertificateProviderProvider = + (FileWatcherCertificateProviderProvider) certProviderProvider; + assertThat(fileWatcherCertificateProviderProvider.fileWatcherCertificateProviderFactory) + .isSameInstanceAs(FileWatcherCertificateProvider.Factory.getInstance()); + } + + @Test + public void createProvider_minimalConfig() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(MINIMAL_FILE_WATCHER_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates.pem"), + eq("/var/run/gke-spiffe/certs/private_key.pem"), + eq("/var/run/gke-spiffe/certs/ca_certificates.pem"), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_fullConfig() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(FULL_FILE_WATCHER_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates2.pem"), + eq("/var/run/gke-spiffe/certs/private_key3.pem"), + eq("/var/run/gke-spiffe/certs/ca_certificates4.pem"), + eq(7890L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_missingCert_expectException() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(MISSING_CERT_CONFIG); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (NullPointerException npe) { + assertThat(npe).hasMessageThat().isEqualTo("'certificate_file' is required in the config"); + } + } + + @Test + public void createProvider_missingKey_expectException() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(MISSING_KEY_CONFIG); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (NullPointerException npe) { + assertThat(npe).hasMessageThat().isEqualTo("'private_key_file' is required in the config"); + } + } + + @Test + public void createProvider_missingRoot_expectException() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(MISSING_ROOT_CONFIG); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (NullPointerException npe) { + assertThat(npe).hasMessageThat().isEqualTo("'ca_certificate_file' is required in the config"); + } + } + + private static final String MINIMAL_FILE_WATCHER_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"," + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + + " }"; + + private static final String FULL_FILE_WATCHER_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key3.pem\"," + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates4.pem\"," + + " \"refresh_interval\": \"7890s\"" + + " }"; + + private static final String MISSING_CERT_CONFIG = + "{\n" + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"," + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + + " }"; + + private static final String MISSING_KEY_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + + " }"; + + private static final String MISSING_ROOT_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"" + + " }"; +} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java new file mode 100644 index 0000000000..e2cb464734 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java @@ -0,0 +1,376 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.Status; +import io.grpc.internal.TimeProvider; +import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Paths; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Unit tests for {@link FileWatcherCertificateProvider}. */ +@RunWith(JUnit4.class) +public class FileWatcherCertificateProviderTest { + private static final String CERT_FILE = "cert.pem"; + private static final String KEY_FILE = "key.pem"; + private static final String ROOT_FILE = "root.pem"; + + @Mock private CertificateProvider.Watcher mockWatcher; + @Mock private ScheduledExecutorService timeService; + @Mock private TimeProvider timeProvider; + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + private String certFile; + private String keyFile; + private String rootFile; + + private FileWatcherCertificateProvider provider; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + + DistributorWatcher watcher = new DistributorWatcher(); + watcher.addWatcher(mockWatcher); + + certFile = new File(tempFolder.getRoot(), CERT_FILE).getAbsolutePath(); + keyFile = new File(tempFolder.getRoot(), KEY_FILE).getAbsolutePath(); + rootFile = new File(tempFolder.getRoot(), ROOT_FILE).getAbsolutePath(); + provider = + new FileWatcherCertificateProvider( + watcher, true, certFile, keyFile, rootFile, 600L, timeService, timeProvider); + } + + private void populateTarget( + String certFileSource, + String keyFileSource, + String rootFileSource, + boolean deleteCurCert, + boolean deleteCurKey, + boolean deleteCurRoot) + throws IOException { + if (deleteCurCert) { + Files.delete(Paths.get(certFile)); + } + if (certFileSource != null) { + certFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFileSource); + Files.copy(Paths.get(certFileSource), Paths.get(certFile), REPLACE_EXISTING); + } + if (deleteCurKey) { + Files.delete(Paths.get(keyFile)); + } + if (keyFileSource != null) { + keyFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFileSource); + Files.copy(Paths.get(keyFileSource), Paths.get(keyFile), REPLACE_EXISTING); + } + if (deleteCurRoot) { + Files.delete(Paths.get(rootFile)); + } + if (rootFileSource != null) { + rootFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFileSource); + Files.copy(Paths.get(rootFileSource), Paths.get(rootFile), REPLACE_EXISTING); + } + } + + @Test + public void getCertificateAndCheckUpdates() throws IOException, CertificateException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); + verifyTimeServiceAndScheduledHandle(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + provider.checkAndReloadCertificates(); + verifyWatcherErrorUpdates(null, null, 0, 0, (String[]) null); + verifyTimeServiceAndScheduledHandle(); + } + + @Test + public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + Thread.sleep(1000L); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); + verifyTimeServiceAndScheduledHandle(); + } + + @Test + public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + Thread.sleep(1000L); + populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(null, SERVER_1_PEM_FILE); + verifyTimeServiceAndScheduledHandle(); + } + + @Test + public void certAndKeyFileUpdateOnly() + throws IOException, CertificateException, InterruptedException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + Thread.sleep(1000L); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(SERVER_0_PEM_FILE, null); + verifyTimeServiceAndScheduledHandle(); + } + + @Test + public void getCertificate_initialMissingCertFile() throws IOException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + when(timeProvider.currentTimeNanos()) + .thenReturn(TimeProvider.SYSTEM_TIME_PROVIDER.currentTimeNanos()); + provider.checkAndReloadCertificates(); + verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 0, 1, "cert.pem"); + } + + @Test + public void getCertificate_missingCertFile() throws IOException, InterruptedException { + commonErrorTest( + null, CLIENT_KEY_FILE, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "cert.pem"); + } + + @Test + public void getCertificate_missingKeyFile() throws IOException, InterruptedException { + commonErrorTest( + CLIENT_PEM_FILE, null, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "key.pem"); + } + + @Test + public void getCertificate_badKeyFile() throws IOException, InterruptedException { + commonErrorTest( + CLIENT_PEM_FILE, + SERVER_0_PEM_FILE, + CA_PEM_FILE, + java.security.KeyException.class, + 0, + 1, + 0, + 0, + "could not find a PKCS #8 private key in input stream"); + } + + @Test + public void getCertificate_missingRootFile() throws IOException, InterruptedException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher); + Thread.sleep(1000L); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, false, false, true); + when(timeProvider.currentTimeNanos()) + .thenReturn( + TimeUnit.MILLISECONDS.toNanos( + MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + provider.checkAndReloadCertificates(); + verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem"); + } + + private void commonErrorTest( + String certFile, + String keyFile, + String rootFile, + Class throwableType, + int firstUpdateCertCount, + int firstUpdateRootCount, + int secondUpdateCertCount, + int secondUpdateRootCount, + String... causeMessages) + throws IOException, InterruptedException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher); + Thread.sleep(1000L); + populateTarget( + certFile, keyFile, rootFile, certFile == null, keyFile == null, rootFile == null); + when(timeProvider.currentTimeNanos()) + .thenReturn( + TimeUnit.MILLISECONDS.toNanos( + MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + provider.checkAndReloadCertificates(); + verifyWatcherErrorUpdates( + null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null); + + reset(mockWatcher, timeProvider); + when(timeProvider.currentTimeNanos()) + .thenReturn( + TimeUnit.MILLISECONDS.toNanos( + MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L)); + provider.checkAndReloadCertificates(); + verifyWatcherErrorUpdates( + Status.Code.UNKNOWN, + throwableType, + secondUpdateCertCount, + secondUpdateRootCount, + causeMessages); + } + + private void verifyWatcherErrorUpdates( + Status.Code code, + Class throwableType, + int updateCertCount, + int updateRootCount, + String... causeMessages) { + verify(mockWatcher, times(updateCertCount)) + .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); + verify(mockWatcher, times(updateRootCount)) + .updateTrustedRoots(ArgumentMatchers.anyList()); + if (code == null && throwableType == null && causeMessages == null) { + verify(mockWatcher, never()).onError(any(Status.class)); + } else { + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + verify(mockWatcher, times(1)).onError(statusCaptor.capture()); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(code); + Throwable cause = status.getCause(); + assertThat(cause).isInstanceOf(throwableType); + for (String causeMessage : causeMessages) { + assertThat(cause).hasMessageThat().contains(causeMessage); + cause = cause.getCause(); + } + } + } + + private void verifyTimeServiceAndScheduledHandle() { + verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS)); + assertThat(provider.scheduledHandle).isNotNull(); + assertThat(provider.scheduledHandle.isPending()).isTrue(); + } + + private void verifyWatcherUpdates(String certPemFile, String rootPemFile) + throws IOException, CertificateException { + if (certPemFile != null) { + ArgumentCaptor> certChainCaptor = ArgumentCaptor.forClass(null); + verify(mockWatcher, times(1)) + .updateCertificate(any(PrivateKey.class), certChainCaptor.capture()); + List certChain = certChainCaptor.getValue(); + assertThat(certChain).hasSize(1); + assertThat(certChain.get(0)) + .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(certPemFile)); + } else { + verify(mockWatcher, never()) + .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); + } + if (rootPemFile != null) { + ArgumentCaptor> rootsCaptor = ArgumentCaptor.forClass(null); + verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture()); + List roots = rootsCaptor.getValue(); + assertThat(roots).hasSize(1); + assertThat(roots.get(0)) + .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(rootPemFile)); + verify(mockWatcher, never()).onError(any(Status.class)); + } else { + verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); + } + } +}