xds: convert and rename ReferenceCountingSslContextProviderMap to generic ReferenceCountingMap (#7181)

This commit is contained in:
sanjaypujare 2020-07-06 18:08:25 -07:00 committed by GitHub
parent 784e6b62f4
commit 2dc670163f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 101 additions and 76 deletions

View File

@ -21,17 +21,17 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException;
import java.util.concurrent.Executors;
/** Factory to create client-side SslContextProvider from UpstreamTlsContext. */
final class ClientSslContextProviderFactory
implements SslContextProviderFactory<UpstreamTlsContext> {
implements ValueFactory<UpstreamTlsContext, SslContextProvider> {
/** Creates an SslContextProvider from the given UpstreamTlsContext. */
@Override
public SslContextProvider createSslContextProvider(UpstreamTlsContext upstreamTlsContext) {
public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
checkNotNull(
upstreamTlsContext.getCommonTlsContext(),

View File

@ -0,0 +1,23 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
interface Closeable extends java.io.Closeable {
@Override
public void close();
}

View File

@ -26,38 +26,38 @@ import javax.annotation.CheckReturnValue;
import javax.annotation.concurrent.ThreadSafe;
/**
* A map for managing {@link SslContextProvider}s as reference-counted shared resources.
* A map for managing reference-counted shared resources - typically providers.
*
* <p>A key (of generic type K) identifies a {@link SslContextProvider}. The map also depends on a
* factory {@link SslContextProviderFactory} to create a new instance of {@link SslContextProvider}
* as needed. {@link SslContextProvider}s are ref-counted and closed by calling {@link
* SslContextProvider#close()} when ref-count reaches zero.
* <p>A key (of generic type K) identifies a provider (of generic type V). The map also depends on a
* factory {@link ValueFactory} to create a new instance of V as needed. Values are ref-counted and
* closed by calling {@link Closeable#close()} when ref-count reaches zero.
*
* @param <K> Key type for the map
* @param <V> Value type for the map - it should be a {@link Closeable}
*/
@ThreadSafe
final class ReferenceCountingSslContextProviderMap<K> {
final class ReferenceCountingMap<K, V extends Closeable> {
private final Map<K, Instance> instances = new HashMap<>();
private final SslContextProviderFactory<K> sslContextProviderFactory;
private final Map<K, Instance<V>> instances = new HashMap<>();
private final ValueFactory<K, V> valueFactory;
ReferenceCountingSslContextProviderMap(SslContextProviderFactory<K> sslContextProviderFactory) {
checkNotNull(sslContextProviderFactory, "sslContextProviderFactory");
this.sslContextProviderFactory = sslContextProviderFactory;
ReferenceCountingMap(ValueFactory<K, V> valueFactory) {
checkNotNull(valueFactory, "valueFactory");
this.valueFactory = valueFactory;
}
/**
* Gets an existing instance of {@link SslContextProvider}. If it doesn't exist, creates a new one
* using the provided {@link SslContextProviderFactory&lt;K&gt;}
* Gets an existing instance of a provider. If it doesn't exist, creates a new one
* using the provided {@link ValueFactory &lt;K, V&gt;}
*/
@CheckReturnValue
public SslContextProvider get(K key) {
public V get(K key) {
checkNotNull(key, "key");
return getInternal(key);
}
/**
* Releases an instance of the given {@link SslContextProvider}.
* Releases an instance of the given value.
*
* <p>The instance must have been obtained from {@link #get(K)}. Otherwise will throw
* IllegalArgumentException.
@ -69,30 +69,30 @@ final class ReferenceCountingSslContextProviderMap<K> {
* @param value the instance to be released
* @return a null which the caller can use to clear the reference to that instance.
*/
public SslContextProvider release(K key, SslContextProvider value) {
public V release(K key, V value) {
checkNotNull(key, "key");
checkNotNull(value, "value");
return releaseInternal(key, value);
}
private synchronized SslContextProvider getInternal(K key) {
Instance instance = instances.get(key);
private synchronized V getInternal(K key) {
Instance<V> instance = instances.get(key);
if (instance == null) {
instance = new Instance(sslContextProviderFactory.createSslContextProvider(key));
instance = new Instance<>(valueFactory.create(key));
instances.put(key, instance);
return instance.sslContextProvider;
return instance.value;
} else {
return instance.acquire();
}
}
private synchronized SslContextProvider releaseInternal(K key, SslContextProvider instance) {
Instance cached = instances.get(key);
private synchronized V releaseInternal(K key, V value) {
Instance<V> cached = instances.get(key);
checkArgument(cached != null, "No cached instance found for %s", key);
checkArgument(instance == cached.sslContextProvider, "Releasing the wrong instance");
checkArgument(value == cached.value, "Releasing the wrong instance");
if (cached.release()) {
try {
cached.sslContextProvider.close();
cached.value.close();
} finally {
instances.remove(key);
}
@ -101,19 +101,19 @@ final class ReferenceCountingSslContextProviderMap<K> {
return null;
}
/** A factory to create an SslContextProvider from the given key. */
public interface SslContextProviderFactory<K> {
SslContextProvider createSslContextProvider(K key);
/** A factory to create a value from the given key. */
public interface ValueFactory<K, V extends Closeable> {
V create(K key);
}
private static class Instance {
final SslContextProvider sslContextProvider;
private static final class Instance<V extends Closeable> {
final V value;
private int refCount;
/** Increment refCount and acquire a reference to sslContextProvider. */
SslContextProvider acquire() {
/** Increment refCount and acquire a reference to value. */
V acquire() {
refCount++;
return sslContextProvider;
return value;
}
/** Decrement refCount and return true if it has reached 0. */
@ -122,8 +122,8 @@ final class ReferenceCountingSslContextProviderMap<K> {
return --refCount == 0;
}
Instance(SslContextProvider sslContextProvider) {
this.sslContextProvider = sslContextProvider;
Instance(V value) {
this.value = value;
this.refCount = 1;
}
}

View File

@ -203,7 +203,7 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
}
@Override
void close() {
public void close() {
if (certSdsClient != null) {
certSdsClient.cancelSecretWatch(this);
certSdsClient.shutdown();

View File

@ -21,17 +21,17 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException;
import java.util.concurrent.Executors;
/** Factory to create server-side SslContextProvider from DownstreamTlsContext. */
final class ServerSslContextProviderFactory
implements SslContextProviderFactory<DownstreamTlsContext> {
implements ValueFactory<DownstreamTlsContext, SslContextProvider> {
/** Creates a SslContextProvider from the given DownstreamTlsContext. */
@Override
public SslContextProvider createSslContextProvider(
public SslContextProvider create(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
checkNotNull(

View File

@ -41,7 +41,7 @@ import java.util.logging.Logger;
* stream that is receiving the requested secret(s) or it could represent file-system based
* secret(s) that are dynamic.
*/
public abstract class SslContextProvider {
public abstract class SslContextProvider implements Closeable {
private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName());
@ -93,7 +93,8 @@ public abstract class SslContextProvider {
}
/** Closes this provider and releases any resources. */
void close() {}
@Override
public abstract void close();
/**
* Registers a callback on the given executor. The callback will run when SslContext becomes

View File

@ -21,20 +21,20 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
/**
* Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by
* gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of
* {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link
* ReferenceCountingSslContextProviderMap}.
* ReferenceCountingMap}.
*/
public final class TlsContextManagerImpl implements TlsContextManager {
private static TlsContextManagerImpl instance;
private final ReferenceCountingSslContextProviderMap<UpstreamTlsContext> mapForClients;
private final ReferenceCountingSslContextProviderMap<DownstreamTlsContext> mapForServers;
private final ReferenceCountingMap<UpstreamTlsContext, SslContextProvider> mapForClients;
private final ReferenceCountingMap<DownstreamTlsContext, SslContextProvider> mapForServers;
private TlsContextManagerImpl() {
this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory());
@ -42,12 +42,12 @@ public final class TlsContextManagerImpl implements TlsContextManager {
@VisibleForTesting
TlsContextManagerImpl(
SslContextProviderFactory<UpstreamTlsContext> clientFactory,
SslContextProviderFactory<DownstreamTlsContext> serverFactory) {
ValueFactory<UpstreamTlsContext, SslContextProvider> clientFactory,
ValueFactory<DownstreamTlsContext, SslContextProvider> serverFactory) {
checkNotNull(clientFactory, "clientFactory");
checkNotNull(serverFactory, "serverFactory");
mapForClients = new ReferenceCountingSslContextProviderMap<>(clientFactory);
mapForServers = new ReferenceCountingSslContextProviderMap<>(serverFactory);
mapForClients = new ReferenceCountingMap<>(clientFactory);
mapForServers = new ReferenceCountingMap<>(serverFactory);
}
/** Gets the TlsContextManagerImpl singleton. */

View File

@ -42,7 +42,7 @@ public class ClientSslContextProviderFactoryTest {
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isNotNull();
}
@ -56,7 +56,7 @@ public class ClientSslContextProviderFactoryTest {
try {
SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
@ -78,7 +78,7 @@ public class ClientSslContextProviderFactoryTest {
try {
SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext);
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)

View File

@ -24,7 +24,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@ -34,25 +34,26 @@ import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
/** Unit tests for {@link ReferenceCountingSslContextProviderMap}. */
/** Unit tests for {@link ReferenceCountingMap}. */
@RunWith(JUnit4.class)
public class ReferenceCountingSslContextProviderMapTest {
public class ReferenceCountingMapTest {
@Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock SslContextProviderFactory<Integer> mockFactory;
@Mock
ValueFactory<Integer, SslContextProvider> mockFactory;
ReferenceCountingSslContextProviderMap<Integer> map;
ReferenceCountingMap<Integer, SslContextProvider> map;
@Before
public void setUp() {
map = new ReferenceCountingSslContextProviderMap<>(mockFactory);
map = new ReferenceCountingMap<>(mockFactory);
}
@Test
public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException {
SslContextProvider valueFor3 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
when(mockFactory.create(3)).thenReturn(valueFor3);
SslContextProvider val = map.get(3);
assertThat(val).isSameInstanceAs(valueFor3);
verify(valueFor3, never()).close();
@ -73,8 +74,8 @@ public class ReferenceCountingSslContextProviderMapTest {
public void referenceCountingMap_distinctElements() throws InterruptedException {
SslContextProvider valueFor3 = getTypedMock();
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
when(mockFactory.create(3)).thenReturn(valueFor3);
when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider val3 = map.get(3);
assertThat(val3).isSameInstanceAs(valueFor3);
SslContextProvider val4 = map.get(4);
@ -91,8 +92,8 @@ public class ReferenceCountingSslContextProviderMapTest {
throws InterruptedException {
SslContextProvider valueFor3 = getTypedMock();
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
when(mockFactory.create(3)).thenReturn(valueFor3);
when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider unused = map.get(3);
SslContextProvider val4 = map.get(4);
// now provide wrong key (3) and value (val4) combination
@ -107,7 +108,7 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test
public void referenceCountingMap_excessRelease_expectException() throws InterruptedException {
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1
@ -124,7 +125,7 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test
public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException {
SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4);
when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1
@ -132,7 +133,7 @@ public class ReferenceCountingSslContextProviderMapTest {
// at this point ref-count is 0 and val is removed
// should get another instance for 4
SslContextProvider valueFor4a = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4a);
when(mockFactory.create(4)).thenReturn(valueFor4a);
val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4a);
// verify it is a different instance from before

View File

@ -42,7 +42,7 @@ public class ServerSslContextProviderFactoryTest {
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isNotNull();
}
@ -57,7 +57,7 @@ public class ServerSslContextProviderFactoryTest {
try {
SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
@ -77,7 +77,7 @@ public class ServerSslContextProviderFactoryTest {
try {
SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)

View File

@ -32,7 +32,7 @@ import static org.mockito.Mockito.when;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.lang.reflect.Field;
import org.junit.Before;
import org.junit.Rule;
@ -49,9 +49,9 @@ public class TlsContextManagerTest {
@Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock SslContextProviderFactory<UpstreamTlsContext> mockClientFactory;
@Mock ValueFactory<UpstreamTlsContext, SslContextProvider> mockClientFactory;
@Mock SslContextProviderFactory<DownstreamTlsContext> mockServerFactory;
@Mock ValueFactory<DownstreamTlsContext, SslContextProvider> mockServerFactory;
@Before
public void clearInstance() throws NoSuchFieldException, IllegalAccessException {
@ -141,7 +141,7 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider);
when(mockServerFactory.create(downstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isSameInstanceAs(mockProvider);
@ -160,7 +160,7 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider);
when(mockClientFactory.create(upstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isSameInstanceAs(mockProvider);