xds: add File-watcher certificate provider (#7590)

This commit is contained in:
sanjaypujare 2020-11-09 09:52:42 -08:00 committed by GitHub
parent d154aa3328
commit cffc07f5d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 956 additions and 0 deletions

View File

@ -0,0 +1,249 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.InternalLogId;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import java.io.ByteArrayInputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileTime;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
// TODO(sanjaypujare): abstract out common functionality into an an abstract superclass
/** Implementation of {@link CertificateProvider} for file watching cert provider. */
final class FileWatcherCertificateProvider extends CertificateProvider {
private static final Logger logger =
Logger.getLogger(FileWatcherCertificateProvider.class.getName());
private final SynchronizationContext syncContext;
private final ScheduledExecutorService scheduledExecutorService;
private final TimeProvider timeProvider;
private final Path certFile;
private final Path keyFile;
private final Path trustFile;
private final long refreshIntervalInSeconds;
@VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle;
private FileTime lastModifiedTimeCert;
private FileTime lastModifiedTimeKey;
private FileTime lastModifiedTimeRoot;
FileWatcherCertificateProvider(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String certFile,
String keyFile,
String trustFile,
long refreshIntervalInSeconds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider) {
super(watcher, notifyCertUpdates);
this.scheduledExecutorService =
checkNotNull(scheduledExecutorService, "scheduledExecutorService");
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.certFile = Paths.get(checkNotNull(certFile, "certFile"));
this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile"));
this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile"));
this.refreshIntervalInSeconds = refreshIntervalInSeconds;
this.syncContext = createSynchronizationContext(certFile);
}
private SynchronizationContext createSynchronizationContext(String details) {
final InternalLogId logId =
InternalLogId.allocate("DynamicReloadingCertificateProvider", details);
return new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
private boolean panicMode;
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.log(
Level.SEVERE,
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
e);
panic(e);
}
void panic(final Throwable t) {
if (panicMode) {
// Preserve the first panic information
return;
}
panicMode = true;
close();
}
});
}
@Override
public void start() {
scheduleNextRefreshCertificate(/* delayInSeconds= */0);
}
@Override
public void close() {
if (scheduledHandle != null) {
scheduledHandle.cancel();
scheduledHandle = null;
}
getWatcher().close();
}
private void scheduleNextRefreshCertificate(long delayInSeconds) {
RefreshCertificateTask runnable = new RefreshCertificateTask();
scheduledHandle =
syncContext.schedule(runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService);
}
@VisibleForTesting
void checkAndReloadCertificates() {
try {
try {
FileTime currentCertTime = Files.getLastModifiedTime(certFile);
FileTime currentKeyTime = Files.getLastModifiedTime(keyFile);
if (!currentCertTime.equals(lastModifiedTimeCert)
&& !currentKeyTime.equals(lastModifiedTimeKey)) {
byte[] certFileContents = Files.readAllBytes(certFile);
byte[] keyFileContents = Files.readAllBytes(keyFile);
FileTime currentCertTime2 = Files.getLastModifiedTime(certFile);
FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile);
if (!currentCertTime2.equals(currentCertTime)) {
return;
}
if (!currentKeyTime2.equals(currentKeyTime)) {
return;
}
try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents);
ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) {
PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream);
X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream);
getWatcher().updateCertificate(privateKey, Arrays.asList(certs));
}
lastModifiedTimeCert = currentCertTime;
lastModifiedTimeKey = currentKeyTime;
}
} catch (Throwable t) {
generateErrorIfCurrentCertExpired(t);
}
try {
FileTime currentRootTime = Files.getLastModifiedTime(trustFile);
if (currentRootTime.equals(lastModifiedTimeRoot)) {
return;
}
byte[] rootFileContents = Files.readAllBytes(trustFile);
FileTime currentRootTime2 = Files.getLastModifiedTime(trustFile);
if (!currentRootTime2.equals(currentRootTime)) {
return;
}
try (ByteArrayInputStream rootStream = new ByteArrayInputStream(rootFileContents)) {
X509Certificate[] caCerts = CertificateUtils.toX509Certificates(rootStream);
getWatcher().updateTrustedRoots(Arrays.asList(caCerts));
}
lastModifiedTimeRoot = currentRootTime;
} catch (Throwable t) {
getWatcher().onError(Status.fromThrowable(t));
}
} finally {
scheduleNextRefreshCertificate(refreshIntervalInSeconds);
}
}
private void generateErrorIfCurrentCertExpired(Throwable t) {
X509Certificate currentCert = getWatcher().getLastIdentityCert();
if (currentCert != null) {
long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
if (delaySeconds > refreshIntervalInSeconds) {
logger.log(Level.FINER, "reload certificate error", t);
return;
}
// The current cert is going to expire in less than {@link refreshIntervalInSeconds}
// Clear the current cert and notify our watchers thru {@code onError}
getWatcher().clearValues();
}
getWatcher().onError(Status.fromThrowable(t));
}
@SuppressWarnings("JdkObsolete")
private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
checkNotNull(lastCert, "lastCert");
return TimeUnit.NANOSECONDS.toSeconds(
TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime())
- timeProvider.currentTimeNanos());
}
@VisibleForTesting
class RefreshCertificateTask implements Runnable {
@Override
public void run() {
checkAndReloadCertificates();
}
}
abstract static class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory() {
@Override
FileWatcherCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String certFile,
String keyFile,
String trustFile,
long refreshIntervalInSeconds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider) {
return new FileWatcherCertificateProvider(
watcher,
notifyCertUpdates,
certFile,
keyFile,
trustFile,
refreshIntervalInSeconds,
scheduledExecutorService,
timeProvider);
}
};
static Factory getInstance() {
return DEFAULT_INSTANCE;
}
abstract FileWatcherCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String certFile,
String keyFile,
String trustFile,
long refreshIntervalInSeconds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider);
}
}

