xds: replace DownstreamTlsContext by SslContextProviderSupplier in the Listener (#8205)

This commit is contained in:
sanjaypujare 2021-05-26 14:42:47 -07:00 committed by GitHub
parent 6aeeba805f
commit 328071bbce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 535 additions and 324 deletions

View File

@ -287,10 +287,10 @@ final class ClientXdsClient extends AbstractXdsClient {
"HttpConnectionManager neither has inlined route_config nor RDS."); "HttpConnectionManager neither has inlined route_config nor RDS.");
} }
private static LdsUpdate processServerSideListener(Listener listener) private LdsUpdate processServerSideListener(Listener listener)
throws ResourceInvalidException { throws ResourceInvalidException {
StructOrError<EnvoyServerProtoData.Listener> convertedListener = StructOrError<EnvoyServerProtoData.Listener> convertedListener =
parseServerSideListener(listener); parseServerSideListener(listener, tlsContextManager);
if (convertedListener.getErrorDetail() != null) { if (convertedListener.getErrorDetail() != null) {
throw new ResourceInvalidException(convertedListener.getErrorDetail()); throw new ResourceInvalidException(convertedListener.getErrorDetail());
} }
@ -369,10 +369,10 @@ final class ClientXdsClient extends AbstractXdsClient {
} }
@VisibleForTesting static StructOrError<EnvoyServerProtoData.Listener> parseServerSideListener( @VisibleForTesting static StructOrError<EnvoyServerProtoData.Listener> parseServerSideListener(
Listener listener) { Listener listener, TlsContextManager tlsContextManager) {
try { try {
return StructOrError.fromStruct( return StructOrError.fromStruct(
EnvoyServerProtoData.Listener.fromEnvoyProtoListener(listener)); EnvoyServerProtoData.Listener.fromEnvoyProtoListener(listener, tlsContextManager));
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
return StructOrError.fromError( return StructOrError.fromError(
"Failed to unpack Listener " + listener.getName() + ":" + e.getMessage()); "Failed to unpack Listener " + listener.getName() + ":" + e.getMessage());

View File

@ -28,6 +28,7 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3
import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.grpc.Internal; import io.grpc.Internal;
import io.grpc.xds.internal.sds.SslContextProviderSupplier;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.ArrayList; import java.util.ArrayList;
@ -354,17 +355,21 @@ public final class EnvoyServerProtoData {
// TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here. // TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here.
private final FilterChainMatch filterChainMatch; private final FilterChainMatch filterChainMatch;
@Nullable @Nullable
private final DownstreamTlsContext downstreamTlsContext; private final SslContextProviderSupplier sslContextProviderSupplier;
@VisibleForTesting @VisibleForTesting
FilterChain( FilterChain(
FilterChainMatch filterChainMatch, @Nullable DownstreamTlsContext downstreamTlsContext) { FilterChainMatch filterChainMatch, @Nullable DownstreamTlsContext downstreamTlsContext,
TlsContextManager tlsContextManager) {
SslContextProviderSupplier sslContextProviderSupplier1 = downstreamTlsContext == null ? null
: new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager);
this.filterChainMatch = filterChainMatch; this.filterChainMatch = filterChainMatch;
this.downstreamTlsContext = downstreamTlsContext; this.sslContextProviderSupplier = sslContextProviderSupplier1;
} }
static FilterChain fromEnvoyProtoFilterChain( static FilterChain fromEnvoyProtoFilterChain(
io.envoyproxy.envoy.config.listener.v3.FilterChain proto, boolean isDefaultFilterChain) io.envoyproxy.envoy.config.listener.v3.FilterChain proto,
TlsContextManager tlsContextManager, boolean isDefaultFilterChain)
throws InvalidProtocolBufferException { throws InvalidProtocolBufferException {
if (!isDefaultFilterChain && proto.getFiltersList().isEmpty()) { if (!isDefaultFilterChain && proto.getFiltersList().isEmpty()) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
@ -380,7 +385,8 @@ public final class EnvoyServerProtoData {
} }
return new FilterChain( return new FilterChain(
FilterChainMatch.fromEnvoyProtoFilterChainMatch(proto.getFilterChainMatch()), FilterChainMatch.fromEnvoyProtoFilterChainMatch(proto.getFilterChainMatch()),
getTlsContextFromFilterChain(proto) getTlsContextFromFilterChain(proto),
tlsContextManager
); );
} }
@ -456,9 +462,8 @@ public final class EnvoyServerProtoData {
return filterChainMatch; return filterChainMatch;
} }
@Nullable public SslContextProviderSupplier getSslContextProviderSupplier() {
public DownstreamTlsContext getDownstreamTlsContext() { return sslContextProviderSupplier;
return downstreamTlsContext;
} }
@Override @Override
@ -471,19 +476,19 @@ public final class EnvoyServerProtoData {
} }
FilterChain that = (FilterChain) o; FilterChain that = (FilterChain) o;
return java.util.Objects.equals(filterChainMatch, that.filterChainMatch) return java.util.Objects.equals(filterChainMatch, that.filterChainMatch)
&& java.util.Objects.equals(downstreamTlsContext, that.downstreamTlsContext); && java.util.Objects.equals(sslContextProviderSupplier, that.sslContextProviderSupplier);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return java.util.Objects.hash(filterChainMatch, downstreamTlsContext); return java.util.Objects.hash(filterChainMatch, sslContextProviderSupplier);
} }
@Override @Override
public String toString() { public String toString() {
return "FilterChain{" return "FilterChain{"
+ "filterChainMatch=" + filterChainMatch + "filterChainMatch=" + filterChainMatch
+ ", downstreamTlsContext=" + downstreamTlsContext + ", sslContextProviderSupplier=" + sslContextProviderSupplier
+ '}'; + '}';
} }
} }
@ -524,7 +529,8 @@ public final class EnvoyServerProtoData {
return null; return null;
} }
static Listener fromEnvoyProtoListener(io.envoyproxy.envoy.config.listener.v3.Listener proto) static Listener fromEnvoyProtoListener(io.envoyproxy.envoy.config.listener.v3.Listener proto,
TlsContextManager tlsContextManager)
throws InvalidProtocolBufferException { throws InvalidProtocolBufferException {
if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND)) { if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND)) {
throw new IllegalArgumentException("Listener " + proto.getName() + " is not INBOUND"); throw new IllegalArgumentException("Listener " + proto.getName() + " is not INBOUND");
@ -537,21 +543,25 @@ public final class EnvoyServerProtoData {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Listener " + proto.getName() + " cannot have use_original_dst set to true"); "Listener " + proto.getName() + " cannot have use_original_dst set to true");
} }
List<FilterChain> filterChains = validateAndSelectFilterChains(proto.getFilterChainsList()); List<FilterChain> filterChains = validateAndSelectFilterChains(proto.getFilterChainsList(),
tlsContextManager);
return new Listener( return new Listener(
proto.getName(), proto.getName(),
convertEnvoyAddressToString(proto.getAddress()), convertEnvoyAddressToString(proto.getAddress()),
filterChains, FilterChain.fromEnvoyProtoFilterChain(proto.getDefaultFilterChain(), true)); filterChains, FilterChain
.fromEnvoyProtoFilterChain(proto.getDefaultFilterChain(), tlsContextManager, true));
} }
private static List<FilterChain> validateAndSelectFilterChains( private static List<FilterChain> validateAndSelectFilterChains(
List<io.envoyproxy.envoy.config.listener.v3.FilterChain> inputFilterChains) List<io.envoyproxy.envoy.config.listener.v3.FilterChain> inputFilterChains,
TlsContextManager tlsContextManager)
throws InvalidProtocolBufferException { throws InvalidProtocolBufferException {
List<FilterChain> filterChains = new ArrayList<>(inputFilterChains.size()); List<FilterChain> filterChains = new ArrayList<>(inputFilterChains.size());
for (io.envoyproxy.envoy.config.listener.v3.FilterChain filterChain : for (io.envoyproxy.envoy.config.listener.v3.FilterChain filterChain :
inputFilterChains) { inputFilterChains) {
if (isAcceptable(filterChain.getFilterChainMatch())) { if (isAcceptable(filterChain.getFilterChainMatch())) {
filterChains.add(FilterChain.fromEnvoyProtoFilterChain(filterChain, false)); filterChains
.add(FilterChain.fromEnvoyProtoFilterChain(filterChain, tlsContextManager, false));
} }
} }
return filterChains; return filterChains;

View File

@ -27,9 +27,9 @@ import io.grpc.Status;
import io.grpc.internal.ObjectPool; import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
import io.grpc.xds.EnvoyServerProtoData.CidrRange; import io.grpc.xds.EnvoyServerProtoData.CidrRange;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.EnvoyServerProtoData.FilterChain;
import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch;
import io.grpc.xds.internal.sds.SslContextProviderSupplier;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollEventLoopGroup;
@ -74,6 +74,7 @@ public final class XdsClientWrapperForServerSds {
private ScheduledExecutorService timeService; private ScheduledExecutorService timeService;
private XdsClient.LdsResourceWatcher listenerWatcher; private XdsClient.LdsResourceWatcher listenerWatcher;
private boolean newServerApi; private boolean newServerApi;
private String grpcServerResourceId;
@VisibleForTesting final Set<ServerWatcher> serverWatchers = new HashSet<>(); @VisibleForTesting final Set<ServerWatcher> serverWatchers = new HashSet<>();
/** /**
@ -114,14 +115,14 @@ public final class XdsClientWrapperForServerSds {
new XdsClient.LdsResourceWatcher() { new XdsClient.LdsResourceWatcher() {
@Override @Override
public void onChanged(XdsClient.LdsUpdate update) { public void onChanged(XdsClient.LdsUpdate update) {
curListener.set(update.listener); releaseOldSuppliers(curListener.getAndSet(update.listener));
reportSuccess(); reportSuccess();
} }
@Override @Override
public void onResourceDoesNotExist(String resourceName) { public void onResourceDoesNotExist(String resourceName) {
logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName); logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName);
curListener.set(null); releaseOldSuppliers(curListener.getAndSet(null));
reportError(Status.NOT_FOUND.asException(), true); reportError(Status.NOT_FOUND.asException(), true);
} }
@ -129,10 +130,15 @@ public final class XdsClientWrapperForServerSds {
public void onError(Status error) { public void onError(Status error) {
logger.log( logger.log(
Level.WARNING, "LdsResourceWatcher in XdsClientWrapperForServerSds: {0}", error); Level.WARNING, "LdsResourceWatcher in XdsClientWrapperForServerSds: {0}", error);
reportError(error.asException(), isResourceAbsent(error)); if (isResourceAbsent(error)) {
releaseOldSuppliers(curListener.getAndSet(null));
reportError(error.asException(), true);
} else {
reportError(error.asException(), false);
}
} }
}; };
String grpcServerResourceId = xdsClient.getBootstrapInfo() grpcServerResourceId = xdsClient.getBootstrapInfo()
.getServerListenerResourceNameTemplate(); .getServerListenerResourceNameTemplate();
newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3();
if (newServerApi && grpcServerResourceId == null) { if (newServerApi && grpcServerResourceId == null) {
@ -145,6 +151,27 @@ public final class XdsClientWrapperForServerSds {
xdsClient.watchLdsResource(grpcServerResourceId, listenerWatcher); xdsClient.watchLdsResource(grpcServerResourceId, listenerWatcher);
} }
// go thru the old listener and release all the old SslContextProviderSupplier
private void releaseOldSuppliers(EnvoyServerProtoData.Listener oldListener) {
if (oldListener != null) {
List<FilterChain> filterChains = oldListener.getFilterChains();
for (FilterChain filterChain : filterChains) {
releaseSupplier(filterChain);
}
releaseSupplier(oldListener.getDefaultFilterChain());
}
}
private static void releaseSupplier(FilterChain filterChain) {
if (filterChain != null) {
SslContextProviderSupplier sslContextProviderSupplier =
filterChain.getSslContextProviderSupplier();
if (sslContextProviderSupplier != null) {
sslContextProviderSupplier.close();
}
}
}
/** Whether the throwable indicates our listener resource is absent/deleted. */ /** Whether the throwable indicates our listener resource is absent/deleted. */
private static boolean isResourceAbsent(Status status) { private static boolean isResourceAbsent(Status status) {
Status.Code code = status.getCode(); Status.Code code = status.getCode();
@ -162,10 +189,10 @@ public final class XdsClientWrapperForServerSds {
/** /**
* Locates the best matching FilterChain to the channel from the current listener and if found * Locates the best matching FilterChain to the channel from the current listener and if found
* returns the DownstreamTlsContext from that FilterChain, else null. * returns the SslContextProviderSupplier from that FilterChain, else null.
*/ */
@Nullable @Nullable
public DownstreamTlsContext getDownstreamTlsContext(Channel channel) { public SslContextProviderSupplier getSslContextProviderSupplier(Channel channel) {
EnvoyServerProtoData.Listener copyListener = curListener.get(); EnvoyServerProtoData.Listener copyListener = curListener.get();
if (copyListener != null && channel != null) { if (copyListener != null && channel != null) {
SocketAddress localAddress = channel.localAddress(); SocketAddress localAddress = channel.localAddress();
@ -176,7 +203,7 @@ public final class XdsClientWrapperForServerSds {
checkState( checkState(
port == localInetAddr.getPort(), port == localInetAddr.getPort(),
"Channel localAddress port does not match requested listener port"); "Channel localAddress port does not match requested listener port");
return getDownstreamTlsContext(localInetAddr, remoteInetAddr, copyListener); return getSslContextProviderSupplier(localInetAddr, remoteInetAddr, copyListener);
} }
} }
return null; return null;
@ -185,13 +212,13 @@ public final class XdsClientWrapperForServerSds {
/** /**
* Using the logic specified at * Using the logic specified at
* https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/listener/listener_components.proto.html?highlight=filter%20chain#listener-filterchainmatch * https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/listener/listener_components.proto.html?highlight=filter%20chain#listener-filterchainmatch
* locate a matching filter and return the corresponding DownstreamTlsContext or else return one * locate a matching filter and return the corresponding SslContextProviderSupplier or else
* from default filter chain. * return one from default filter chain.
* *
* @param localInetAddr dest address of the inbound connection * @param localInetAddr dest address of the inbound connection
* @param remoteInetAddr source address of the inbound connection * @param remoteInetAddr source address of the inbound connection
*/ */
private static DownstreamTlsContext getDownstreamTlsContext( private static SslContextProviderSupplier getSslContextProviderSupplier(
InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr, InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr,
EnvoyServerProtoData.Listener listener) { EnvoyServerProtoData.Listener listener) {
List<FilterChain> filterChains = listener.getFilterChains(); List<FilterChain> filterChains = listener.getFilterChains();
@ -207,9 +234,9 @@ public final class XdsClientWrapperForServerSds {
// close the connection // close the connection
throw new IllegalStateException("Found 2 matching filter-chains"); throw new IllegalStateException("Found 2 matching filter-chains");
} else if (filterChains.size() == 1) { } else if (filterChains.size() == 1) {
return filterChains.get(0).getDownstreamTlsContext(); return filterChains.get(0).getSslContextProviderSupplier();
} }
return listener.getDefaultFilterChain().getDownstreamTlsContext(); return listener.getDefaultFilterChain().getSslContextProviderSupplier();
} }
// destination_port present => Always fail match // destination_port present => Always fail match
@ -423,8 +450,10 @@ public final class XdsClientWrapperForServerSds {
public void shutdown() { public void shutdown() {
logger.log(Level.FINER, "Shutdown"); logger.log(Level.FINER, "Shutdown");
if (xdsClient != null) { if (xdsClient != null) {
xdsClient.cancelLdsResourceWatch(grpcServerResourceId, listenerWatcher);
xdsClient = xdsClientPool.returnObject(xdsClient); xdsClient = xdsClientPool.returnObject(xdsClient);
} }
releaseOldSuppliers(curListener.getAndSet(null));
if (timeService != null) { if (timeService != null) {
timeService = SharedResourceHolder.release(timeServiceResource, timeService); timeService = SharedResourceHolder.release(timeServiceResource, timeService);
} }

View File

@ -31,9 +31,7 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.netty.ProtocolNegotiationEvent;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.InternalXdsAttributes;
import io.grpc.xds.TlsContextManager;
import io.grpc.xds.XdsClientWrapperForServerSds; import io.grpc.xds.XdsClientWrapperForServerSds;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerAdapter;
@ -345,11 +343,11 @@ public final class SdsProtocolNegotiators {
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) { if (evt instanceof ProtocolNegotiationEvent) {
DownstreamTlsContext downstreamTlsContext = SslContextProviderSupplier sslContextProviderSupplier =
xdsClientWrapperForServerSds == null xdsClientWrapperForServerSds == null
? null ? null
: xdsClientWrapperForServerSds.getDownstreamTlsContext(ctx.channel()); : xdsClientWrapperForServerSds.getSslContextProviderSupplier(ctx.channel());
if (downstreamTlsContext == null) { if (sslContextProviderSupplier == null) {
if (fallbackProtocolNegotiator == null) { if (fallbackProtocolNegotiator == null) {
ctx.fireExceptionCaught(new CertStoreException("No certificate source found!")); ctx.fireExceptionCaught(new CertStoreException("No certificate source found!"));
return; return;
@ -369,8 +367,7 @@ public final class SdsProtocolNegotiators {
this, this,
null, null,
new ServerSdsHandler( new ServerSdsHandler(
grpcHandler, downstreamTlsContext, fallbackProtocolNegotiator, grpcHandler, sslContextProviderSupplier));
xdsClientWrapperForServerSds.getTlsContextManager()));
ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault();
ctx.fireUserEventTriggered(pne); ctx.fireUserEventTriggered(pne);
return; return;
@ -385,14 +382,11 @@ public final class SdsProtocolNegotiators {
static final class ServerSdsHandler static final class ServerSdsHandler
extends InternalProtocolNegotiators.ProtocolNegotiationHandler { extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
private final GrpcHttp2ConnectionHandler grpcHandler; private final GrpcHttp2ConnectionHandler grpcHandler;
private final DownstreamTlsContext downstreamTlsContext; private final SslContextProviderSupplier sslContextProviderSupplier;
private final TlsContextManager tlsContextManager;
@Nullable private final ProtocolNegotiator fallbackProtocolNegotiator;
ServerSdsHandler( ServerSdsHandler(
GrpcHttp2ConnectionHandler grpcHandler, GrpcHttp2ConnectionHandler grpcHandler,
DownstreamTlsContext downstreamTlsContext, SslContextProviderSupplier sslContextProviderSupplier) {
ProtocolNegotiator fallbackProtocolNegotiator, TlsContextManager tlsContextManager) {
super( super(
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
@ -405,9 +399,7 @@ public final class SdsProtocolNegotiators {
}, grpcHandler.getNegotiationLogger()); }, grpcHandler.getNegotiationLogger());
checkNotNull(grpcHandler, "grpcHandler"); checkNotNull(grpcHandler, "grpcHandler");
this.grpcHandler = grpcHandler; this.grpcHandler = grpcHandler;
this.downstreamTlsContext = downstreamTlsContext; this.sslContextProviderSupplier = sslContextProviderSupplier;
this.fallbackProtocolNegotiator = fallbackProtocolNegotiator;
this.tlsContextManager = tlsContextManager;
} }
@Override @Override
@ -415,23 +407,7 @@ public final class SdsProtocolNegotiators {
final BufferReadsHandler bufferReads = new BufferReadsHandler(); final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads); ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
SslContextProvider sslContextProviderTemp = null; sslContextProviderSupplier.updateSslContext(
try {
sslContextProviderTemp =
tlsContextManager.findOrCreateServerSslContextProvider(downstreamTlsContext);
} catch (Exception e) {
if (fallbackProtocolNegotiator == null) {
ctx.fireExceptionCaught(new CertStoreException("No certificate source found!", e));
return;
}
logger.log(Level.INFO, "Using fallback for {0}", ctx.channel().localAddress());
// Delegate rest of handshake to fallback handler
ctx.pipeline().replace(this, null, fallbackProtocolNegotiator.newHandler(grpcHandler));
ctx.pipeline().remove(bufferReads);
return;
}
final SslContextProvider sslContextProvider = sslContextProviderTemp;
sslContextProvider.addCallback(
new SslContextProvider.Callback(ctx.executor()) { new SslContextProvider.Callback(ctx.executor()) {
@Override @Override
@ -445,7 +421,6 @@ public final class SdsProtocolNegotiators {
fireProtocolNegotiationEvent(ctx); fireProtocolNegotiationEvent(ctx);
ctx.pipeline().remove(bufferReads); ctx.pipeline().remove(bufferReads);
} }
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
} }
@Override @Override

View File

@ -19,11 +19,13 @@ package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import com.google.common.base.MoreObjects;
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.TlsContextManager; import io.grpc.xds.TlsContextManager;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import java.util.Objects;
/** /**
* Enables Client or server side to initialize this object with the received {@link BaseTlsContext} * Enables Client or server side to initialize this object with the received {@link BaseTlsContext}
@ -52,27 +54,36 @@ public final class SslContextProviderSupplier implements Closeable {
/** Updates SslContext via the passed callback. */ /** Updates SslContext via the passed callback. */
public synchronized void updateSslContext(final SslContextProvider.Callback callback) { public synchronized void updateSslContext(final SslContextProvider.Callback callback) {
checkNotNull(callback, "callback"); checkNotNull(callback, "callback");
checkState(!shutdown, "Supplier is shutdown!"); try {
if (sslContextProvider == null) { checkState(!shutdown, "Supplier is shutdown!");
sslContextProvider = getSslContextProvider(); if (sslContextProvider == null) {
sslContextProvider = getSslContextProvider();
}
// we want to increment the ref-count so call findOrCreate again...
final SslContextProvider toRelease = getSslContextProvider();
sslContextProvider.addCallback(
new SslContextProvider.Callback(callback.getExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
callback.updateSecret(sslContext);
releaseSslContextProvider(toRelease);
}
@Override
public void onException(Throwable throwable) {
callback.onException(throwable);
releaseSslContextProvider(toRelease);
}
});
} catch (final Throwable throwable) {
callback.getExecutor().execute(new Runnable() {
@Override
public void run() {
callback.onException(throwable);
}
});
} }
// we want to increment the ref-count so call findOrCreate again...
final SslContextProvider toRelease = getSslContextProvider();
sslContextProvider.addCallback(
new SslContextProvider.Callback(callback.getExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
callback.updateSecret(sslContext);
releaseSslContextProvider(toRelease);
}
@Override
public void onException(Throwable throwable) {
callback.onException(throwable);
releaseSslContextProvider(toRelease);
}
});
} }
private void releaseSslContextProvider(SslContextProvider toRelease) { private void releaseSslContextProvider(SslContextProvider toRelease) {
@ -101,4 +112,34 @@ public final class SslContextProviderSupplier implements Closeable {
} }
shutdown = true; shutdown = true;
} }
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SslContextProviderSupplier that = (SslContextProviderSupplier) o;
return shutdown == that.shutdown
&& Objects.equals(tlsContext, that.tlsContext)
&& Objects.equals(tlsContextManager, that.tlsContextManager)
&& Objects.equals(sslContextProvider, that.sslContextProvider);
}
@Override
public int hashCode() {
return Objects.hash(tlsContext, tlsContextManager, sslContextProvider, shutdown);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("tlsContext", tlsContext)
.add("tlsContextManager", tlsContextManager)
.add("sslContextProvider", sslContextProvider)
.add("shutdown", shutdown)
.toString();
}
} }

View File

@ -688,7 +688,7 @@ public class ClientXdsClientDataTest {
.setTrafficDirection(TrafficDirection.OUTBOUND) .setTrafficDirection(TrafficDirection.OUTBOUND)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()).isEqualTo("Listener listener1 is not INBOUND"); assertThat(struct.getErrorDetail()).isEqualTo("Listener listener1 is not INBOUND");
} }
@ -701,7 +701,7 @@ public class ClientXdsClientDataTest {
.addListenerFilters(ListenerFilter.newBuilder().build()) .addListenerFilters(ListenerFilter.newBuilder().build())
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo("Listener listener1 cannot have listener_filters"); .isEqualTo("Listener listener1 cannot have listener_filters");
} }
@ -715,7 +715,7 @@ public class ClientXdsClientDataTest {
.setUseOriginalDst(BoolValue.of(true)) .setUseOriginalDst(BoolValue.of(true))
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo("Listener listener1 cannot have use_original_dst set to true"); .isEqualTo("Listener listener1 cannot have use_original_dst set to true");
} }
@ -729,7 +729,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(FilterChain.newBuilder().build()) .addFilterChains(FilterChain.newBuilder().build())
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo("filerChain has to have envoy.http_connection_manager"); .isEqualTo("filerChain has to have envoy.http_connection_manager");
} }
@ -753,7 +753,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo("filerChain has non-unique filter name:envoy.http_connection_manager"); .isEqualTo("filerChain has non-unique filter name:envoy.http_connection_manager");
} }
@ -773,7 +773,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo("filter envoy.http_connection_manager with config_discovery not supported"); .isEqualTo("filter envoy.http_connection_manager with config_discovery not supported");
} }
@ -789,7 +789,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo("filter envoy.http_connection_manager expected to have typed_config"); .isEqualTo("filter envoy.http_connection_manager expected to have typed_config");
} }
@ -809,7 +809,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo( .isEqualTo(
"filter envoy.http_connection_manager with unsupported typed_config type:badTypeUrl"); "filter envoy.http_connection_manager with unsupported typed_config type:badTypeUrl");
@ -830,7 +830,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo("http-connection-manager has non-unique http-filter name:hf"); .isEqualTo("http-connection-manager has non-unique http-filter name:hf");
} }
@ -852,7 +852,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo( .isEqualTo(
"http-connection-manager http-filter envoy.router uses " "http-connection-manager http-filter envoy.router uses "
@ -877,7 +877,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo( .isEqualTo(
"http-connection-manager http-filter envoy.router has unsupported typed-config type:" "http-connection-manager http-filter envoy.router has unsupported typed-config type:"
@ -898,7 +898,7 @@ public class ClientXdsClientDataTest {
.addFilterChains(filterChain) .addFilterChains(filterChain)
.build(); .build();
StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct = StructOrError<io.grpc.xds.EnvoyServerProtoData.Listener> struct =
ClientXdsClient.parseServerSideListener(listener); ClientXdsClient.parseServerSideListener(listener, null);
assertThat(struct.getErrorDetail()) assertThat(struct.getErrorDetail())
.isEqualTo( .isEqualTo(
"http-connection-manager http-filter envoy.filters.http.router should have " "http-connection-manager http-filter envoy.filters.http.router should have "

View File

@ -234,6 +234,8 @@ public abstract class ClientXdsClientTestBase {
private CdsResourceWatcher cdsResourceWatcher; private CdsResourceWatcher cdsResourceWatcher;
@Mock @Mock
private EdsResourceWatcher edsResourceWatcher; private EdsResourceWatcher edsResourceWatcher;
@Mock
private TlsContextManager tlsContextManager;
private ManagedChannel channel; private ManagedChannel channel;
private ClientXdsClient xdsClient; private ClientXdsClient xdsClient;
@ -279,7 +281,7 @@ public abstract class ClientXdsClientTestBase {
backoffPolicyProvider, backoffPolicyProvider,
fakeClock.getStopwatchSupplier(), fakeClock.getStopwatchSupplier(),
timeProvider, timeProvider,
mock(TlsContextManager.class)); tlsContextManager);
assertThat(resourceDiscoveryCalls).isEmpty(); assertThat(resourceDiscoveryCalls).isEmpty();
assertThat(loadReportCalls).isEmpty(); assertThat(loadReportCalls).isEmpty();
@ -2021,7 +2023,7 @@ public abstract class ClientXdsClientTestBase {
ClientXdsClientTestBase.DiscoveryRpcCall call = ClientXdsClientTestBase.DiscoveryRpcCall call =
startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher);
Message listener = Message listener =
mf.buildListenerWithFilterChain( mf.buildListenerWithFilterChain(
LISTENER_RESOURCE, 7000, "0.0.0.0", "google-sds-config-default", "ROOTCA"); LISTENER_RESOURCE, 7000, "0.0.0.0", "google-sds-config-default", "ROOTCA");
List<Any> listeners = ImmutableList.of(Any.pack(listener)); List<Any> listeners = ImmutableList.of(Any.pack(listener));
call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); call.sendResponse(ResourceType.LDS, listeners, "0", "0000");
@ -2030,10 +2032,11 @@ public abstract class ClientXdsClientTestBase {
ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE);
verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture());
assertThat(ldsUpdateCaptor.getValue().listener) assertThat(ldsUpdateCaptor.getValue().listener)
.isEqualTo(EnvoyServerProtoData.Listener.fromEnvoyProtoListener((Listener)listener)); .isEqualTo(EnvoyServerProtoData.Listener
.fromEnvoyProtoListener((Listener) listener, tlsContextManager));
listener = listener =
mf.buildListenerWithFilterChain( mf.buildListenerWithFilterChain(
LISTENER_RESOURCE, 7000, "0.0.0.0", "CERT2", "ROOTCA2"); LISTENER_RESOURCE, 7000, "0.0.0.0", "CERT2", "ROOTCA2");
listeners = ImmutableList.of(Any.pack(listener)); listeners = ImmutableList.of(Any.pack(listener));
call.sendResponse(ResourceType.LDS, listeners, "1", "0001"); call.sendResponse(ResourceType.LDS, listeners, "1", "0001");
@ -2043,7 +2046,8 @@ public abstract class ClientXdsClientTestBase {
ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "1", "0001", NODE); ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "1", "0001", NODE);
verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture());
assertThat(ldsUpdateCaptor.getValue().listener) assertThat(ldsUpdateCaptor.getValue().listener)
.isEqualTo(EnvoyServerProtoData.Listener.fromEnvoyProtoListener((Listener)listener)); .isEqualTo(EnvoyServerProtoData.Listener
.fromEnvoyProtoListener((Listener) listener, tlsContextManager));
assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty();
} }

View File

@ -17,6 +17,7 @@
package io.grpc.xds; package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import com.google.protobuf.Any; import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
@ -34,6 +35,7 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.Listener; import io.grpc.xds.EnvoyServerProtoData.Listener;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import io.grpc.xds.internal.sds.SslContextProviderSupplier;
import java.util.List; import java.util.List;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -61,7 +63,7 @@ public class EnvoyServerProtoDataTest {
.setTrafficDirection(TrafficDirection.INBOUND) .setTrafficDirection(TrafficDirection.INBOUND)
.build(); .build();
Listener xdsListener = Listener.fromEnvoyProtoListener(listener); Listener xdsListener = Listener.fromEnvoyProtoListener(listener, mock(TlsContextManager.class));
assertThat(xdsListener.getName()).isEqualTo("8000"); assertThat(xdsListener.getName()).isEqualTo("8000");
assertThat(xdsListener.getAddress()).isEqualTo("10.2.1.34:8000"); assertThat(xdsListener.getAddress()).isEqualTo("10.2.1.34:8000");
List<EnvoyServerProtoData.FilterChain> filterChains = xdsListener.getFilterChains(); List<EnvoyServerProtoData.FilterChain> filterChains = xdsListener.getFilterChains();
@ -81,7 +83,11 @@ public class EnvoyServerProtoDataTest {
assertThat(inFilterChainMatch.getConnectionSourceType()) assertThat(inFilterChainMatch.getConnectionSourceType())
.isEqualTo(EnvoyServerProtoData.ConnectionSourceType.EXTERNAL); .isEqualTo(EnvoyServerProtoData.ConnectionSourceType.EXTERNAL);
assertThat(inFilterChainMatch.getSourcePorts()).containsExactly(200, 300); assertThat(inFilterChainMatch.getSourcePorts()).containsExactly(200, 300);
DownstreamTlsContext inFilterTlsContext = inFilter.getDownstreamTlsContext(); SslContextProviderSupplier sslContextProviderSupplier = inFilter
.getSslContextProviderSupplier();
assertThat(sslContextProviderSupplier.getTlsContext()).isInstanceOf(DownstreamTlsContext.class);
DownstreamTlsContext inFilterTlsContext = (DownstreamTlsContext) sslContextProviderSupplier
.getTlsContext();
assertThat(inFilterTlsContext.getCommonTlsContext()).isNotNull(); assertThat(inFilterTlsContext.getCommonTlsContext()).isNotNull();
CommonTlsContext commonTlsContext = inFilterTlsContext.getCommonTlsContext(); CommonTlsContext commonTlsContext = inFilterTlsContext.getCommonTlsContext();
List<SdsSecretConfig> tlsCertSdsConfigs = commonTlsContext List<SdsSecretConfig> tlsCertSdsConfigs = commonTlsContext

View File

@ -23,6 +23,7 @@ import static org.mockito.Mockito.when;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import io.grpc.xds.internal.sds.SslContextProviderSupplier;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
@ -46,6 +47,7 @@ public class FilterChainMatchTest {
private static final String REMOTE_IP = "10.4.2.3"; // source private static final String REMOTE_IP = "10.4.2.3"; // source
@Mock private Channel channel; @Mock private Channel channel;
@Mock private TlsContextManager tlsContextManager;
private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
private XdsClient.LdsResourceWatcher registeredWatcher; private XdsClient.LdsResourceWatcher registeredWatcher;
@ -54,7 +56,7 @@ public class FilterChainMatchTest {
public void setUp() throws IOException { public void setUp() throws IOException {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
xdsClientWrapperForServerSds = XdsServerTestHelper xdsClientWrapperForServerSds = XdsServerTestHelper
.createXdsClientWrapperForServerSds(PORT, null); .createXdsClientWrapperForServerSds(PORT, tlsContextManager);
registeredWatcher = registeredWatcher =
XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds);
} }
@ -64,6 +66,17 @@ public class FilterChainMatchTest {
xdsClientWrapperForServerSds.shutdown(); xdsClientWrapperForServerSds.shutdown();
} }
private DownstreamTlsContext getDownstreamTlsContext() {
SslContextProviderSupplier sslContextProviderSupplier =
xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel);
if (sslContextProviderSupplier != null) {
EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext();
assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class);
return (DownstreamTlsContext) tlsContext;
}
return null;
}
@Test @Test
public void singleFilterChainWithoutAlpn() throws UnknownHostException { public void singleFilterChainWithoutAlpn() throws UnknownHostException {
setupChannel(LOCAL_IP, REMOTE_IP, 15000); setupChannel(LOCAL_IP, REMOTE_IP, 15000);
@ -78,13 +91,12 @@ public class FilterChainMatchTest {
DownstreamTlsContext tlsContext = DownstreamTlsContext tlsContext =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain filterChain =
new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContext); assertThat(tlsContext1).isSameInstanceAs(tlsContext);
} }
@ -102,13 +114,12 @@ public class FilterChainMatchTest {
DownstreamTlsContext tlsContext = DownstreamTlsContext tlsContext =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain filterChain =
new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContext); assertThat(tlsContext1).isSameInstanceAs(tlsContext);
} }
@ -118,14 +129,13 @@ public class FilterChainMatchTest {
DownstreamTlsContext tlsContext = DownstreamTlsContext tlsContext =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain filterChain =
new EnvoyServerProtoData.FilterChain(null, tlsContext); new EnvoyServerProtoData.FilterChain(null, tlsContext, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.<EnvoyServerProtoData.FilterChain>asList(), filterChain); "listener1", LOCAL_IP, Arrays.<EnvoyServerProtoData.FilterChain>asList(), filterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContext); assertThat(tlsContext1).isSameInstanceAs(tlsContext);
} }
@ -143,18 +153,19 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainWithDestPort = EnvoyServerProtoData.FilterChain filterChainWithDestPort =
new EnvoyServerProtoData.FilterChain(filterChainMatchWithDestPort, tlsContextWithDestPort); new EnvoyServerProtoData.FilterChain(filterChainMatchWithDestPort, tlsContextWithDestPort,
tlsContextManager);
DownstreamTlsContext tlsContextForDefaultFilterChain = DownstreamTlsContext tlsContextForDefaultFilterChain =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain,
tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.asList(filterChainWithDestPort), defaultFilterChain); "listener1", LOCAL_IP, Arrays.asList(filterChainWithDestPort), defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain);
} }
@ -172,18 +183,19 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainWithMatch = EnvoyServerProtoData.FilterChain filterChainWithMatch =
new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch); new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch,
tlsContextManager);
DownstreamTlsContext tlsContextForDefaultFilterChain = DownstreamTlsContext tlsContextForDefaultFilterChain =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain,
tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch);
} }
@ -203,24 +215,25 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainWithMismatch = EnvoyServerProtoData.FilterChain filterChainWithMismatch =
new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch); new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch,
tlsContextManager);
DownstreamTlsContext tlsContextForDefaultFilterChain = DownstreamTlsContext tlsContextForDefaultFilterChain =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain,
tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain);
} }
@Test @Test
public void dest0LengthPrefixRange() public void dest0LengthPrefixRange()
throws UnknownHostException, InvalidProtocolBufferException { throws UnknownHostException, InvalidProtocolBufferException {
setupChannel(LOCAL_IP, REMOTE_IP, 15000); setupChannel(LOCAL_IP, REMOTE_IP, 15000);
DownstreamTlsContext tlsContext0Length = DownstreamTlsContext tlsContext0Length =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
@ -234,18 +247,19 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChain0Length = EnvoyServerProtoData.FilterChain filterChain0Length =
new EnvoyServerProtoData.FilterChain(filterChainMatch0Length, tlsContext0Length); new EnvoyServerProtoData.FilterChain(filterChainMatch0Length, tlsContext0Length,
tlsContextManager);
DownstreamTlsContext tlsContextForDefaultFilterChain = DownstreamTlsContext tlsContextForDefaultFilterChain =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain,
tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.asList(filterChain0Length), defaultFilterChain); "listener1", LOCAL_IP, Arrays.asList(filterChain0Length), defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContext0Length); assertThat(tlsContext1).isSameInstanceAs(tlsContext0Length);
} }
@ -264,7 +278,8 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainLessSpecific = EnvoyServerProtoData.FilterChain filterChainLessSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific,
tlsContextManager);
DownstreamTlsContext tlsContextMoreSpecific = DownstreamTlsContext tlsContextMoreSpecific =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
@ -277,9 +292,10 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainMoreSpecific = EnvoyServerProtoData.FilterChain filterChainMoreSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific,
tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", "listener1",
@ -288,14 +304,13 @@ public class FilterChainMatchTest {
defaultFilterChain); defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific);
} }
@Test @Test
public void destPrefixRange_emptyListLessSpecific() public void destPrefixRange_emptyListLessSpecific()
throws UnknownHostException, InvalidProtocolBufferException { throws UnknownHostException, InvalidProtocolBufferException {
setupChannel(LOCAL_IP, REMOTE_IP, 15000); setupChannel(LOCAL_IP, REMOTE_IP, 15000);
DownstreamTlsContext tlsContextLessSpecific = DownstreamTlsContext tlsContextLessSpecific =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
@ -308,7 +323,8 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainLessSpecific = EnvoyServerProtoData.FilterChain filterChainLessSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific,
tlsContextManager);
DownstreamTlsContext tlsContextMoreSpecific = DownstreamTlsContext tlsContextMoreSpecific =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
@ -321,9 +337,10 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainMoreSpecific = EnvoyServerProtoData.FilterChain filterChainMoreSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific,
tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", "listener1",
@ -332,8 +349,7 @@ public class FilterChainMatchTest {
defaultFilterChain); defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific);
} }
@ -352,7 +368,8 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainLessSpecific = EnvoyServerProtoData.FilterChain filterChainLessSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific,
tlsContextManager);
DownstreamTlsContext tlsContextMoreSpecific = DownstreamTlsContext tlsContextMoreSpecific =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
@ -365,9 +382,10 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainMoreSpecific = EnvoyServerProtoData.FilterChain filterChainMoreSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific,
tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", "listener1",
@ -376,8 +394,7 @@ public class FilterChainMatchTest {
defaultFilterChain); defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific);
} }
@ -399,7 +416,7 @@ public class FilterChainMatchTest {
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 =
new EnvoyServerProtoData.FilterChain( new EnvoyServerProtoData.FilterChain(
filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2); filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2, tlsContextManager);
DownstreamTlsContext tlsContextLessSpecific = DownstreamTlsContext tlsContextLessSpecific =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
@ -412,9 +429,10 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainLessSpecific = EnvoyServerProtoData.FilterChain filterChainLessSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific,
tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", "listener1",
@ -423,8 +441,7 @@ public class FilterChainMatchTest {
defaultFilterChain); defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2);
} }
@ -442,18 +459,19 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainWithMismatch = EnvoyServerProtoData.FilterChain filterChainWithMismatch =
new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch); new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch,
tlsContextManager);
DownstreamTlsContext tlsContextForDefaultFilterChain = DownstreamTlsContext tlsContextForDefaultFilterChain =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain,
tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain);
} }
@ -471,18 +489,19 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainWithMatch = EnvoyServerProtoData.FilterChain filterChainWithMatch =
new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch); new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch,
tlsContextManager);
DownstreamTlsContext tlsContextForDefaultFilterChain = DownstreamTlsContext tlsContextForDefaultFilterChain =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain,
tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch);
} }
@ -504,7 +523,7 @@ public class FilterChainMatchTest {
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 =
new EnvoyServerProtoData.FilterChain( new EnvoyServerProtoData.FilterChain(
filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2); filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2, tlsContextManager);
DownstreamTlsContext tlsContextLessSpecific = DownstreamTlsContext tlsContextLessSpecific =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
@ -517,9 +536,10 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainLessSpecific = EnvoyServerProtoData.FilterChain filterChainLessSpecific =
new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific,
tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", "listener1",
@ -528,14 +548,13 @@ public class FilterChainMatchTest {
defaultFilterChain); defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2);
} }
@Test @Test
public void sourcePrefixRange_2Matchers_expectException() public void sourcePrefixRange_2Matchers_expectException()
throws UnknownHostException, InvalidProtocolBufferException { throws UnknownHostException, InvalidProtocolBufferException {
setupChannel(LOCAL_IP, REMOTE_IP, 15000); setupChannel(LOCAL_IP, REMOTE_IP, 15000);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
@ -550,7 +569,7 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChain1 = EnvoyServerProtoData.FilterChain filterChain1 =
new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1); new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1, tlsContextManager);
DownstreamTlsContext tlsContext2 = DownstreamTlsContext tlsContext2 =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
@ -563,16 +582,16 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChain2 = EnvoyServerProtoData.FilterChain filterChain2 =
new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2); new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2, tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, null);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", LOCAL_IP, Arrays.asList(filterChain1, filterChain2), defaultFilterChain); "listener1", LOCAL_IP, Arrays.asList(filterChain1, filterChain2), defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
try { try {
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel);
fail("expect exception!"); fail("expect exception!");
} catch (IllegalStateException ise) { } catch (IllegalStateException ise) {
assertThat(ise).hasMessageThat().isEqualTo("Found 2 matching filter-chains"); assertThat(ise).hasMessageThat().isEqualTo("Found 2 matching filter-chains");
@ -597,7 +616,7 @@ public class FilterChainMatchTest {
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts = EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts =
new EnvoyServerProtoData.FilterChain( new EnvoyServerProtoData.FilterChain(
filterChainMatchEmptySourcePorts, tlsContextEmptySourcePorts); filterChainMatchEmptySourcePorts, tlsContextEmptySourcePorts, tlsContextManager);
DownstreamTlsContext tlsContextSourcePortMatch = DownstreamTlsContext tlsContextSourcePortMatch =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
@ -611,9 +630,9 @@ public class FilterChainMatchTest {
Arrays.asList(7000, 15000)); Arrays.asList(7000, 15000));
EnvoyServerProtoData.FilterChain filterChainSourcePortMatch = EnvoyServerProtoData.FilterChain filterChainSourcePortMatch =
new EnvoyServerProtoData.FilterChain( new EnvoyServerProtoData.FilterChain(
filterChainMatchSourcePortMatch, tlsContextSourcePortMatch); filterChainMatchSourcePortMatch, tlsContextSourcePortMatch, tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", "listener1",
@ -622,8 +641,7 @@ public class FilterChainMatchTest {
defaultFilterChain); defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext1 = DownstreamTlsContext tlsContext1 = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContext1).isSameInstanceAs(tlsContextSourcePortMatch); assertThat(tlsContext1).isSameInstanceAs(tlsContextSourcePortMatch);
} }
@ -660,7 +678,7 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChain1 = EnvoyServerProtoData.FilterChain filterChain1 =
new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1); new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1, tlsContextManager);
// next 5 use prefix range: 4 with prefixLen of 30 and last one with 29 // next 5 use prefix range: 4 with prefixLen of 30 and last one with 29
@ -674,7 +692,7 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChain2 = EnvoyServerProtoData.FilterChain filterChain2 =
new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2); new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2, tlsContextManager);
// has prefix ranges with one not matching and source type local: gets eliminated in step 3 // has prefix ranges with one not matching and source type local: gets eliminated in step 3
EnvoyServerProtoData.FilterChainMatch filterChainMatch3 = EnvoyServerProtoData.FilterChainMatch filterChainMatch3 =
@ -688,7 +706,7 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChain3 = EnvoyServerProtoData.FilterChain filterChain3 =
new EnvoyServerProtoData.FilterChain(filterChainMatch3, tlsContext3); new EnvoyServerProtoData.FilterChain(filterChainMatch3, tlsContext3, tlsContextManager);
// has prefix ranges with both matching and source type external but non matching source port: // has prefix ranges with both matching and source type external but non matching source port:
// gets eliminated in step 5 // gets eliminated in step 5
@ -703,7 +721,7 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, EnvoyServerProtoData.ConnectionSourceType.EXTERNAL,
Arrays.asList(16000, 9000)); Arrays.asList(16000, 9000));
EnvoyServerProtoData.FilterChain filterChain4 = EnvoyServerProtoData.FilterChain filterChain4 =
new EnvoyServerProtoData.FilterChain(filterChainMatch4, tlsContext4); new EnvoyServerProtoData.FilterChain(filterChainMatch4, tlsContext4, tlsContextManager);
// has prefix ranges with both matching and source type external and matching source port: this // has prefix ranges with both matching and source type external and matching source port: this
// gets selected // gets selected
@ -720,7 +738,7 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.asList(15000, 8000)); Arrays.asList(15000, 8000));
EnvoyServerProtoData.FilterChain filterChain5 = EnvoyServerProtoData.FilterChain filterChain5 =
new EnvoyServerProtoData.FilterChain(filterChainMatch5, tlsContext5); new EnvoyServerProtoData.FilterChain(filterChainMatch5, tlsContext5, tlsContextManager);
// has prefix range with prefixLen of 29: gets eliminated in step 2 // has prefix range with prefixLen of 29: gets eliminated in step 2
EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = EnvoyServerProtoData.FilterChainMatch filterChainMatch6 =
@ -732,10 +750,10 @@ public class FilterChainMatchTest {
EnvoyServerProtoData.ConnectionSourceType.ANY, EnvoyServerProtoData.ConnectionSourceType.ANY,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain filterChain6 = EnvoyServerProtoData.FilterChain filterChain6 =
new EnvoyServerProtoData.FilterChain(filterChainMatch6, tlsContext6); new EnvoyServerProtoData.FilterChain(filterChainMatch6, tlsContext6, tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
"listener1", "listener1",
@ -745,8 +763,7 @@ public class FilterChainMatchTest {
defaultFilterChain); defaultFilterChain);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContextPicked = DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext();
xdsClientWrapperForServerSds.getDownstreamTlsContext(channel);
assertThat(tlsContextPicked).isSameInstanceAs(tlsContext5); assertThat(tlsContextPicked).isSameInstanceAs(tlsContext5);
} }

