xds: convert and rename ReferenceCountingSslContextProviderMap to generic ReferenceCountingMap (#7181)

This commit is contained in:
sanjaypujare 2020-07-06 18:08:25 -07:00 committed by GitHub
parent 784e6b62f4
commit 2dc670163f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 101 additions and 76 deletions

View File

@ -21,17 +21,17 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper; import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
/** Factory to create client-side SslContextProvider from UpstreamTlsContext. */ /** Factory to create client-side SslContextProvider from UpstreamTlsContext. */
final class ClientSslContextProviderFactory final class ClientSslContextProviderFactory
implements SslContextProviderFactory<UpstreamTlsContext> { implements ValueFactory<UpstreamTlsContext, SslContextProvider> {
/** Creates an SslContextProvider from the given UpstreamTlsContext. */ /** Creates an SslContextProvider from the given UpstreamTlsContext. */
@Override @Override
public SslContextProvider createSslContextProvider(UpstreamTlsContext upstreamTlsContext) { public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext"); checkNotNull(upstreamTlsContext, "upstreamTlsContext");
checkNotNull( checkNotNull(
upstreamTlsContext.getCommonTlsContext(), upstreamTlsContext.getCommonTlsContext(),

View File

@ -0,0 +1,23 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
interface Closeable extends java.io.Closeable {
@Override
public void close();
}

View File

@ -26,38 +26,38 @@ import javax.annotation.CheckReturnValue;
import javax.annotation.concurrent.ThreadSafe; import javax.annotation.concurrent.ThreadSafe;
/** /**
* A map for managing {@link SslContextProvider}s as reference-counted shared resources. * A map for managing reference-counted shared resources - typically providers.
* *
* <p>A key (of generic type K) identifies a {@link SslContextProvider}. The map also depends on a * <p>A key (of generic type K) identifies a provider (of generic type V). The map also depends on a
* factory {@link SslContextProviderFactory} to create a new instance of {@link SslContextProvider} * factory {@link ValueFactory} to create a new instance of V as needed. Values are ref-counted and
* as needed. {@link SslContextProvider}s are ref-counted and closed by calling {@link * closed by calling {@link Closeable#close()} when ref-count reaches zero.
* SslContextProvider#close()} when ref-count reaches zero.
* *
* @param <K> Key type for the map * @param <K> Key type for the map
* @param <V> Value type for the map - it should be a {@link Closeable}
*/ */
@ThreadSafe @ThreadSafe
final class ReferenceCountingSslContextProviderMap<K> { final class ReferenceCountingMap<K, V extends Closeable> {
private final Map<K, Instance> instances = new HashMap<>(); private final Map<K, Instance<V>> instances = new HashMap<>();
private final SslContextProviderFactory<K> sslContextProviderFactory; private final ValueFactory<K, V> valueFactory;
ReferenceCountingSslContextProviderMap(SslContextProviderFactory<K> sslContextProviderFactory) { ReferenceCountingMap(ValueFactory<K, V> valueFactory) {
checkNotNull(sslContextProviderFactory, "sslContextProviderFactory"); checkNotNull(valueFactory, "valueFactory");
this.sslContextProviderFactory = sslContextProviderFactory; this.valueFactory = valueFactory;
} }
/** /**
* Gets an existing instance of {@link SslContextProvider}. If it doesn't exist, creates a new one * Gets an existing instance of a provider. If it doesn't exist, creates a new one
* using the provided {@link SslContextProviderFactory&lt;K&gt;} * using the provided {@link ValueFactory &lt;K, V&gt;}
*/ */
@CheckReturnValue @CheckReturnValue
public SslContextProvider get(K key) { public V get(K key) {
checkNotNull(key, "key"); checkNotNull(key, "key");
return getInternal(key); return getInternal(key);
} }
/** /**
* Releases an instance of the given {@link SslContextProvider}. * Releases an instance of the given value.
* *
* <p>The instance must have been obtained from {@link #get(K)}. Otherwise will throw * <p>The instance must have been obtained from {@link #get(K)}. Otherwise will throw
* IllegalArgumentException. * IllegalArgumentException.
@ -69,30 +69,30 @@ final class ReferenceCountingSslContextProviderMap<K> {
* @param value the instance to be released * @param value the instance to be released
* @return a null which the caller can use to clear the reference to that instance. * @return a null which the caller can use to clear the reference to that instance.
*/ */
public SslContextProvider release(K key, SslContextProvider value) { public V release(K key, V value) {
checkNotNull(key, "key"); checkNotNull(key, "key");
checkNotNull(value, "value"); checkNotNull(value, "value");
return releaseInternal(key, value); return releaseInternal(key, value);
} }
private synchronized SslContextProvider getInternal(K key) { private synchronized V getInternal(K key) {
Instance instance = instances.get(key); Instance<V> instance = instances.get(key);
if (instance == null) { if (instance == null) {
instance = new Instance(sslContextProviderFactory.createSslContextProvider(key)); instance = new Instance<>(valueFactory.create(key));
instances.put(key, instance); instances.put(key, instance);
return instance.sslContextProvider; return instance.value;
} else { } else {
return instance.acquire(); return instance.acquire();
} }
} }
private synchronized SslContextProvider releaseInternal(K key, SslContextProvider instance) { private synchronized V releaseInternal(K key, V value) {
Instance cached = instances.get(key); Instance<V> cached = instances.get(key);
checkArgument(cached != null, "No cached instance found for %s", key); checkArgument(cached != null, "No cached instance found for %s", key);
checkArgument(instance == cached.sslContextProvider, "Releasing the wrong instance"); checkArgument(value == cached.value, "Releasing the wrong instance");
if (cached.release()) { if (cached.release()) {
try { try {
cached.sslContextProvider.close(); cached.value.close();
} finally { } finally {
instances.remove(key); instances.remove(key);
} }
@ -101,19 +101,19 @@ final class ReferenceCountingSslContextProviderMap<K> {
return null; return null;
} }
/** A factory to create an SslContextProvider from the given key. */ /** A factory to create a value from the given key. */
public interface SslContextProviderFactory<K> { public interface ValueFactory<K, V extends Closeable> {
SslContextProvider createSslContextProvider(K key); V create(K key);
} }
private static class Instance { private static final class Instance<V extends Closeable> {
final SslContextProvider sslContextProvider; final V value;
private int refCount; private int refCount;
/** Increment refCount and acquire a reference to sslContextProvider. */ /** Increment refCount and acquire a reference to value. */
SslContextProvider acquire() { V acquire() {
refCount++; refCount++;
return sslContextProvider; return value;
} }
/** Decrement refCount and return true if it has reached 0. */ /** Decrement refCount and return true if it has reached 0. */
@ -122,8 +122,8 @@ final class ReferenceCountingSslContextProviderMap<K> {
return --refCount == 0; return --refCount == 0;
} }
Instance(SslContextProvider sslContextProvider) { Instance(V value) {
this.sslContextProvider = sslContextProvider; this.value = value;
this.refCount = 1; this.refCount = 1;
} }
} }

View File

@ -203,7 +203,7 @@ abstract class SdsSslContextProvider extends SslContextProvider implements SdsCl
} }
@Override @Override
void close() { public void close() {
if (certSdsClient != null) { if (certSdsClient != null) {
certSdsClient.cancelSecretWatch(this); certSdsClient.cancelSecretWatch(this);
certSdsClient.shutdown(); certSdsClient.shutdown();

View File

@ -21,17 +21,17 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.xds.Bootstrapper; import io.grpc.xds.Bootstrapper;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
/** Factory to create server-side SslContextProvider from DownstreamTlsContext. */ /** Factory to create server-side SslContextProvider from DownstreamTlsContext. */
final class ServerSslContextProviderFactory final class ServerSslContextProviderFactory
implements SslContextProviderFactory<DownstreamTlsContext> { implements ValueFactory<DownstreamTlsContext, SslContextProvider> {
/** Creates a SslContextProvider from the given DownstreamTlsContext. */ /** Creates a SslContextProvider from the given DownstreamTlsContext. */
@Override @Override
public SslContextProvider createSslContextProvider( public SslContextProvider create(
DownstreamTlsContext downstreamTlsContext) { DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext"); checkNotNull(downstreamTlsContext, "downstreamTlsContext");
checkNotNull( checkNotNull(

View File

@ -41,7 +41,7 @@ import java.util.logging.Logger;
* stream that is receiving the requested secret(s) or it could represent file-system based * stream that is receiving the requested secret(s) or it could represent file-system based
* secret(s) that are dynamic. * secret(s) that are dynamic.
*/ */
public abstract class SslContextProvider { public abstract class SslContextProvider implements Closeable {
private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName()); private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName());
@ -93,7 +93,8 @@ public abstract class SslContextProvider {
} }
/** Closes this provider and releases any resources. */ /** Closes this provider and releases any resources. */
void close() {} @Override
public abstract void close();
/** /**
* Registers a callback on the given executor. The callback will run when SslContext becomes * Registers a callback on the given executor. The callback will run when SslContext becomes

View File

@ -21,20 +21,20 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
/** /**
* Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by * Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by
* gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of * gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of
* {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link * {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link
* ReferenceCountingSslContextProviderMap}. * ReferenceCountingMap}.
*/ */
public final class TlsContextManagerImpl implements TlsContextManager { public final class TlsContextManagerImpl implements TlsContextManager {
private static TlsContextManagerImpl instance; private static TlsContextManagerImpl instance;
private final ReferenceCountingSslContextProviderMap<UpstreamTlsContext> mapForClients; private final ReferenceCountingMap<UpstreamTlsContext, SslContextProvider> mapForClients;
private final ReferenceCountingSslContextProviderMap<DownstreamTlsContext> mapForServers; private final ReferenceCountingMap<DownstreamTlsContext, SslContextProvider> mapForServers;
private TlsContextManagerImpl() { private TlsContextManagerImpl() {
this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory()); this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory());
@ -42,12 +42,12 @@ public final class TlsContextManagerImpl implements TlsContextManager {
@VisibleForTesting @VisibleForTesting
TlsContextManagerImpl( TlsContextManagerImpl(
SslContextProviderFactory<UpstreamTlsContext> clientFactory, ValueFactory<UpstreamTlsContext, SslContextProvider> clientFactory,
SslContextProviderFactory<DownstreamTlsContext> serverFactory) { ValueFactory<DownstreamTlsContext, SslContextProvider> serverFactory) {
checkNotNull(clientFactory, "clientFactory"); checkNotNull(clientFactory, "clientFactory");
checkNotNull(serverFactory, "serverFactory"); checkNotNull(serverFactory, "serverFactory");
mapForClients = new ReferenceCountingSslContextProviderMap<>(clientFactory); mapForClients = new ReferenceCountingMap<>(clientFactory);
mapForServers = new ReferenceCountingSslContextProviderMap<>(serverFactory); mapForServers = new ReferenceCountingMap<>(serverFactory);
} }
/** Gets the TlsContextManagerImpl singleton. */ /** Gets the TlsContextManagerImpl singleton. */

View File

@ -42,7 +42,7 @@ public class ClientSslContextProviderFactoryTest {
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider sslContextProvider = SslContextProvider sslContextProvider =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isNotNull(); assertThat(sslContextProvider).isNotNull();
} }
@ -56,7 +56,7 @@ public class ClientSslContextProviderFactoryTest {
try { try {
SslContextProvider unused = SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
assertThat(expected) assertThat(expected)
@ -78,7 +78,7 @@ public class ClientSslContextProviderFactoryTest {
try { try {
SslContextProvider unused = SslContextProvider unused =
clientSslContextProviderFactory.createSslContextProvider(upstreamTlsContext); clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
assertThat(expected) assertThat(expected)

View File

@ -24,7 +24,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -34,25 +34,26 @@ import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule; import org.mockito.junit.MockitoRule;
/** Unit tests for {@link ReferenceCountingSslContextProviderMap}. */ /** Unit tests for {@link ReferenceCountingMap}. */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class ReferenceCountingSslContextProviderMapTest { public class ReferenceCountingMapTest {
@Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock SslContextProviderFactory<Integer> mockFactory; @Mock
ValueFactory<Integer, SslContextProvider> mockFactory;
ReferenceCountingSslContextProviderMap<Integer> map; ReferenceCountingMap<Integer, SslContextProvider> map;
@Before @Before
public void setUp() { public void setUp() {
map = new ReferenceCountingSslContextProviderMap<>(mockFactory); map = new ReferenceCountingMap<>(mockFactory);
} }
@Test @Test
public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException { public void referenceCountingMap_getAndRelease_closeCalled() throws InterruptedException {
SslContextProvider valueFor3 = getTypedMock(); SslContextProvider valueFor3 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.create(3)).thenReturn(valueFor3);
SslContextProvider val = map.get(3); SslContextProvider val = map.get(3);
assertThat(val).isSameInstanceAs(valueFor3); assertThat(val).isSameInstanceAs(valueFor3);
verify(valueFor3, never()).close(); verify(valueFor3, never()).close();
@ -73,8 +74,8 @@ public class ReferenceCountingSslContextProviderMapTest {
public void referenceCountingMap_distinctElements() throws InterruptedException { public void referenceCountingMap_distinctElements() throws InterruptedException {
SslContextProvider valueFor3 = getTypedMock(); SslContextProvider valueFor3 = getTypedMock();
SslContextProvider valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.create(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider val3 = map.get(3); SslContextProvider val3 = map.get(3);
assertThat(val3).isSameInstanceAs(valueFor3); assertThat(val3).isSameInstanceAs(valueFor3);
SslContextProvider val4 = map.get(4); SslContextProvider val4 = map.get(4);
@ -91,8 +92,8 @@ public class ReferenceCountingSslContextProviderMapTest {
throws InterruptedException { throws InterruptedException {
SslContextProvider valueFor3 = getTypedMock(); SslContextProvider valueFor3 = getTypedMock();
SslContextProvider valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(3)).thenReturn(valueFor3); when(mockFactory.create(3)).thenReturn(valueFor3);
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider unused = map.get(3); SslContextProvider unused = map.get(3);
SslContextProvider val4 = map.get(4); SslContextProvider val4 = map.get(4);
// now provide wrong key (3) and value (val4) combination // now provide wrong key (3) and value (val4) combination
@ -107,7 +108,7 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test @Test
public void referenceCountingMap_excessRelease_expectException() throws InterruptedException { public void referenceCountingMap_excessRelease_expectException() throws InterruptedException {
SslContextProvider valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider val = map.get(4); SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4); assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1 // at this point ref-count is 1
@ -124,7 +125,7 @@ public class ReferenceCountingSslContextProviderMapTest {
@Test @Test
public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException { public void referenceCountingMap_releaseAndGet_differentInstance() throws InterruptedException {
SslContextProvider valueFor4 = getTypedMock(); SslContextProvider valueFor4 = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4); when(mockFactory.create(4)).thenReturn(valueFor4);
SslContextProvider val = map.get(4); SslContextProvider val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4); assertThat(val).isSameInstanceAs(valueFor4);
// at this point ref-count is 1 // at this point ref-count is 1
@ -132,7 +133,7 @@ public class ReferenceCountingSslContextProviderMapTest {
// at this point ref-count is 0 and val is removed // at this point ref-count is 0 and val is removed
// should get another instance for 4 // should get another instance for 4
SslContextProvider valueFor4a = getTypedMock(); SslContextProvider valueFor4a = getTypedMock();
when(mockFactory.createSslContextProvider(4)).thenReturn(valueFor4a); when(mockFactory.create(4)).thenReturn(valueFor4a);
val = map.get(4); val = map.get(4);
assertThat(val).isSameInstanceAs(valueFor4a); assertThat(val).isSameInstanceAs(valueFor4a);
// verify it is a different instance from before // verify it is a different instance from before

View File

@ -42,7 +42,7 @@ public class ServerSslContextProviderFactoryTest {
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
SslContextProvider sslContextProvider = SslContextProvider sslContextProvider =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isNotNull(); assertThat(sslContextProvider).isNotNull();
} }
@ -57,7 +57,7 @@ public class ServerSslContextProviderFactoryTest {
try { try {
SslContextProvider unused = SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
assertThat(expected) assertThat(expected)
@ -77,7 +77,7 @@ public class ServerSslContextProviderFactoryTest {
try { try {
SslContextProvider unused = SslContextProvider unused =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) { } catch (UnsupportedOperationException expected) {
assertThat(expected) assertThat(expected)

View File

@ -32,7 +32,7 @@ import static org.mockito.Mockito.when;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
@ -49,9 +49,9 @@ public class TlsContextManagerTest {
@Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock SslContextProviderFactory<UpstreamTlsContext> mockClientFactory; @Mock ValueFactory<UpstreamTlsContext, SslContextProvider> mockClientFactory;
@Mock SslContextProviderFactory<DownstreamTlsContext> mockServerFactory; @Mock ValueFactory<DownstreamTlsContext, SslContextProvider> mockServerFactory;
@Before @Before
public void clearInstance() throws NoSuchFieldException, IllegalAccessException { public void clearInstance() throws NoSuchFieldException, IllegalAccessException {
@ -141,7 +141,7 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
SslContextProvider mockProvider = mock(SslContextProvider.class); SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider); when(mockServerFactory.create(downstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider serverSecretProvider = SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isSameInstanceAs(mockProvider); assertThat(serverSecretProvider).isSameInstanceAs(mockProvider);
@ -160,7 +160,7 @@ public class TlsContextManagerTest {
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
SslContextProvider mockProvider = mock(SslContextProvider.class); SslContextProvider mockProvider = mock(SslContextProvider.class);
when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider); when(mockClientFactory.create(upstreamTlsContext)).thenReturn(mockProvider);
SslContextProvider clientSecretProvider = SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isSameInstanceAs(mockProvider); assertThat(clientSecretProvider).isSameInstanceAs(mockProvider);