View File

@ -0,0 +1,145 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.protobuf.Duration;
import com.google.protobuf.util.Durations;
import io.grpc.internal.JsonUtil;
import io.grpc.internal.TimeProvider;
import java.text.ParseException;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
/**
* Provider of {@link FileWatcherCertificateProvider}s.
*/
final class FileWatcherCertificateProviderProvider implements CertificateProviderProvider {
private static final String CERT_FILE_KEY = "certificate_file";
private static final String KEY_FILE_KEY = "private_key_file";
private static final String ROOT_FILE_KEY = "ca_certificate_file";
private static final String REFRESH_INTERVAL_KEY = "refresh_interval";
@VisibleForTesting static final long REFRESH_INTERVAL_DEFAULT = 600L;
static final String FILE_WATCHER_PROVIDER_NAME = "file_watcher";
static {
CertificateProviderRegistry.getInstance()
.register(
new FileWatcherCertificateProviderProvider(
FileWatcherCertificateProvider.Factory.getInstance(),
ScheduledExecutorServiceFactory.DEFAULT_INSTANCE,
TimeProvider.SYSTEM_TIME_PROVIDER));
}
final FileWatcherCertificateProvider.Factory fileWatcherCertificateProviderFactory;
private final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory;
private final TimeProvider timeProvider;
@VisibleForTesting
FileWatcherCertificateProviderProvider(
FileWatcherCertificateProvider.Factory fileWatcherCertificateProviderFactory,
ScheduledExecutorServiceFactory scheduledExecutorServiceFactory,
TimeProvider timeProvider) {
this.fileWatcherCertificateProviderFactory = fileWatcherCertificateProviderFactory;
this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory;
this.timeProvider = timeProvider;
}
@Override
public String getName() {
return FILE_WATCHER_PROVIDER_NAME;
}
@Override
public CertificateProvider createCertificateProvider(
Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) {
Config configObj = validateAndTranslateConfig(config);
return fileWatcherCertificateProviderFactory.create(
watcher,
notifyCertUpdates,
configObj.certFile,
configObj.keyFile,
configObj.rootFile,
configObj.refrehInterval,
scheduledExecutorServiceFactory.create(),
timeProvider);
}
private static String checkForNullAndGet(Map<String, ?> map, String key) {
return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config");
}
private static Config validateAndTranslateConfig(Object config) {
checkArgument(config instanceof Map, "Only Map supported for config");
@SuppressWarnings("unchecked") Map<String, ?> map = (Map<String, ?>)config;
Config configObj = new Config();
configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY);
configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY);
configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY);
String refreshIntervalString = JsonUtil.getString(map, REFRESH_INTERVAL_KEY);
if (refreshIntervalString != null) {
try {
Duration duration = Durations.parse(refreshIntervalString);
configObj.refrehInterval = duration.getSeconds();
} catch (ParseException e) {
throw new IllegalArgumentException(e);
}
}
if (configObj.refrehInterval == null) {
configObj.refrehInterval = REFRESH_INTERVAL_DEFAULT;
}
return configObj;
}
abstract static class ScheduledExecutorServiceFactory {
private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE =
new ScheduledExecutorServiceFactory() {
@Override
ScheduledExecutorService create() {
return Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
.setNameFormat("fileWatcher" + "-%d")
.setDaemon(true)
.build());
}
};
abstract ScheduledExecutorService create();
}
/** POJO class for storing various config values. */
@VisibleForTesting
static class Config {
String certFile;
String keyFile;
String rootFile;
Long refrehInterval;
}
}