View File

@ -66,13 +66,15 @@ public class ServerWrapperForXdsTest {
private XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener; private XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener;
private XdsClient.LdsResourceWatcher listenerWatcher; private XdsClient.LdsResourceWatcher listenerWatcher;
private Server mockServer; private Server mockServer;
private TlsContextManager tlsContextManager;
@Before @Before
public void setUp() throws IOException { public void setUp() throws IOException {
port = XdsServerTestHelper.findFreePort(); port = XdsServerTestHelper.findFreePort();
mockDelegateBuilder = mock(ServerBuilder.class); mockDelegateBuilder = mock(ServerBuilder.class);
tlsContextManager = mock(TlsContextManager.class);
xdsClientWrapperForServerSds = XdsServerTestHelper xdsClientWrapperForServerSds = XdsServerTestHelper
.createXdsClientWrapperForServerSds(port, null); .createXdsClientWrapperForServerSds(port, tlsContextManager);
mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class);
listenerWatcher = listenerWatcher =
XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds);
@ -117,8 +119,8 @@ public class ServerWrapperForXdsTest {
verifyCapturedCodeAndNotServing(Status.Code.ABORTED, ServerWrapperForXds.ServingState.STARTING); verifyCapturedCodeAndNotServing(Status.Code.ABORTED, ServerWrapperForXds.ServingState.STARTING);
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
Throwable exception = future.get(2, TimeUnit.SECONDS); Throwable exception = future.get(2, TimeUnit.SECONDS);
assertThat(exception).isNull(); assertThat(exception).isNull();
assertThat(serverWrapperForXds.getCurrentServingState()) assertThat(serverWrapperForXds.getCurrentServingState())
@ -163,8 +165,8 @@ public class ServerWrapperForXdsTest {
public void run() { public void run() {
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"),
); tlsContextManager);
} }
}).start(); }).start();
assertThat(settableFutureToSignalStart.get()).isNull(); assertThat(settableFutureToSignalStart.get()).isNull();
@ -197,9 +199,9 @@ public class ServerWrapperForXdsTest {
@Override @Override
public void run() { public void run() {
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"),
); tlsContextManager);
} }
}).start(); }).start();
Throwable exception = future.get(2, TimeUnit.SECONDS); Throwable exception = future.get(2, TimeUnit.SECONDS);
@ -242,8 +244,8 @@ public class ServerWrapperForXdsTest {
Future<Throwable> future = startServerAsync(); Future<Throwable> future = startServerAsync();
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"),
); tlsContextManager);
Throwable exception = future.get(2, TimeUnit.SECONDS); Throwable exception = future.get(2, TimeUnit.SECONDS);
assertThat(exception).isNull(); assertThat(exception).isNull();
assertThat(serverWrapperForXds.getCurrentServingState()) assertThat(serverWrapperForXds.getCurrentServingState())
@ -256,8 +258,8 @@ public class ServerWrapperForXdsTest {
when(mockDelegateBuilder.build()).thenReturn(mockServer); when(mockDelegateBuilder.build()).thenReturn(mockServer);
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"),
); tlsContextManager);
Thread.sleep(100L); Thread.sleep(100L);
assertThat(serverWrapperForXds.getCurrentServingState()) assertThat(serverWrapperForXds.getCurrentServingState())
.isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN);
@ -269,8 +271,8 @@ public class ServerWrapperForXdsTest {
Future<Throwable> future = startServerAsync(); Future<Throwable> future = startServerAsync();
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"),
); tlsContextManager);
Throwable exception = future.get(2, TimeUnit.SECONDS); Throwable exception = future.get(2, TimeUnit.SECONDS);
assertThat(exception).isNull(); assertThat(exception).isNull();
assertThat(serverWrapperForXds.getCurrentServingState()) assertThat(serverWrapperForXds.getCurrentServingState())
@ -302,8 +304,8 @@ public class ServerWrapperForXdsTest {
public void run() { public void run() {
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"),
); tlsContextManager);
} }
}).start(); }).start();
assertThat(settableFutureToSignalStart.get()).isNull(); assertThat(settableFutureToSignalStart.get()).isNull();

