diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java index 823d38d1e9..0bea184fb9 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java @@ -16,11 +16,14 @@ package io.grpc.xds.internal.certprovider; +import com.google.common.annotations.VisibleForTesting; import io.grpc.Status; -import java.io.Closeable; +import io.grpc.xds.internal.sds.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.HashSet; import java.util.List; +import java.util.Set; /** * A plug-in that provides certificates required by the xDS security component and created @@ -33,7 +36,7 @@ import java.util.List; */ public abstract class CertificateProvider implements Closeable { - /** A watcher is registered via the constructor to receive updates for the certificates. */ + /** A watcher is registered to receive certificate updates. */ public interface Watcher { void updateCertificate(PrivateKey key, List certChain); @@ -42,6 +45,41 @@ public abstract class CertificateProvider implements Closeable { void onError(Status errorStatus); } + @VisibleForTesting + static final class DistributorWatcher implements Watcher { + @VisibleForTesting + final Set downsstreamWatchers = new HashSet<>(); + + synchronized void addWatcher(Watcher watcher) { + downsstreamWatchers.add(watcher); + } + + synchronized void removeWatcher(Watcher watcher) { + downsstreamWatchers.remove(watcher); + } + + @Override + public void updateCertificate(PrivateKey key, List certChain) { + for (Watcher watcher : downsstreamWatchers) { + watcher.updateCertificate(key, certChain); + } + } + + @Override + public void updateTrustedRoots(List trustedRoots) { + for (Watcher watcher : downsstreamWatchers) { + watcher.updateTrustedRoots(trustedRoots); + } + } + + @Override + public void onError(Status errorStatus) { + for (Watcher watcher : downsstreamWatchers) { + watcher.onError(errorStatus); + } + } + } + /** * Concrete subclasses will call this to register the {@link Watcher}. * @@ -51,7 +89,7 @@ public abstract class CertificateProvider implements Closeable { * Used by server-side and mTLS client-side. Note the Provider is always required * to call updateTrustedRoots to provide trusted-root updates. */ - protected CertificateProvider(Watcher watcher, boolean notifyCertUpdates) { + protected CertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates) { this.watcher = watcher; this.notifyCertUpdates = notifyCertUpdates; } @@ -60,6 +98,16 @@ public abstract class CertificateProvider implements Closeable { @Override public abstract void close(); - protected final Watcher watcher; - protected final boolean notifyCertUpdates; + private final DistributorWatcher watcher; + private final boolean notifyCertUpdates; + + public DistributorWatcher getWatcher() { + return watcher; + } + + public boolean isNotifyCertUpdates() { + return notifyCertUpdates; + } + + } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java index c99738ef8d..92b2d4d6aa 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java @@ -32,8 +32,14 @@ interface CertificateProviderProvider { * @param config configuration needed by the Provider to create the CertificateProvider. A form of * JSON that the Provider understands e.g. a string or a key-value Map. * @param watcher A {@link Watcher} to receive updates from the CertificateProvider - * @param notifyCertUpdates See {@link CertificateProvider#CertificateProvider(Watcher, boolean)} + * @param notifyCertUpdates if true, the provider is required to call the watcher’s + * updateCertificate method. Implies the Provider is capable of minting certificates. Used + * by server-side and mTLS client-side. Note the Provider is always required to call + * updateTrustedRoots to provide trusted-root updates. + * @throws IllegalArgumentException in case of errors in processing config. + * @throws UnsupportedOperationException if the plugin is incapable of sending cert updates when + * notifyCertUpdates is true. */ CertificateProvider createCertificateProvider( - Object config, Watcher watcher, boolean notifyCertUpdates); + Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates); } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java new file mode 100644 index 0000000000..36db37e5db --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import java.util.LinkedHashMap; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; + +/** Maintains {@link CertificateProvider}s for all registered plugins. */ +@ThreadSafe +public final class CertificateProviderRegistry { + private static CertificateProviderRegistry instance; + private final LinkedHashMap providers = + new LinkedHashMap<>(); + + @VisibleForTesting + CertificateProviderRegistry() { + } + + /** Returns the singleton registry. */ + public static synchronized CertificateProviderRegistry getInstance() { + if (instance == null) { + instance = new CertificateProviderRegistry(); + } + return instance; + } + + /** + * Register a {@link CertificateProviderProvider}. + * + *

If a provider with the same {@link CertificateProviderProvider#getName name} was already + * registered, this method will overwrite that provider. + */ + public synchronized void register(CertificateProviderProvider certificateProviderProvider) { + checkNotNull(certificateProviderProvider, "certificateProviderProvider"); + providers.put(certificateProviderProvider.getName(), certificateProviderProvider); + } + + /** + * Deregisters a provider. No-op if the provider is not in the registry. + * + * @param certificateProviderProvider the provider that was added to the registry via + * {@link #register}. + */ + public synchronized void deregister(CertificateProviderProvider certificateProviderProvider) { + checkNotNull(certificateProviderProvider, "certificateProviderProvider"); + providers.remove(certificateProviderProvider.getName()); + } + + /** + * Returns the CertificateProviderProvider for the given name, or {@code null} if no + * provider is found. Each provider declares its name via {@link + * CertificateProviderProvider#getName}. This is an internal method of the Registry + * *only* used by the framework. + */ + @Nullable + synchronized CertificateProviderProvider getProvider(String name) { + return providers.get(checkNotNull(name, "name")); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java new file mode 100644 index 0000000000..04542b05dd --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java @@ -0,0 +1,195 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.xds.internal.certprovider.CertificateProvider.Watcher; +import io.grpc.xds.internal.sds.ReferenceCountingMap; + +import java.util.Objects; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Global map of all ref-counted {@link CertificateProvider}s that have been instantiated in + * the application. Also propagates updates received from a {@link CertificateProvider} to all + * the {@link Watcher}s registered for that CertificateProvider. The Store is meant to be + * used internally by gRPC and *not* a public API. + */ +@ThreadSafe +public final class CertificateProviderStore { + private static final Logger logger = Logger.getLogger(CertificateProviderStore.class.getName()); + + private static CertificateProviderStore instance; + private final CertificateProviderRegistry certificateProviderRegistry; + private final ReferenceCountingMap certProviderMap; + + /** Opaque Handle returned by {@link #createOrGetProvider}. */ + @VisibleForTesting + final class Handle implements java.io.Closeable { + private final CertProviderKey key; + private final Watcher watcher; + @VisibleForTesting + final CertificateProvider certProvider; + + private Handle(CertProviderKey key, Watcher watcher, CertificateProvider certProvider) { + this.key = key; + this.watcher = watcher; + this.certProvider = certProvider; + } + + /** + * Removes the associated {@link Watcher} for the {@link CertificateProvider} and + * decrements the ref-count. Releases the {@link CertificateProvider} if the ref-count + * has reached 0. + */ + @Override + public synchronized void close() { + CertificateProvider.DistributorWatcher distWatcher = certProvider.getWatcher(); + distWatcher.removeWatcher(watcher); + certProviderMap.release(key, certProvider); + } + } + + private static final class CertProviderKey { + private final String certName; + private final String pluginName; + private final boolean notifyCertUpdates; + private final Object config; + + private CertProviderKey( + String certName, String pluginName, boolean notifyCertUpdates, Object config) { + this.certName = certName; + this.pluginName = pluginName; + this.notifyCertUpdates = notifyCertUpdates; + this.config = config; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CertProviderKey)) { + return false; + } + CertProviderKey that = (CertProviderKey) o; + return notifyCertUpdates == that.notifyCertUpdates + && Objects.equals(certName, that.certName) + && Objects.equals(pluginName, that.pluginName) + && Objects.equals(config, that.config); + } + + @Override + public int hashCode() { + return Objects.hash(certName, pluginName, notifyCertUpdates, config); + } + + @Override + public String toString() { + return "CertProviderKey{" + + "certName='" + + certName + + '\'' + + ", pluginName='" + + pluginName + + '\'' + + ", notifyCertUpdates=" + + notifyCertUpdates + + ", config=" + + config + + '}'; + } + } + + private final class CertProviderFactory + implements ReferenceCountingMap.ValueFactory { + + private CertProviderFactory() { + } + + @Override + public CertificateProvider create(CertProviderKey key) { + CertificateProviderProvider certProviderProvider = + certificateProviderRegistry.getProvider(key.pluginName); + if (certProviderProvider == null) { + throw new IllegalArgumentException("Provider not found."); + } + return certProviderProvider.createCertificateProvider( + key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates); + } + } + + @VisibleForTesting + CertificateProviderStore(CertificateProviderRegistry certificateProviderRegistry) { + this.certificateProviderRegistry = certificateProviderRegistry; + certProviderMap = new ReferenceCountingMap<>(new CertProviderFactory()); + } + + /** + * Creates or retrieves a {@link CertificateProvider} instance, increments its ref-count and + * registers the watcher passed. Returns a {@link Handle} that can be {@link Handle#close()}d when + * the instance is no longer needed by the caller. + * + * @param notifyCertUpdates when true, the caller is interested in identity cert updates. When + * false, the caller cannot depend on receiving the {@link Watcher#updateCertificate} + * callbacks but may still receive these callbacks which should be ignored. + * @throws IllegalArgumentException in case of errors in processing config or the plugin is + * incapable of sending cert updates when notifyCertUpdates is true. + * @throws UnsupportedOperationException if the plugin is incapable of sending cert updates when + * notifyCertUpdates is true. + */ + public synchronized Handle createOrGetProvider( + String certName, + String pluginName, + Object config, + Watcher watcher, + boolean notifyCertUpdates) { + if (!notifyCertUpdates) { + // we try to get a provider first for notifyCertUpdates==true always + try { + return createProviderHelper(certName, pluginName, config, watcher, true); + } catch (UnsupportedOperationException uoe) { + // ignore & log exception and fall thru to create a provider with actual value + logger.log(Level.FINE, "Trying to get provider for notifyCertUpdates==true", uoe); + } + } + return createProviderHelper(certName, pluginName, config, watcher, notifyCertUpdates); + } + + private synchronized Handle createProviderHelper( + String certName, + String pluginName, + Object config, + Watcher watcher, + boolean notifyCertUpdates) { + CertProviderKey key = new CertProviderKey(certName, pluginName, notifyCertUpdates, config); + CertificateProvider provider = certProviderMap.get(key); + CertificateProvider.DistributorWatcher distWatcher = provider.getWatcher(); + distWatcher.addWatcher(watcher); + return new Handle(key, watcher, provider); + } + + /** Gets the CertificateProviderStore singleton instance. */ + public static synchronized CertificateProviderStore getInstance() { + if (instance == null) { + instance = new CertificateProviderStore(CertificateProviderRegistry.getInstance()); + } + return instance; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java b/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java index a2dc8178c3..c3695cecaf 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java @@ -16,7 +16,7 @@ package io.grpc.xds.internal.sds; -interface Closeable extends java.io.Closeable { +public interface Closeable extends java.io.Closeable { @Override public void close(); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java b/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java index 45dee9cc50..9f520ea7bc 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java @@ -36,12 +36,12 @@ import javax.annotation.concurrent.ThreadSafe; * @param Value type for the map - it should be a {@link Closeable} */ @ThreadSafe -final class ReferenceCountingMap { +public final class ReferenceCountingMap { private final Map> instances = new HashMap<>(); private final ValueFactory valueFactory; - ReferenceCountingMap(ValueFactory valueFactory) { + public ReferenceCountingMap(ValueFactory valueFactory) { checkNotNull(valueFactory, "valueFactory"); this.valueFactory = valueFactory; } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java new file mode 100644 index 0000000000..d9099dc901 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java @@ -0,0 +1,331 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyListOf; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.grpc.Status; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.List; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +/** Unit tests for {@link CertificateProviderStore}. */ +@RunWith(JUnit4.class) +public class CertificateProviderStoreTest { + + private CertificateProviderRegistry certificateProviderRegistry; + private CertificateProviderStore certificateProviderStore; + private boolean throwExceptionForCertUpdates; + + private class TestCertificateProvider extends CertificateProvider { + Object config; + CertificateProviderProvider certProviderProvider; + int closeCalled = 0; + + protected TestCertificateProvider( + CertificateProvider.DistributorWatcher watcher, + boolean notifyCertUpdates, + Object config, + CertificateProviderProvider certificateProviderProvider) { + super(watcher, notifyCertUpdates); + if (throwExceptionForCertUpdates && notifyCertUpdates) { + throw new UnsupportedOperationException("Provider does not support Certificate Updates."); + } + this.config = config; + this.certProviderProvider = certificateProviderProvider; + } + + @Override + public void close() { + closeCalled++; + } + } + + @Before + public void setUp() { + certificateProviderRegistry = new CertificateProviderRegistry(); + certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); + throwExceptionForCertUpdates = false; + } + + @Test + public void pluginNotRegistered_expectException() { + CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class); + try { + CertificateProviderStore.Handle unused = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher, true); + fail("exception expected"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().isEqualTo("Provider not found."); + } + } + + @Test + public void pluginUnregistered_expectException() { + CertificateProviderProvider certificateProviderProvider = registerPlugin("plugin1"); + CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher, true); + handle.close(); + certificateProviderRegistry.deregister(certificateProviderProvider); + try { + handle = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher, true); + fail("exception expected"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().isEqualTo("Provider not found."); + } + } + + @Test + public void notifyCertUpdatesNotSupported_expectException() { + CertificateProviderProvider unused = registerPlugin("plugin1"); + throwExceptionForCertUpdates = true; + CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class); + try { + CertificateProviderStore.Handle unused1 = + certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher, true); + fail("exception expected"); + } catch (UnsupportedOperationException expected) { + assertThat(expected) + .hasMessageThat() + .isEqualTo("Provider does not support Certificate Updates."); + } + } + + @Test + public void notifyCertUpdatesNotSupported_expectExceptionOnSecondCall() { + CertificateProviderProvider unused = registerPlugin("plugin1"); + throwExceptionForCertUpdates = true; + CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle1 = + certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher, false); + try { + CertificateProviderStore.Handle unused1 = + certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher, true); + fail("exception expected"); + } catch (UnsupportedOperationException expected) { + assertThat(expected) + .hasMessageThat() + .isEqualTo("Provider does not support Certificate Updates."); + } + handle1.close(); + } + + @Test + @SuppressWarnings("deprecation") + public void onePluginSameConfig_sameInstance() { + registerPlugin("plugin1"); + CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher1, true); + CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher2, true); + assertThat(handle1).isNotSameInstanceAs(handle2); + assertThat(handle1.certProvider).isSameInstanceAs(handle2.certProvider); + assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class); + TestCertificateProvider testCertificateProvider = + (TestCertificateProvider) handle1.certProvider; + CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher(); + assertThat(distWatcher.downsstreamWatchers.size()).isEqualTo(2); + PrivateKey testKey = mock(PrivateKey.class); + X509Certificate cert = mock(X509Certificate.class); + List testList = ImmutableList.of(cert); + testCertificateProvider.getWatcher().updateCertificate(testKey, testList); + verify(mockWatcher1, times(1)).updateCertificate(eq(testKey), eq(testList)); + verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList)); + reset(mockWatcher1); + reset(mockWatcher2); + testCertificateProvider.getWatcher().updateTrustedRoots(testList); + verify(mockWatcher1, times(1)).updateTrustedRoots(eq(testList)); + verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList)); + reset(mockWatcher1); + reset(mockWatcher2); + handle1.close(); + assertThat(testCertificateProvider.closeCalled).isEqualTo(0); + assertThat(distWatcher.downsstreamWatchers.size()).isEqualTo(1); + testCertificateProvider.getWatcher().updateCertificate(testKey, testList); + verify(mockWatcher1, never()) + .updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class)); + verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList)); + testCertificateProvider.getWatcher().updateTrustedRoots(testList); + verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList)); + handle2.close(); + assertThat(testCertificateProvider.closeCalled).isEqualTo(1); + } + + @Test + public void onePluginTwoInstances_notifyError() { + registerPlugin("plugin1"); + CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher1, true); + CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher2, true); + TestCertificateProvider testCertificateProvider = + (TestCertificateProvider) handle1.certProvider; + testCertificateProvider.getWatcher().onError(Status.CANCELLED); + verify(mockWatcher1, times(1)).onError(eq(Status.CANCELLED)); + verify(mockWatcher2, times(1)).onError(eq(Status.CANCELLED)); + handle1.close(); + handle2.close(); + } + + @Test + public void onePluginDifferentConfig_differentInstance() { + CertificateProviderProvider certProviderProvider = registerPlugin("plugin1"); + CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher1, true); + CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config2", mockWatcher2, true); + checkDifferentInstances( + mockWatcher1, handle1, certProviderProvider, mockWatcher2, handle2, certProviderProvider); + } + + @Test + public void onePluginDifferentCertName_differentInstance() { + CertificateProviderProvider certProviderProvider = registerPlugin("plugin1"); + CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher1, true); + CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider( + "cert-name2", "plugin1", "config", mockWatcher2, true); + checkDifferentInstances( + mockWatcher1, handle1, certProviderProvider, mockWatcher2, handle2, certProviderProvider); + } + + @Test + public void onePluginDifferentNotifyValue_sameInstance() { + CertificateProviderProvider unused = registerPlugin("plugin1"); + CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher1, true); + CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher2, false); + assertThat(handle1).isNotSameInstanceAs(handle2); + assertThat(handle1.certProvider).isSameInstanceAs(handle2.certProvider); + } + + @Test + public void twoPlugins_differentInstance() { + CertificateProviderProvider certProviderProvider1 = registerPlugin("plugin1"); + CertificateProviderProvider certProviderProvider2 = registerPlugin("plugin2"); + CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin1", "config", mockWatcher1, true); + CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class); + CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider( + "cert-name1", "plugin2", "config", mockWatcher2, true); + checkDifferentInstances( + mockWatcher1, handle1, certProviderProvider1, mockWatcher2, handle2, certProviderProvider2); + } + + @SuppressWarnings("deprecation") + private void checkDifferentInstances( + CertificateProvider.Watcher mockWatcher1, + CertificateProviderStore.Handle handle1, + CertificateProviderProvider certProviderProvider1, + CertificateProvider.Watcher mockWatcher2, + CertificateProviderStore.Handle handle2, + CertificateProviderProvider certProviderProvider2) { + assertThat(handle1.certProvider).isNotSameInstanceAs(handle2.certProvider); + TestCertificateProvider testCertificateProvider1 = + (TestCertificateProvider) handle1.certProvider; + TestCertificateProvider testCertificateProvider2 = + (TestCertificateProvider) handle2.certProvider; + assertThat(testCertificateProvider1.certProviderProvider) + .isSameInstanceAs(certProviderProvider1); + assertThat(testCertificateProvider2.certProviderProvider) + .isSameInstanceAs(certProviderProvider2); + CertificateProvider.DistributorWatcher distWatcher1 = testCertificateProvider1.getWatcher(); + assertThat(distWatcher1.downsstreamWatchers.size()).isEqualTo(1); + CertificateProvider.DistributorWatcher distWatcher2 = testCertificateProvider2.getWatcher(); + assertThat(distWatcher2.downsstreamWatchers.size()).isEqualTo(1); + PrivateKey testKey1 = mock(PrivateKey.class); + X509Certificate cert1 = mock(X509Certificate.class); + List testList1 = ImmutableList.of(cert1); + testCertificateProvider1.getWatcher().updateCertificate(testKey1, testList1); + verify(mockWatcher1, times(1)).updateCertificate(eq(testKey1), eq(testList1)); + verify(mockWatcher2, never()) + .updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class)); + reset(mockWatcher1); + + PrivateKey testKey2 = mock(PrivateKey.class); + X509Certificate cert2 = mock(X509Certificate.class); + List testList2 = ImmutableList.of(cert2); + testCertificateProvider2.getWatcher().updateCertificate(testKey2, testList2); + verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2)); + verify(mockWatcher1, never()) + .updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class)); + handle2.close(); + assertThat(testCertificateProvider2.closeCalled).isEqualTo(1); + handle1.close(); + assertThat(testCertificateProvider1.closeCalled).isEqualTo(1); + } + + private CertificateProviderProvider registerPlugin(String pluginName) { + final CertificateProviderProvider certProviderProvider = + mock(CertificateProviderProvider.class); + when(certProviderProvider.getName()).thenReturn(pluginName); + when(certProviderProvider.createCertificateProvider( + any(Object.class), + any(CertificateProvider.DistributorWatcher.class), + any(Boolean.TYPE))) + .then( + new Answer() { + + @Override + public CertificateProvider answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + Object config = args[0]; + CertificateProvider.DistributorWatcher watcher = + (CertificateProvider.DistributorWatcher) args[1]; + boolean notifyCertUpdates = (Boolean) args[2]; + return new TestCertificateProvider( + watcher, notifyCertUpdates, config, certProviderProvider); + } + }); + certificateProviderRegistry.register(certProviderProvider); + return certProviderProvider; + } +}