View File

@ -0,0 +1,186 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.internal.JsonParser;
import io.grpc.internal.TimeProvider;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link FileWatcherCertificateProviderProvider}. */
@RunWith(JUnit4.class)
public class FileWatcherCertificateProviderProviderTest {
@Mock FileWatcherCertificateProvider.Factory fileWatcherCertificateProviderFactory;
@Mock private FileWatcherCertificateProviderProvider.ScheduledExecutorServiceFactory
scheduledExecutorServiceFactory;
@Mock private TimeProvider timeProvider;
private FileWatcherCertificateProviderProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
provider =
new FileWatcherCertificateProviderProvider(
fileWatcherCertificateProviderFactory, scheduledExecutorServiceFactory, timeProvider);
}
@Test
public void providerRegisteredName() {
CertificateProviderProvider certProviderProvider =
CertificateProviderRegistry.getInstance()
.getProvider(FileWatcherCertificateProviderProvider.FILE_WATCHER_PROVIDER_NAME);
assertThat(certProviderProvider).isInstanceOf(FileWatcherCertificateProviderProvider.class);
FileWatcherCertificateProviderProvider fileWatcherCertificateProviderProvider =
(FileWatcherCertificateProviderProvider) certProviderProvider;
assertThat(fileWatcherCertificateProviderProvider.fileWatcherCertificateProviderFactory)
.isSameInstanceAs(FileWatcherCertificateProvider.Factory.getInstance());
}
@Test
public void createProvider_minimalConfig() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MINIMAL_FILE_WATCHER_CONFIG);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create()).thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(fileWatcherCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq("/var/run/gke-spiffe/certs/certificates.pem"),
eq("/var/run/gke-spiffe/certs/private_key.pem"),
eq("/var/run/gke-spiffe/certs/ca_certificates.pem"),
eq(600L),
eq(mockService),
eq(timeProvider));
}
@Test
public void createProvider_fullConfig() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(FULL_FILE_WATCHER_CONFIG);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create()).thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(fileWatcherCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq("/var/run/gke-spiffe/certs/certificates2.pem"),
eq("/var/run/gke-spiffe/certs/private_key3.pem"),
eq("/var/run/gke-spiffe/certs/ca_certificates4.pem"),
eq(7890L),
eq(mockService),
eq(timeProvider));
}
@Test
public void createProvider_missingCert_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_CERT_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'certificate_file' is required in the config");
}
}
@Test
public void createProvider_missingKey_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_KEY_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'private_key_file' is required in the config");
}
}
@Test
public void createProvider_missingRoot_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(MISSING_ROOT_CONFIG);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("'ca_certificate_file' is required in the config");
}
}
private static final String MINIMAL_FILE_WATCHER_CONFIG =
"{\n"
+ " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\","
+ " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\","
+ " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\""
+ " }";
private static final String FULL_FILE_WATCHER_CONFIG =
"{\n"
+ " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\","
+ " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key3.pem\","
+ " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates4.pem\","
+ " \"refresh_interval\": \"7890s\""
+ " }";
private static final String MISSING_CERT_CONFIG =
"{\n"
+ " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\","
+ " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\""
+ " }";
private static final String MISSING_KEY_CONFIG =
"{\n"
+ " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\","
+ " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\""
+ " }";
private static final String MISSING_ROOT_CONFIG =
"{\n"
+ " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\","
+ " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\""
+ " }";
}