View File

@ -19,9 +19,11 @@ package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -30,12 +32,15 @@ import io.grpc.StatusException;
import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.inprocess.InProcessSocketAddress;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import io.grpc.xds.internal.sds.SslContextProvider;
import io.grpc.xds.internal.sds.SslContextProviderSupplier;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -53,15 +58,27 @@ public class XdsClientWrapperForServerSdsTestMisc {
private static final int PORT = 7000; private static final int PORT = 7000;
@Mock private Channel channel; @Mock private Channel channel;
@Mock private TlsContextManager tlsContextManager;
@Mock private XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher;
private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
private XdsClient.LdsResourceWatcher registeredWatcher; private XdsClient.LdsResourceWatcher registeredWatcher;
private InetSocketAddress localAddress;
private DownstreamTlsContext tlsContext1;
private DownstreamTlsContext tlsContext2;
private DownstreamTlsContext tlsContext3;
@Before @Before
public void setUp() throws IOException { public void setUp() throws IOException {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
tlsContext1 =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
tlsContext2 =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2");
tlsContext3 =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3");
xdsClientWrapperForServerSds = XdsServerTestHelper xdsClientWrapperForServerSds = XdsServerTestHelper
.createXdsClientWrapperForServerSds(PORT, null); .createXdsClientWrapperForServerSds(PORT, tlsContextManager);
} }
@After @After
@ -73,7 +90,9 @@ public class XdsClientWrapperForServerSdsTestMisc {
public void nonInetSocketAddress_expectNull() throws UnknownHostException { public void nonInetSocketAddress_expectNull() throws UnknownHostException {
registeredWatcher = registeredWatcher =
XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds);
assertThat(sendListenerUpdate(new InProcessSocketAddress("test1"), null)).isNull(); assertThat(
sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager))
.isNull();
} }
@Test @Test
@ -83,7 +102,7 @@ public class XdsClientWrapperForServerSdsTestMisc {
try { try {
InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3");
InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT + 1); InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT + 1);
DownstreamTlsContext unused = sendListenerUpdate(localAddress, null); sendListenerUpdate(localAddress, null, null, tlsContextManager);
fail("exception expected"); fail("exception expected");
} catch (IllegalStateException expected) { } catch (IllegalStateException expected) {
assertThat(expected) assertThat(expected)
@ -114,86 +133,170 @@ public class XdsClientWrapperForServerSdsTestMisc {
null); null);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
DownstreamTlsContext tlsContext = xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); DownstreamTlsContext tlsContext = getDownstreamTlsContext();
assertThat(tlsContext).isNull(); assertThat(tlsContext).isNull();
} }
@Test
public void registerServerWatcher() throws UnknownHostException {
registeredWatcher =
XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds);
XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher =
mock(XdsClientWrapperForServerSds.ServerWatcher.class);
xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher);
InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3");
InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT);
EnvoyServerProtoData.DownstreamTlsContext tlsContext =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
verify(mockServerWatcher, never())
.onListenerUpdate();
DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext);
assertThat(returnedTlsContext).isSameInstanceAs(tlsContext);
verify(mockServerWatcher).onListenerUpdate();
xdsClientWrapperForServerSds.removeServerWatcher(mockServerWatcher);
}
@Test @Test
public void registerServerWatcher_afterListenerUpdate() throws UnknownHostException { public void registerServerWatcher_afterListenerUpdate() throws UnknownHostException {
registeredWatcher = registerWatcherAndCreateListenerUpdate(tlsContext1);
XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds);
InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3");
InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT);
EnvoyServerProtoData.DownstreamTlsContext tlsContext =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext);
assertThat(returnedTlsContext).isSameInstanceAs(tlsContext);
XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher =
mock(XdsClientWrapperForServerSds.ServerWatcher.class);
xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher);
verify(mockServerWatcher).onListenerUpdate(); verify(mockServerWatcher).onListenerUpdate();
} }
@Test @Test
public void registerServerWatcher_notifyError() throws UnknownHostException { public void registerServerWatcher_notifyNotFound() throws UnknownHostException {
registeredWatcher = commonErrorCheck(true, Status.NOT_FOUND, true);
XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); }
XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher =
mock(XdsClientWrapperForServerSds.ServerWatcher.class); @Test
xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); public void registerServerWatcher_notifyInternalError() throws UnknownHostException {
commonErrorCheck(false, Status.INTERNAL, false);
}
@Test
public void registerServerWatcher_notifyPermDeniedError() throws UnknownHostException {
commonErrorCheck(false, Status.PERMISSION_DENIED, true);
}
@Test
public void releaseOldSupplierOnChanged_noCloseDueToLazyLoading() throws UnknownHostException {
registerWatcherAndCreateListenerUpdate(tlsContext1);
XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext2, tlsContextManager);
verify(tlsContextManager, never())
.findOrCreateServerSslContextProvider(any(DownstreamTlsContext.class));
}
@Test
public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws UnknownHostException {
SslContextProvider sslContextProvider1 = mock(SslContextProvider.class);
when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1)))
.thenReturn(sslContextProvider1);
registerWatcherAndCreateListenerUpdate(tlsContext1);
callUpdateSslContext(channel);
XdsServerTestHelper
.generateListenerUpdate(registeredWatcher, Arrays.<Integer>asList(1234), tlsContext2,
tlsContext3, tlsContextManager);
verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1));
reset(tlsContextManager);
SslContextProvider sslContextProvider2 = mock(SslContextProvider.class);
when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext2)))
.thenReturn(sslContextProvider2);
SslContextProvider sslContextProvider3 = mock(SslContextProvider.class);
when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext3)))
.thenReturn(sslContextProvider3);
callUpdateSslContext(channel);
InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6");
InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1111);
when(channel.remoteAddress()).thenReturn(remoteAddress);
callUpdateSslContext(channel);
XdsClient mockXdsClient = xdsClientWrapperForServerSds.getXdsClient();
xdsClientWrapperForServerSds.shutdown();
verify(mockXdsClient, times(1))
.cancelLdsResourceWatch(eq("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT),
eq(registeredWatcher));
verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1));
verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider2));
verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider3));
}
@Test
public void releaseOldSupplierOnNotFound_verifyClose() throws UnknownHostException {
SslContextProvider sslContextProvider1 = mock(SslContextProvider.class);
when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1)))
.thenReturn(sslContextProvider1);
registerWatcherAndCreateListenerUpdate(tlsContext1);
callUpdateSslContext(channel);
registeredWatcher.onResourceDoesNotExist("not-found Error");
verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1));
}
@Test
public void releaseOldSupplierOnPermDeniedError_verifyClose() throws UnknownHostException {
SslContextProvider sslContextProvider1 = mock(SslContextProvider.class);
when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1)))
.thenReturn(sslContextProvider1);
registerWatcherAndCreateListenerUpdate(tlsContext1);
callUpdateSslContext(channel);
registeredWatcher.onError(Status.PERMISSION_DENIED);
verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1));
}
@Test
public void releaseOldSupplierOnInternalError_noClose() throws UnknownHostException {
SslContextProvider sslContextProvider1 = mock(SslContextProvider.class);
when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1)))
.thenReturn(sslContextProvider1);
registerWatcherAndCreateListenerUpdate(tlsContext1);
callUpdateSslContext(channel);
registeredWatcher.onError(Status.INTERNAL); registeredWatcher.onError(Status.INTERNAL);
verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1));
}
private void callUpdateSslContext(Channel channel) {
SslContextProviderSupplier sslContextProviderSupplier =
xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel);
assertThat(sslContextProviderSupplier).isNotNull();
SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class);
sslContextProviderSupplier.updateSslContext(callback);
}
private void registerWatcherAndCreateListenerUpdate(DownstreamTlsContext tlsContext)
throws UnknownHostException {
registeredWatcher =
XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds);
InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3");
localAddress = new InetSocketAddress(ipLocalAddress, PORT);
xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher);
DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext, null,
tlsContextManager);
assertThat(returnedTlsContext).isSameInstanceAs(tlsContext);
}
private void commonErrorCheck(boolean generateResourceDoesNotExist, Status status,
boolean isAbsent) throws UnknownHostException {
registerWatcherAndCreateListenerUpdate(tlsContext1);
reset(mockServerWatcher);
if (generateResourceDoesNotExist) {
registeredWatcher.onResourceDoesNotExist("not-found Error");
} else {
registeredWatcher.onError(status);
}
ArgumentCaptor<Throwable> argCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<Throwable> argCaptor = ArgumentCaptor.forClass(null);
verify(mockServerWatcher).onError(argCaptor.capture(), eq(false)); verify(mockServerWatcher).onError(argCaptor.capture(), eq(isAbsent));
Throwable throwable = argCaptor.getValue(); Throwable throwable = argCaptor.getValue();
assertThat(throwable).isInstanceOf(StatusException.class); assertThat(throwable).isInstanceOf(StatusException.class);
Status captured = ((StatusException)throwable).getStatus(); Status captured = ((StatusException) throwable).getStatus();
assertThat(captured.getCode()).isEqualTo(Status.Code.INTERNAL); assertThat(captured.getCode()).isEqualTo(status.getCode());
reset(mockServerWatcher); if (isAbsent) {
registeredWatcher.onResourceDoesNotExist("not-found Error"); assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNull();
ArgumentCaptor<Throwable> argCaptor1 = ArgumentCaptor.forClass(null); } else {
verify(mockServerWatcher).onError(argCaptor1.capture(), eq(true)); assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNotNull();
throwable = argCaptor1.getValue(); }
assertThat(throwable).isInstanceOf(StatusException.class);
captured = ((StatusException)throwable).getStatus();
assertThat(captured.getCode()).isEqualTo(Status.Code.NOT_FOUND);
InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3");
InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT);
EnvoyServerProtoData.DownstreamTlsContext tlsContext =
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1");
verify(mockServerWatcher, never())
.onListenerUpdate();
DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext);
assertThat(returnedTlsContext).isSameInstanceAs(tlsContext);
verify(mockServerWatcher).onListenerUpdate();
} }
private DownstreamTlsContext sendListenerUpdate( private DownstreamTlsContext sendListenerUpdate(
SocketAddress localAddress, DownstreamTlsContext tlsContext) throws UnknownHostException { SocketAddress localAddress, DownstreamTlsContext tlsContext,
DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager)
throws UnknownHostException {
when(channel.localAddress()).thenReturn(localAddress); when(channel.localAddress()).thenReturn(localAddress);
InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6");
InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234);
when(channel.remoteAddress()).thenReturn(remoteAddress); when(channel.remoteAddress()).thenReturn(remoteAddress);
XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext); XdsServerTestHelper
return xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); .generateListenerUpdate(registeredWatcher, Arrays.<Integer>asList(), tlsContext,
tlsContextForDefaultFilterChain, tlsContextManager);
return getDownstreamTlsContext();
}
private DownstreamTlsContext getDownstreamTlsContext() {
SslContextProviderSupplier sslContextProviderSupplier =
xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel);
if (sslContextProviderSupplier != null) {
EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext();
assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class);
return (DownstreamTlsContext)tlsContext;
}
return null;
} }
/** Creates XdsClientWrapperForServerSds: also used by other classes. */ /** Creates XdsClientWrapperForServerSds: also used by other classes. */
@ -203,7 +306,7 @@ public class XdsClientWrapperForServerSdsTestMisc {
XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager); XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager);
xdsClientWrapperForServerSds.start(); xdsClientWrapperForServerSds.start();
XdsSdsClientServerTest.generateListenerUpdateToWatcher( XdsSdsClientServerTest.generateListenerUpdateToWatcher(
downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher()); downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher(), tlsContextManager);
return xdsClientWrapperForServerSds; return xdsClientWrapperForServerSds;
} }
} }

