xds: remove syncContext and just use the executorService (#8006)

This commit is contained in:
sanjaypujare 2021-03-24 12:41:58 -07:00 committed by GitHub
parent c4dec7517f
commit b7afbc30d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 54 deletions

View File

@ -19,9 +19,7 @@ package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.InternalLogId;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.TimeProvider; import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sds.trust.CertificateUtils; import io.grpc.xds.internal.sds.trust.CertificateUtils;
@ -34,27 +32,28 @@ import java.security.PrivateKey;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Arrays; import java.util.Arrays;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
// TODO(sanjaypujare): abstract out common functionality into an an abstract superclass // TODO(sanjaypujare): abstract out common functionality into an an abstract superclass
/** Implementation of {@link CertificateProvider} for file watching cert provider. */ /** Implementation of {@link CertificateProvider} for file watching cert provider. */
final class FileWatcherCertificateProvider extends CertificateProvider { final class FileWatcherCertificateProvider extends CertificateProvider implements Runnable {
private static final Logger logger = private static final Logger logger =
Logger.getLogger(FileWatcherCertificateProvider.class.getName()); Logger.getLogger(FileWatcherCertificateProvider.class.getName());
private final SynchronizationContext syncContext;
private final ScheduledExecutorService scheduledExecutorService; private final ScheduledExecutorService scheduledExecutorService;
private final TimeProvider timeProvider; private final TimeProvider timeProvider;
private final Path certFile; private final Path certFile;
private final Path keyFile; private final Path keyFile;
private final Path trustFile; private final Path trustFile;
private final long refreshIntervalInSeconds; private final long refreshIntervalInSeconds;
@VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle; @VisibleForTesting ScheduledFuture<?> scheduledFuture;
private FileTime lastModifiedTimeCert; private FileTime lastModifiedTimeCert;
private FileTime lastModifiedTimeKey; private FileTime lastModifiedTimeKey;
private FileTime lastModifiedTimeRoot; private FileTime lastModifiedTimeRoot;
private boolean shutdown;
FileWatcherCertificateProvider( FileWatcherCertificateProvider(
DistributorWatcher watcher, DistributorWatcher watcher,
@ -73,34 +72,6 @@ final class FileWatcherCertificateProvider extends CertificateProvider {
this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile")); this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile"));
this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile")); this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile"));
this.refreshIntervalInSeconds = refreshIntervalInSeconds; this.refreshIntervalInSeconds = refreshIntervalInSeconds;
this.syncContext = createSynchronizationContext(certFile);
}
private SynchronizationContext createSynchronizationContext(String details) {
final InternalLogId logId =
InternalLogId.allocate("DynamicReloadingCertificateProvider", details);
return new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
private boolean panicMode;
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.log(
Level.SEVERE,
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
e);
panic(e);
}
void panic(final Throwable t) {
if (panicMode) {
// Preserve the first panic information
return;
}
panicMode = true;
close();
}
});
} }
@Override @Override
@ -109,18 +80,19 @@ final class FileWatcherCertificateProvider extends CertificateProvider {
} }
@Override @Override
public void close() { public synchronized void close() {
if (scheduledHandle != null) { shutdown = true;
scheduledHandle.cancel(); if (scheduledFuture != null) {
scheduledHandle = null; scheduledFuture.cancel(true);
scheduledFuture = null;
} }
getWatcher().close(); getWatcher().close();
} }
private void scheduleNextRefreshCertificate(long delayInSeconds) { private synchronized void scheduleNextRefreshCertificate(long delayInSeconds) {
RefreshCertificateTask runnable = new RefreshCertificateTask(); if (!shutdown) {
scheduledHandle = scheduledFuture = scheduledExecutorService.schedule(this, delayInSeconds, TimeUnit.SECONDS);
syncContext.schedule(runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService); }
} }
@VisibleForTesting @VisibleForTesting
@ -199,11 +171,17 @@ final class FileWatcherCertificateProvider extends CertificateProvider {
- timeProvider.currentTimeNanos()); - timeProvider.currentTimeNanos());
} }
@VisibleForTesting @Override
class RefreshCertificateTask implements Runnable { public void run() {
@Override if (!shutdown) {
public void run() { try {
checkAndReloadCertificates(); checkAndReloadCertificates();
} catch (Throwable t) {
logger.log(Level.SEVERE, "Uncaught exception!", t);
if (t instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
}
} }
} }

View File

@ -104,6 +104,7 @@ final class FileWatcherCertificateProviderProvider implements CertificateProvide
try { try {
Duration duration = Durations.parse(refreshIntervalString); Duration duration = Durations.parse(refreshIntervalString);
configObj.refrehInterval = duration.getSeconds(); configObj.refrehInterval = duration.getSeconds();
checkArgument(configObj.refrehInterval > 0L, "refreshInterval needs to be greater than 0");
} catch (ParseException e) { } catch (ParseException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }

View File

@ -109,6 +109,22 @@ public class FileWatcherCertificateProviderProviderTest {
eq(timeProvider)); eq(timeProvider));
} }
@Test
public void createProvider_zeroRefreshInterval() throws IOException {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
@SuppressWarnings("unchecked")
Map<String, ?> map = (Map<String, ?>) JsonParser.parse(ZERO_REFRESH_INTERVAL);
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create()).thenReturn(mockService);
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (IllegalArgumentException iae) {
assertThat(iae).hasMessageThat().isEqualTo("refreshInterval needs to be greater than 0");
}
}
@Test @Test
public void createProvider_missingCert_expectException() throws IOException { public void createProvider_missingCert_expectException() throws IOException {
CertificateProvider.DistributorWatcher distWatcher = CertificateProvider.DistributorWatcher distWatcher =
@ -183,4 +199,12 @@ public class FileWatcherCertificateProviderProviderTest {
+ " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\","
+ " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"" + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\""
+ " }"; + " }";
private static final String ZERO_REFRESH_INTERVAL =
"{\n"
+ " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\","
+ " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key3.pem\","
+ " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates4.pem\","
+ " \"refresh_interval\": \"0s\""
+ " }";
} }

View File

@ -134,7 +134,7 @@ public class FileWatcherCertificateProviderTest {
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE);
verifyTimeServiceAndScheduledHandle(); verifyTimeServiceAndScheduledFuture();
reset(mockWatcher, timeService); reset(mockWatcher, timeService);
doReturn(scheduledFuture) doReturn(scheduledFuture)
@ -142,7 +142,7 @@ public class FileWatcherCertificateProviderTest {
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherErrorUpdates(null, null, 0, 0, (String[]) null); verifyWatcherErrorUpdates(null, null, 0, 0, (String[]) null);
verifyTimeServiceAndScheduledHandle(); verifyTimeServiceAndScheduledFuture();
} }
@Test @Test
@ -163,9 +163,26 @@ public class FileWatcherCertificateProviderTest {
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE);
verifyTimeServiceAndScheduledHandle(); verifyTimeServiceAndScheduledFuture();
} }
@Test
public void closeDoesNotScheduleNext() throws IOException, CertificateException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
new MeshCaCertificateProviderTest.TestScheduledFuture<>();
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
provider.close();
provider.checkAndReloadCertificates();
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
verify(timeService, never()).schedule(any(Runnable.class), any(Long.TYPE), any(TimeUnit.class));
}
@Test @Test
public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException { public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException {
MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture = MeshCaCertificateProviderTest.TestScheduledFuture<?> scheduledFuture =
@ -184,7 +201,7 @@ public class FileWatcherCertificateProviderTest {
populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false); populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false);
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherUpdates(null, SERVER_1_PEM_FILE); verifyWatcherUpdates(null, SERVER_1_PEM_FILE);
verifyTimeServiceAndScheduledHandle(); verifyTimeServiceAndScheduledFuture();
} }
@Test @Test
@ -206,7 +223,7 @@ public class FileWatcherCertificateProviderTest {
populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false); populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false);
provider.checkAndReloadCertificates(); provider.checkAndReloadCertificates();
verifyWatcherUpdates(SERVER_0_PEM_FILE, null); verifyWatcherUpdates(SERVER_0_PEM_FILE, null);
verifyTimeServiceAndScheduledHandle(); verifyTimeServiceAndScheduledFuture();
} }
@Test @Test
@ -341,10 +358,11 @@ public class FileWatcherCertificateProviderTest {
} }
} }
private void verifyTimeServiceAndScheduledHandle() { private void verifyTimeServiceAndScheduledFuture() {
verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS)); verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS));
assertThat(provider.scheduledHandle).isNotNull(); assertThat(provider.scheduledFuture).isNotNull();
assertThat(provider.scheduledHandle.isPending()).isTrue(); assertThat(provider.scheduledFuture.isDone()).isFalse();
assertThat(provider.scheduledFuture.isCancelled()).isFalse();
} }
private void verifyWatcherUpdates(String certPemFile, String rootPemFile) private void verifyWatcherUpdates(String certPemFile, String rootPemFile)