diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java index 7267511898..bb6a636e96 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java @@ -21,17 +21,17 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.util.concurrent.ThreadFactoryBuilder; import io.grpc.xds.Bootstrapper; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; +import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import java.io.IOException; import java.util.concurrent.Executors; /** Factory to create client-side SslContextProvider from UpstreamTlsContext. */ final class ClientSslContextProviderFactory - implements SslContextProviderFactory { + implements ValueFactory { /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override - public SslContextProvider createSslContextProvider(UpstreamTlsContext upstreamTlsContext) { + public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); checkNotNull( upstreamTlsContext.getCommonTlsContext(), diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java b/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java new file mode 100644 index 0000000000..a2dc8178c3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java @@ -0,0 +1,23 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +interface Closeable extends java.io.Closeable { + + @Override + public void close(); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMap.java b/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java similarity index 51% rename from xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMap.java rename to xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java index 49921f9c8f..45dee9cc50 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMap.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java @@ -26,38 +26,38 @@ import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; /** - * A map for managing {@link SslContextProvider}s as reference-counted shared resources. + * A map for managing reference-counted shared resources - typically providers. * - *

A key (of generic type K) identifies a {@link SslContextProvider}. The map also depends on a - * factory {@link SslContextProviderFactory} to create a new instance of {@link SslContextProvider} - * as needed. {@link SslContextProvider}s are ref-counted and closed by calling {@link - * SslContextProvider#close()} when ref-count reaches zero. + *

A key (of generic type K) identifies a provider (of generic type V). The map also depends on a + * factory {@link ValueFactory} to create a new instance of V as needed. Values are ref-counted and + * closed by calling {@link Closeable#close()} when ref-count reaches zero. * * @param Key type for the map + * @param Value type for the map - it should be a {@link Closeable} */ @ThreadSafe -final class ReferenceCountingSslContextProviderMap { +final class ReferenceCountingMap { - private final Map instances = new HashMap<>(); - private final SslContextProviderFactory sslContextProviderFactory; + private final Map> instances = new HashMap<>(); + private final ValueFactory valueFactory; - ReferenceCountingSslContextProviderMap(SslContextProviderFactory sslContextProviderFactory) { - checkNotNull(sslContextProviderFactory, "sslContextProviderFactory"); - this.sslContextProviderFactory = sslContextProviderFactory; + ReferenceCountingMap(ValueFactory valueFactory) { + checkNotNull(valueFactory, "valueFactory"); + this.valueFactory = valueFactory; } /** - * Gets an existing instance of {@link SslContextProvider}. If it doesn't exist, creates a new one - * using the provided {@link SslContextProviderFactory<K>} + * Gets an existing instance of a provider. If it doesn't exist, creates a new one + * using the provided {@link ValueFactory <K, V>} */ @CheckReturnValue - public SslContextProvider get(K key) { + public V get(K key) { checkNotNull(key, "key"); return getInternal(key); } /** - * Releases an instance of the given {@link SslContextProvider}. + * Releases an instance of the given value. * *

The instance must have been obtained from {@link #get(K)}. Otherwise will throw * IllegalArgumentException. @@ -69,30 +69,30 @@ final class ReferenceCountingSslContextProviderMap { * @param value the instance to be released * @return a null which the caller can use to clear the reference to that instance. */ - public SslContextProvider release(K key, SslContextProvider value) { + public V release(K key, V value) { checkNotNull(key, "key"); checkNotNull(value, "value"); return releaseInternal(key, value); } - private synchronized SslContextProvider getInternal(K key) { - Instance instance = instances.get(key); + private synchronized V getInternal(K key) { + Instance instance = instances.get(key); if (instance == null) { - instance = new Instance(sslContextProviderFactory.createSslContextProvider(key)); + instance = new Instance<>(valueFactory.create(key)); instances.put(key, instance); - return instance.sslContextProvider; + return instance.value; } else { return instance.acquire(); } } - private synchronized SslContextProvider releaseInternal(K key, SslContextProvider instance) { - Instance cached = instances.get(key); + private synchronized V releaseInternal(K key, V value) { + Instance cached = instances.get(key); checkArgument(cached != null, "No cached instance found for %s", key); - checkArgument(instance == cached.sslContextProvider, "Releasing the wrong instance"); + checkArgument(value == cached.value, "Releasing the wrong instance"); if (cached.release()) { try { - cached.sslContextProvider.close(); + cached.value.close(); } finally { instances.remove(key); } @@ -101,19 +101,19 @@ final class ReferenceCountingSslContextProviderMap { return null; } - /** A factory to create an SslContextProvider from the given key. */ - public interface SslContextProviderFactory { - SslContextProvider createSslContextProvider(K key); + /** A factory to create a value from the given key. */ + public interface ValueFactory { + V create(K key); } - private static class Instance { - final SslContextProvider sslContextProvider; + private static final class Instance { + final V value; private int refCount; - /** Increment refCount and acquire a reference to sslContextProvider. */ - SslContextProvider acquire() { + /** Increment refCount and acquire a reference to value. */ + V acquire() { refCount++; - return sslContextProvider; + return value; } /** Decrement refCount and return true if it has reached 0. */ @@ -122,8 +122,8 @@ final class ReferenceCountingSslContextProviderMap { return --refCount == 0; } - Instance(SslContextProvider sslContextProvider) { - this.sslContextProvider = sslContextProvider; + Instance(V value) { + this.value = value; this.refCount = 1; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java index 6eb75b066a..26ff694cf5 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java @@ -203,7 +203,7 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl } @Override - void close() { + public void close() { if (certSdsClient != null) { certSdsClient.cancelSecretWatch(this); certSdsClient.shutdown(); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java index ec713afdb6..fb271a32dc 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java @@ -21,17 +21,17 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.util.concurrent.ThreadFactoryBuilder; import io.grpc.xds.Bootstrapper; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; -import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; +import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import java.io.IOException; import java.util.concurrent.Executors; /** Factory to create server-side SslContextProvider from DownstreamTlsContext. */ final class ServerSslContextProviderFactory - implements SslContextProviderFactory { + implements ValueFactory { /** Creates a SslContextProvider from the given DownstreamTlsContext. */ @Override - public SslContextProvider createSslContextProvider( + public SslContextProvider create( DownstreamTlsContext downstreamTlsContext) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); checkNotNull( diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java index 9c319dde87..6c73b7c737 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java @@ -41,7 +41,7 @@ import java.util.logging.Logger; * stream that is receiving the requested secret(s) or it could represent file-system based * secret(s) that are dynamic. */ -public abstract class SslContextProvider { +public abstract class SslContextProvider implements Closeable { private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName()); @@ -93,7 +93,8 @@ public abstract class SslContextProvider { } /** Closes this provider and releases any resources. */ - void close() {} + @Override + public abstract void close(); /** * Registers a callback on the given executor. The callback will run when SslContext becomes diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java index 4a098b1aec..a870e9d36f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java @@ -21,20 +21,20 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; +import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; /** * Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by * gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of * {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link - * ReferenceCountingSslContextProviderMap}. + * ReferenceCountingMap}. */ public final class TlsContextManagerImpl implements TlsContextManager { private static TlsContextManagerImpl instance; - private final ReferenceCountingSslContextProviderMap mapForClients; - private final ReferenceCountingSslContextProviderMap mapForServers; + private final ReferenceCountingMap mapForClients; + private final ReferenceCountingMap mapForServers; private TlsContextManagerImpl() { this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory()); @@ -42,12 +42,12 @@ public final class TlsContextManagerImpl implements TlsContextManager { @VisibleForTesting TlsContextManagerImpl( - SslContextProviderFactory clientFactory, - SslContextProviderFactory serverFactory) { + ValueFactory clientFactory, + ValueFactory serverFactory) { checkNotNull(clientFactory, "clientFactory"); checkNotNull(serverFactory, "serverFactory"); - mapForClients = new ReferenceCountingSslContextProviderMap<>(clientFactory); - mapForServers = new ReferenceCountingSslContextProviderMap<>(serverFactory); + mapForClients = new ReferenceCountingMap<>(clientFactory); + mapForServers = new ReferenceCountingMap<>(serverFactory); } /** Gets the TlsContextManagerImpl singleton. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index 5a481bebe9..8fe011f0ee 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -42,7 +42,7 @@ public class ClientSslContextProviderFactoryTest { CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); + clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider).isNotNull(); } @@ -56,7 +56,7 @@ public class ClientSslContextProviderFactoryTest { try { SslContextProvider unused = - clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); + clientSslContextProviderFactory.create(upstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { assertThat(expected) @@ -78,7 +78,7 @@ public class ClientSslContextProviderFactoryTest { try { SslContextProvider unused = - clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); + clientSslContextProviderFactory.create(upstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { assertThat(expected) diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMapTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingMapTest.java similarity index 81% rename from xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMapTest.java rename to xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingMapTest.java index e05e2ee50d..b94aefd215 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingSslContextProviderMapTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingMapTest.java @@ -24,7 +24,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; +import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -34,25 +34,26 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -/** Unit tests for {@link ReferenceCountingSslContextProviderMap}. */ +/** Unit tests for {@link ReferenceCountingMap}. */ @RunWith(JUnit4.class) -public class ReferenceCountingSslContextProviderMapTest { +public class ReferenceCountingMapTest { @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); - @Mock SslContextProviderFactory mockFactory; + @Mock + ValueFactory mockFactory; - ReferenceCountingSslContextProviderMap map; + ReferenceCountingMap map; @Before public void setUp() { - map = new ReferenceCountingSslContextProviderMap<>(mockFactory); + map = new ReferenceCountingMap<>(mockFactory); } @Test public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException { SslContextProvider valueFor3 = getTypedMock(); - when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); + when(mockFactory.create(3)).thenReturn(valueFor3); SslContextProvider val = map.get(3); assertThat(val).isSameInstanceAs(valueFor3); verify(valueFor3, never()).close(); @@ -73,8 +74,8 @@ public class ReferenceCountingSslContextProviderMapTest { public void referenceCountingMap_distinctElements() throws InterruptedException { SslContextProvider valueFor3 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock(); - when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); - when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); + when(mockFactory.create(3)).thenReturn(valueFor3); + when(mockFactory.create(4)).thenReturn(valueFor4); SslContextProvider val3 = map.get(3); assertThat(val3).isSameInstanceAs(valueFor3); SslContextProvider val4 = map.get(4); @@ -91,8 +92,8 @@ public class ReferenceCountingSslContextProviderMapTest { throws InterruptedException { SslContextProvider valueFor3 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock(); - when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); - when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); + when(mockFactory.create(3)).thenReturn(valueFor3); + when(mockFactory.create(4)).thenReturn(valueFor4); SslContextProvider unused = map.get(3); SslContextProvider val4 = map.get(4); // now provide wrong key (3) and value (val4) combination @@ -107,7 +108,7 @@ public class ReferenceCountingSslContextProviderMapTest { @Test public void referenceCountingMap_excessRelease_expectException() throws InterruptedException { SslContextProvider valueFor4 = getTypedMock(); - when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); + when(mockFactory.create(4)).thenReturn(valueFor4); SslContextProvider val = map.get(4); assertThat(val).isSameInstanceAs(valueFor4); // at this point ref-count is 1 @@ -124,7 +125,7 @@ public class ReferenceCountingSslContextProviderMapTest { @Test public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException { SslContextProvider valueFor4 = getTypedMock(); - when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); + when(mockFactory.create(4)).thenReturn(valueFor4); SslContextProvider val = map.get(4); assertThat(val).isSameInstanceAs(valueFor4); // at this point ref-count is 1 @@ -132,7 +133,7 @@ public class ReferenceCountingSslContextProviderMapTest { // at this point ref-count is 0 and val is removed // should get another instance for 4 SslContextProvider valueFor4a = getTypedMock(); - when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4a); + when(mockFactory.create(4)).thenReturn(valueFor4a); val = map.get(4); assertThat(val).isSameInstanceAs(valueFor4a); // verify it is a different instance from before diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index 46acd6e033..1c2b58f7b2 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -42,7 +42,7 @@ public class ServerSslContextProviderFactoryTest { SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); SslContextProvider sslContextProvider = - serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); + serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider).isNotNull(); } @@ -57,7 +57,7 @@ public class ServerSslContextProviderFactoryTest { try { SslContextProvider unused = - serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); + serverSslContextProviderFactory.create(downstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { assertThat(expected) @@ -77,7 +77,7 @@ public class ServerSslContextProviderFactoryTest { try { SslContextProvider unused = - serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); + serverSslContextProviderFactory.create(downstreamTlsContext); Assert.fail("no exception thrown"); } catch (UnsupportedOperationException expected) { assertThat(expected) diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java index d9ba5afb12..d69892b7c5 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java @@ -32,7 +32,7 @@ import static org.mockito.Mockito.when; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; +import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import java.lang.reflect.Field; import org.junit.Before; import org.junit.Rule; @@ -49,9 +49,9 @@ public class TlsContextManagerTest { @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); - @Mock SslContextProviderFactory mockClientFactory; + @Mock ValueFactory mockClientFactory; - @Mock SslContextProviderFactory mockServerFactory; + @Mock ValueFactory mockServerFactory; @Before public void clearInstance() throws NoSuchFieldException, IllegalAccessException { @@ -141,7 +141,7 @@ public class TlsContextManagerTest { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); SslContextProvider mockProvider = mock(SslContextProvider.class); - when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider); + when(mockServerFactory.create(downstreamTlsContext)).thenReturn(mockProvider); SslContextProvider serverSecretProvider = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isSameInstanceAs(mockProvider); @@ -160,7 +160,7 @@ public class TlsContextManagerTest { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); SslContextProvider mockProvider = mock(SslContextProvider.class); - when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider); + when(mockClientFactory.create(upstreamTlsContext)).thenReturn(mockProvider); SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isSameInstanceAs(mockProvider);