View File

@ -113,19 +113,6 @@ public class XdsSdsClientServerTest {
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
} }
@Test
public void plaintextClientServer_withDefaultTlsContext() throws IOException, URISyntaxException {
DownstreamTlsContext defaultTlsContext =
EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext(
io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext
.getDefaultInstance());
buildServerWithTlsContext(/* downstreamTlsContext= */ defaultTlsContext);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null);
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}
@Test @Test
public void nullFallbackCredentials_expectException() throws IOException, URISyntaxException { public void nullFallbackCredentials_expectException() throws IOException, URISyntaxException {
try { try {
@ -289,7 +276,7 @@ public class XdsSdsClientServerTest {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE);
generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher); generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher, tlsContextManager);
try { try {
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); getBlockingStub(upstreamTlsContext, "foo.test.google.fr");
@ -356,8 +343,10 @@ public class XdsSdsClientServerTest {
} }
static void generateListenerUpdateToWatcher( static void generateListenerUpdateToWatcher(
DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher) { DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher,
EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext); TlsContextManager tlsContextManager) {
EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext,
tlsContextManager);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
} }
@ -371,12 +360,13 @@ public class XdsSdsClientServerTest {
XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials) XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials)
.addService(new SimpleServiceImpl()); .addService(new SimpleServiceImpl());
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext); xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext, tlsContextManager);
cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start(); cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start();
} }
static EnvoyServerProtoData.Listener buildListener( static EnvoyServerProtoData.Listener buildListener(
String name, String address, DownstreamTlsContext tlsContext) { String name, String address, DownstreamTlsContext tlsContext,
TlsContextManager tlsContextManager) {
EnvoyServerProtoData.FilterChainMatch filterChainMatch = EnvoyServerProtoData.FilterChainMatch filterChainMatch =
new EnvoyServerProtoData.FilterChainMatch( new EnvoyServerProtoData.FilterChainMatch(
0, 0,
@ -386,7 +376,7 @@ public class XdsSdsClientServerTest {
null, null,
Arrays.<Integer>asList()); Arrays.<Integer>asList());
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener(name, address, Arrays.asList(defaultFilterChain), null); new EnvoyServerProtoData.Listener(name, address, Arrays.asList(defaultFilterChain), null);
return listener; return listener;

View File

@ -63,6 +63,7 @@ public class XdsServerBuilderTest {
private XdsClient.LdsResourceWatcher listenerWatcher; private XdsClient.LdsResourceWatcher listenerWatcher;
private int port; private int port;
private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
private TlsContextManager tlsContextManager;
private void buildServer(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) private void buildServer(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener)
throws IOException { throws IOException {
@ -79,8 +80,9 @@ public class XdsServerBuilderTest {
if (xdsServingStatusListener != null) { if (xdsServingStatusListener != null) {
builder = builder.xdsServingStatusListener(xdsServingStatusListener); builder = builder.xdsServingStatusListener(xdsServingStatusListener);
} }
tlsContextManager = mock(TlsContextManager.class);
xdsClientWrapperForServerSds = XdsServerTestHelper xdsClientWrapperForServerSds = XdsServerTestHelper
.createXdsClientWrapperForServerSds(port, null); .createXdsClientWrapperForServerSds(port, tlsContextManager);
listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds);
} }
@ -150,8 +152,8 @@ public class XdsServerBuilderTest {
Future<Throwable> future = startServerAsync(); Future<Throwable> future = startServerAsync();
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
verifyServer(future, null, null); verifyServer(future, null, null);
verifyShutdown(); verifyShutdown();
} }
@ -162,8 +164,8 @@ public class XdsServerBuilderTest {
buildServer(null); buildServer(null);
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
xdsServer.start(); xdsServer.start();
try { try {
xdsServer.start(); xdsServer.start();
@ -183,8 +185,8 @@ public class XdsServerBuilderTest {
Future<Throwable> future = startServerAsync(); Future<Throwable> future = startServerAsync();
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
verifyServer(future, mockXdsServingStatusListener, null); verifyServer(future, mockXdsServingStatusListener, null);
} }
@ -224,8 +226,8 @@ public class XdsServerBuilderTest {
reset(mockXdsServingStatusListener); reset(mockXdsServingStatusListener);
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
verifyServer(future, mockXdsServingStatusListener, null); verifyServer(future, mockXdsServingStatusListener, null);
} }
@ -240,8 +242,8 @@ public class XdsServerBuilderTest {
ServerSocket serverSocket = new ServerSocket(port); ServerSocket serverSocket = new ServerSocket(port);
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
Throwable exception = future.get(5, TimeUnit.SECONDS); Throwable exception = future.get(5, TimeUnit.SECONDS);
assertThat(exception).isInstanceOf(IOException.class); assertThat(exception).isInstanceOf(IOException.class);
assertThat(exception).hasMessageThat().contains("Failed to bind"); assertThat(exception).hasMessageThat().contains("Failed to bind");
@ -258,12 +260,12 @@ public class XdsServerBuilderTest {
Future<Throwable> future = startServerAsync(); Future<Throwable> future = startServerAsync();
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
listenerWatcher, listenerWatcher,
CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"),
); tlsContextManager);
verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class)); verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class));
verifyServer(future, mockXdsServingStatusListener, null); verifyServer(future, mockXdsServingStatusListener, null);
listenerWatcher.onError(Status.ABORTED); listenerWatcher.onError(Status.ABORTED);

