diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java index 6580185794..bbcb521c0d 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java @@ -19,9 +19,7 @@ package io.grpc.xds.internal.certprovider; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; -import io.grpc.InternalLogId; import io.grpc.Status; -import io.grpc.SynchronizationContext; import io.grpc.internal.TimeProvider; import io.grpc.xds.internal.sds.trust.CertificateUtils; @@ -34,27 +32,28 @@ import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; // TODO(sanjaypujare): abstract out common functionality into an an abstract superclass /** Implementation of {@link CertificateProvider} for file watching cert provider. */ -final class FileWatcherCertificateProvider extends CertificateProvider { +final class FileWatcherCertificateProvider extends CertificateProvider implements Runnable { private static final Logger logger = Logger.getLogger(FileWatcherCertificateProvider.class.getName()); - private final SynchronizationContext syncContext; private final ScheduledExecutorService scheduledExecutorService; private final TimeProvider timeProvider; private final Path certFile; private final Path keyFile; private final Path trustFile; private final long refreshIntervalInSeconds; - @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle; + @VisibleForTesting ScheduledFuture scheduledFuture; private FileTime lastModifiedTimeCert; private FileTime lastModifiedTimeKey; private FileTime lastModifiedTimeRoot; + private boolean shutdown; FileWatcherCertificateProvider( DistributorWatcher watcher, @@ -73,34 +72,6 @@ final class FileWatcherCertificateProvider extends CertificateProvider { this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile")); this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile")); this.refreshIntervalInSeconds = refreshIntervalInSeconds; - this.syncContext = createSynchronizationContext(certFile); - } - - private SynchronizationContext createSynchronizationContext(String details) { - final InternalLogId logId = - InternalLogId.allocate("DynamicReloadingCertificateProvider", details); - return new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - private boolean panicMode; - - @Override - public void uncaughtException(Thread t, Throwable e) { - logger.log( - Level.SEVERE, - "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", - e); - panic(e); - } - - void panic(final Throwable t) { - if (panicMode) { - // Preserve the first panic information - return; - } - panicMode = true; - close(); - } - }); } @Override @@ -109,18 +80,19 @@ final class FileWatcherCertificateProvider extends CertificateProvider { } @Override - public void close() { - if (scheduledHandle != null) { - scheduledHandle.cancel(); - scheduledHandle = null; + public synchronized void close() { + shutdown = true; + if (scheduledFuture != null) { + scheduledFuture.cancel(true); + scheduledFuture = null; } getWatcher().close(); } - private void scheduleNextRefreshCertificate(long delayInSeconds) { - RefreshCertificateTask runnable = new RefreshCertificateTask(); - scheduledHandle = - syncContext.schedule(runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService); + private synchronized void scheduleNextRefreshCertificate(long delayInSeconds) { + if (!shutdown) { + scheduledFuture = scheduledExecutorService.schedule(this, delayInSeconds, TimeUnit.SECONDS); + } } @VisibleForTesting @@ -199,11 +171,17 @@ final class FileWatcherCertificateProvider extends CertificateProvider { - timeProvider.currentTimeNanos()); } - @VisibleForTesting - class RefreshCertificateTask implements Runnable { - @Override - public void run() { - checkAndReloadCertificates(); + @Override + public void run() { + if (!shutdown) { + try { + checkAndReloadCertificates(); + } catch (Throwable t) { + logger.log(Level.SEVERE, "Uncaught exception!", t); + if (t instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + } } } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java index bfef3677b9..c1b0ce3f50 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java @@ -104,6 +104,7 @@ final class FileWatcherCertificateProviderProvider implements CertificateProvide try { Duration duration = Durations.parse(refreshIntervalString); configObj.refrehInterval = duration.getSeconds(); + checkArgument(configObj.refrehInterval > 0L, "refreshInterval needs to be greater than 0"); } catch (ParseException e) { throw new IllegalArgumentException(e); } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java index b3ab599952..d113b52005 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java @@ -109,6 +109,22 @@ public class FileWatcherCertificateProviderProviderTest { eq(timeProvider)); } + @Test + public void createProvider_zeroRefreshInterval() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(ZERO_REFRESH_INTERVAL); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (IllegalArgumentException iae) { + assertThat(iae).hasMessageThat().isEqualTo("refreshInterval needs to be greater than 0"); + } + } + @Test public void createProvider_missingCert_expectException() throws IOException { CertificateProvider.DistributorWatcher distWatcher = @@ -183,4 +199,12 @@ public class FileWatcherCertificateProviderProviderTest { + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"" + " }"; + + private static final String ZERO_REFRESH_INTERVAL = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key3.pem\"," + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates4.pem\"," + + " \"refresh_interval\": \"0s\"" + + " }"; } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java index e2cb464734..474c05d048 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java @@ -134,7 +134,7 @@ public class FileWatcherCertificateProviderTest { populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); - verifyTimeServiceAndScheduledHandle(); + verifyTimeServiceAndScheduledFuture(); reset(mockWatcher, timeService); doReturn(scheduledFuture) @@ -142,7 +142,7 @@ public class FileWatcherCertificateProviderTest { .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates(null, null, 0, 0, (String[]) null); - verifyTimeServiceAndScheduledHandle(); + verifyTimeServiceAndScheduledFuture(); } @Test @@ -163,9 +163,26 @@ public class FileWatcherCertificateProviderTest { populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); - verifyTimeServiceAndScheduledHandle(); + verifyTimeServiceAndScheduledFuture(); } + @Test + public void closeDoesNotScheduleNext() throws IOException, CertificateException { + MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = + new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + provider.close(); + provider.checkAndReloadCertificates(); + verify(mockWatcher, never()) + .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); + verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); + verify(timeService, never()).schedule(any(Runnable.class), any(Long.TYPE), any(TimeUnit.class)); + } + + @Test public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException { MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = @@ -184,7 +201,7 @@ public class FileWatcherCertificateProviderTest { populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherUpdates(null, SERVER_1_PEM_FILE); - verifyTimeServiceAndScheduledHandle(); + verifyTimeServiceAndScheduledFuture(); } @Test @@ -206,7 +223,7 @@ public class FileWatcherCertificateProviderTest { populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherUpdates(SERVER_0_PEM_FILE, null); - verifyTimeServiceAndScheduledHandle(); + verifyTimeServiceAndScheduledFuture(); } @Test @@ -341,10 +358,11 @@ public class FileWatcherCertificateProviderTest { } } - private void verifyTimeServiceAndScheduledHandle() { + private void verifyTimeServiceAndScheduledFuture() { verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS)); - assertThat(provider.scheduledHandle).isNotNull(); - assertThat(provider.scheduledHandle.isPending()).isTrue(); + assertThat(provider.scheduledFuture).isNotNull(); + assertThat(provider.scheduledFuture.isDone()).isFalse(); + assertThat(provider.scheduledFuture.isCancelled()).isFalse(); } private void verifyWatcherUpdates(String certPemFile, String rootPemFile)