View File

@ -0,0 +1,376 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static java.nio.file.StandardCopyOption.REPLACE_EXISTING;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Status;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.NoSuchFileException;
import java.nio.file.Paths;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link FileWatcherCertificateProvider}. */
@RunWith(JUnit4.class)
public class FileWatcherCertificateProviderTest {
private static final String CERT_FILE = "cert.pem";
private static final String KEY_FILE = "key.pem";
private static final String ROOT_FILE = "root.pem";
@Mock private CertificateProvider.Watcher mockWatcher;
@Mock private ScheduledExecutorService timeService;
@Mock private TimeProvider timeProvider;
@Rule public TemporaryFolder tempFolder = new TemporaryFolder();
private String certFile;
private String keyFile;
private String rootFile;
private FileWatcherCertificateProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
DistributorWatcher watcher = new DistributorWatcher();
watcher.addWatcher(mockWatcher);
certFile = new File(tempFolder.getRoot(), CERT_FILE).getAbsolutePath();
keyFile = new File(tempFolder.getRoot(), KEY_FILE).getAbsolutePath();
rootFile = new File(tempFolder.getRoot(), ROOT_FILE).getAbsolutePath();
provider =
new FileWatcherCertificateProvider(
watcher, true, certFile, keyFile, rootFile, 600L, timeService, timeProvider);
}
private void populateTarget(
String certFileSource,
String keyFileSource,
String rootFileSource,
boolean deleteCurCert,
boolean deleteCurKey,
boolean deleteCurRoot)
throws IOException {
if (deleteCurCert) {
Files.delete(Paths.get(certFile));
}
if (certFileSource != null) {
certFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFileSource);
Files.copy(Paths.get(certFileSource), Paths.get(certFile), REPLACE_EXISTING);
}
if (deleteCurKey) {
Files.delete(Paths.get(keyFile));
}
if (keyFileSource != null) {
keyFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFileSource);
Files.copy(Paths.get(keyFileSource), Paths.get(keyFile), REPLACE_EXISTING);
}
if (deleteCurRoot) {
Files.delete(Paths.get(rootFile));
}
if (rootFileSource != null) {
rootFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFileSource);
Files.copy(Paths.get(rootFileSource), Paths.get(rootFile), REPLACE_EXISTING);
}
}
@Test
public void getCertificateAndCheckUpdates() throws IOException, CertificateException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE);
verifyTimeServiceAndScheduledHandle();
reset(mockWatcher, timeService);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(null, null, 0, 0, (String[]) null);
verifyTimeServiceAndScheduledHandle();
}
@Test
public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
reset(mockWatcher, timeService);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
Thread.sleep(1000L);
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE);
verifyTimeServiceAndScheduledHandle();
}
@Test
public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
reset(mockWatcher, timeService);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
Thread.sleep(1000L);
populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
verifyWatcherUpdates(null, SERVER_1_PEM_FILE);
verifyTimeServiceAndScheduledHandle();
}
@Test
public void certAndKeyFileUpdateOnly()
throws IOException, CertificateException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
reset(mockWatcher, timeService);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
Thread.sleep(1000L);
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false);
provider.checkAndReloadCertificates();
verifyWatcherUpdates(SERVER_0_PEM_FILE, null);
verifyTimeServiceAndScheduledHandle();
}
@Test
public void getCertificate_initialMissingCertFile() throws IOException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
when(timeProvider.currentTimeNanos())
.thenReturn(TimeProvider.SYSTEM_TIME_PROVIDER.currentTimeNanos());
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 0, 1, "cert.pem");
}
@Test
public void getCertificate_missingCertFile() throws IOException, InterruptedException {
commonErrorTest(
null, CLIENT_KEY_FILE, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "cert.pem");
}
@Test
public void getCertificate_missingKeyFile() throws IOException, InterruptedException {
commonErrorTest(
CLIENT_PEM_FILE, null, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "key.pem");
}
@Test
public void getCertificate_badKeyFile() throws IOException, InterruptedException {
commonErrorTest(
CLIENT_PEM_FILE,
SERVER_0_PEM_FILE,
CA_PEM_FILE,
java.security.KeyException.class,
0,
1,
0,
0,
"could not find a PKCS #8 private key in input stream");
}
@Test
public void getCertificate_missingRootFile() throws IOException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
reset(mockWatcher);
Thread.sleep(1000L);
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, false, false, true);
when(timeProvider.currentTimeNanos())
.thenReturn(
TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L));
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem");
}
private void commonErrorTest(
String certFile,
String keyFile,
String rootFile,
Class<?> throwableType,
int firstUpdateCertCount,
int firstUpdateRootCount,
int secondUpdateCertCount,
int secondUpdateRootCount,
String... causeMessages)
throws IOException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates();
reset(mockWatcher);
Thread.sleep(1000L);
populateTarget(
certFile, keyFile, rootFile, certFile == null, keyFile == null, rootFile == null);
when(timeProvider.currentTimeNanos())
.thenReturn(
TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L));
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(
null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null);
reset(mockWatcher, timeProvider);
when(timeProvider.currentTimeNanos())
.thenReturn(
TimeUnit.MILLISECONDS.toNanos(
MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L));
provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(
Status.Code.UNKNOWN,
throwableType,
secondUpdateCertCount,
secondUpdateRootCount,
causeMessages);
}
private void verifyWatcherErrorUpdates(
Status.Code code,
Class<?> throwableType,
int updateCertCount,
int updateRootCount,
String... causeMessages) {
verify(mockWatcher, times(updateCertCount))
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, times(updateRootCount))
.updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
if (code == null && throwableType == null && causeMessages == null) {
verify(mockWatcher, never()).onError(any(Status.class));
} else {
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1)).onError(statusCaptor.capture());
Status status = statusCaptor.getValue();
assertThat(status.getCode()).isEqualTo(code);
Throwable cause = status.getCause();
assertThat(cause).isInstanceOf(throwableType);
for (String causeMessage : causeMessages) {
assertThat(cause).hasMessageThat().contains(causeMessage);
cause = cause.getCause();
}
}
}
private void verifyTimeServiceAndScheduledHandle() {
verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS));
assertThat(provider.scheduledHandle).isNotNull();
assertThat(provider.scheduledHandle.isPending()).isTrue();
}
private void verifyWatcherUpdates(String certPemFile, String rootPemFile)
throws IOException, CertificateException {
if (certPemFile != null) {
ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1))
.updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
List<X509Certificate> certChain = certChainCaptor.getValue();
assertThat(certChain).hasSize(1);
assertThat(certChain.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(certPemFile));
} else {
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
}
if (rootPemFile != null) {
ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
List<X509Certificate> roots = rootsCaptor.getValue();
assertThat(roots).hasSize(1);
assertThat(roots.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(rootPemFile));
verify(mockWatcher, never()).onError(any(Status.class));
} else {
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
}
}
}