View File

@ -26,6 +26,7 @@ import io.grpc.internal.ObjectPool;
import java.io.IOException; import java.io.IOException;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.Map; import java.util.Map;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -112,14 +113,25 @@ class XdsServerTestHelper {
* Creates a {@link XdsClient.LdsUpdate} with {@link * Creates a {@link XdsClient.LdsUpdate} with {@link
* io.grpc.xds.EnvoyServerProtoData.FilterChain} with a destination port and an optional {@link * io.grpc.xds.EnvoyServerProtoData.FilterChain} with a destination port and an optional {@link
* EnvoyServerProtoData.DownstreamTlsContext}. * EnvoyServerProtoData.DownstreamTlsContext}.
*
* @param registeredWatcher the watcher on which to generate the update * @param registeredWatcher the watcher on which to generate the update
* @param tlsContext if non-null, used to populate filterChain * @param tlsContext if non-null, used to populate filterChain
*/ */
static void generateListenerUpdate( static void generateListenerUpdate(
XdsClient.LdsResourceWatcher registeredWatcher, XdsClient.LdsResourceWatcher registeredWatcher,
EnvoyServerProtoData.DownstreamTlsContext tlsContext) { EnvoyServerProtoData.DownstreamTlsContext tlsContext, TlsContextManager tlsContextManager) {
EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", tlsContext); EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3",
Arrays.<Integer>asList(), tlsContext, null, tlsContextManager);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate);
}
static void generateListenerUpdate(
XdsClient.LdsResourceWatcher registeredWatcher, List<Integer> sourcePorts,
EnvoyServerProtoData.DownstreamTlsContext tlsContext,
EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain,
TlsContextManager tlsContextManager) {
EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts,
tlsContext, tlsContextForDefaultFilterChain, tlsContextManager);
XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener);
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
} }
@ -132,7 +144,10 @@ class XdsServerTestHelper {
} }
static EnvoyServerProtoData.Listener buildTestListener( static EnvoyServerProtoData.Listener buildTestListener(
String name, String address, EnvoyServerProtoData.DownstreamTlsContext tlsContext) { String name, String address, List<Integer> sourcePorts,
EnvoyServerProtoData.DownstreamTlsContext tlsContext,
EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain,
TlsContextManager tlsContextManager) {
EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = EnvoyServerProtoData.FilterChainMatch filterChainMatch1 =
new EnvoyServerProtoData.FilterChainMatch( new EnvoyServerProtoData.FilterChainMatch(
0, 0,
@ -140,11 +155,12 @@ class XdsServerTestHelper {
Arrays.<String>asList(), Arrays.<String>asList(),
Arrays.<EnvoyServerProtoData.CidrRange>asList(), Arrays.<EnvoyServerProtoData.CidrRange>asList(),
null, null,
Arrays.<Integer>asList()); sourcePorts);
EnvoyServerProtoData.FilterChain filterChain1 = EnvoyServerProtoData.FilterChain filterChain1 =
new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext); new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext, tlsContextManager);
EnvoyServerProtoData.FilterChain defaultFilterChain = EnvoyServerProtoData.FilterChain defaultFilterChain =
new EnvoyServerProtoData.FilterChain(null, null); new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain,
tlsContextManager);
EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener( new EnvoyServerProtoData.Listener(
name, address, Arrays.asList(filterChain1), defaultFilterChain); name, address, Arrays.asList(filterChain1), defaultFilterChain);

View File

@ -297,7 +297,7 @@ public class SdsProtocolNegotiatorsTest {
XdsClientWrapperForServerSds xdsClientWrapperForServerSds = XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds(
80, downstreamTlsContext, null); 80, downstreamTlsContext, mock(TlsContextManager.class));
SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SdsProtocolNegotiators.HandlerPickerHandler( new SdsProtocolNegotiators.HandlerPickerHandler(
grpcHandler, xdsClientWrapperForServerSds, mockProtocolNegotiator); grpcHandler, xdsClientWrapperForServerSds, mockProtocolNegotiator);

View File

@ -26,9 +26,11 @@ import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.TlsContextManager; import io.grpc.xds.TlsContextManager;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
@ -119,30 +121,44 @@ public class SslContextProviderSupplierTest {
callUpdateSslContext(); callUpdateSslContext();
supplier.close(); supplier.close();
verify(mockTlsContextManager, times(1)) verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider)); .releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); SslContextProvider.Callback mockCallback = spy(
try { new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
supplier.updateSslContext(mockCallback); @Override
Assert.fail("no exception thrown"); public void updateSecret(SslContext sslContext) {
} catch (IllegalStateException expected) { Assert.fail("unexpected call");
assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); }
}
@Override
protected void onException(Throwable argument) {
assertThat(argument).isInstanceOf(IllegalStateException.class);
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
}
});
supplier.updateSslContext(mockCallback);
} }
@Test @Test
public void testClose_nullSslContextProvider() { public void testClose_nullSslContextProvider() {
prepareSupplier(); prepareSupplier();
doThrow(new NullPointerException()).when(mockTlsContextManager) doThrow(new NullPointerException()).when(mockTlsContextManager)
.releaseClientSslContextProvider(null); .releaseClientSslContextProvider(null);
supplier.close(); supplier.close();
verify(mockTlsContextManager, never()) verify(mockTlsContextManager, never())
.releaseClientSslContextProvider(eq(mockSslContextProvider)); .releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); SslContextProvider.Callback mockCallback = spy(
try { new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
supplier.updateSslContext(mockCallback); @Override
Assert.fail("no exception thrown"); public void updateSecret(SslContext sslContext) {
} catch (IllegalStateException expected) { Assert.fail("unexpected call");
assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); }
}
@Override
protected void onException(Throwable argument) {
assertThat(argument).isInstanceOf(IllegalStateException.class);
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
}
});
supplier.updateSslContext(mockCallback);
} }
} }