diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 19b8d0c1bb..69a2d4f8a7 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -287,10 +287,10 @@ final class ClientXdsClient extends AbstractXdsClient { "HttpConnectionManager neither has inlined route_config nor RDS."); } - private static LdsUpdate processServerSideListener(Listener listener) + private LdsUpdate processServerSideListener(Listener listener) throws ResourceInvalidException { StructOrError convertedListener = - parseServerSideListener(listener); + parseServerSideListener(listener, tlsContextManager); if (convertedListener.getErrorDetail() != null) { throw new ResourceInvalidException(convertedListener.getErrorDetail()); } @@ -369,10 +369,10 @@ final class ClientXdsClient extends AbstractXdsClient { } @VisibleForTesting static StructOrError parseServerSideListener( - Listener listener) { + Listener listener, TlsContextManager tlsContextManager) { try { return StructOrError.fromStruct( - EnvoyServerProtoData.Listener.fromEnvoyProtoListener(listener)); + EnvoyServerProtoData.Listener.fromEnvoyProtoListener(listener, tlsContextManager)); } catch (InvalidProtocolBufferException e) { return StructOrError.fromError( "Failed to unpack Listener " + listener.getName() + ":" + e.getMessage()); diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 22cd597db8..5ed0557df0 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -28,6 +28,7 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3 import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.Internal; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -354,17 +355,21 @@ public final class EnvoyServerProtoData { // TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here. private final FilterChainMatch filterChainMatch; @Nullable - private final DownstreamTlsContext downstreamTlsContext; + private final SslContextProviderSupplier sslContextProviderSupplier; @VisibleForTesting FilterChain( - FilterChainMatch filterChainMatch, @Nullable DownstreamTlsContext downstreamTlsContext) { + FilterChainMatch filterChainMatch, @Nullable DownstreamTlsContext downstreamTlsContext, + TlsContextManager tlsContextManager) { + SslContextProviderSupplier sslContextProviderSupplier1 = downstreamTlsContext == null ? null + : new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager); this.filterChainMatch = filterChainMatch; - this.downstreamTlsContext = downstreamTlsContext; + this.sslContextProviderSupplier = sslContextProviderSupplier1; } static FilterChain fromEnvoyProtoFilterChain( - io.envoyproxy.envoy.config.listener.v3.FilterChain proto, boolean isDefaultFilterChain) + io.envoyproxy.envoy.config.listener.v3.FilterChain proto, + TlsContextManager tlsContextManager, boolean isDefaultFilterChain) throws InvalidProtocolBufferException { if (!isDefaultFilterChain && proto.getFiltersList().isEmpty()) { throw new IllegalArgumentException( @@ -380,7 +385,8 @@ public final class EnvoyServerProtoData { } return new FilterChain( FilterChainMatch.fromEnvoyProtoFilterChainMatch(proto.getFilterChainMatch()), - getTlsContextFromFilterChain(proto) + getTlsContextFromFilterChain(proto), + tlsContextManager ); } @@ -456,9 +462,8 @@ public final class EnvoyServerProtoData { return filterChainMatch; } - @Nullable - public DownstreamTlsContext getDownstreamTlsContext() { - return downstreamTlsContext; + public SslContextProviderSupplier getSslContextProviderSupplier() { + return sslContextProviderSupplier; } @Override @@ -471,19 +476,19 @@ public final class EnvoyServerProtoData { } FilterChain that = (FilterChain) o; return java.util.Objects.equals(filterChainMatch, that.filterChainMatch) - && java.util.Objects.equals(downstreamTlsContext, that.downstreamTlsContext); + && java.util.Objects.equals(sslContextProviderSupplier, that.sslContextProviderSupplier); } @Override public int hashCode() { - return java.util.Objects.hash(filterChainMatch, downstreamTlsContext); + return java.util.Objects.hash(filterChainMatch, sslContextProviderSupplier); } @Override public String toString() { return "FilterChain{" + "filterChainMatch=" + filterChainMatch - + ", downstreamTlsContext=" + downstreamTlsContext + + ", sslContextProviderSupplier=" + sslContextProviderSupplier + '}'; } } @@ -524,7 +529,8 @@ public final class EnvoyServerProtoData { return null; } - static Listener fromEnvoyProtoListener(io.envoyproxy.envoy.config.listener.v3.Listener proto) + static Listener fromEnvoyProtoListener(io.envoyproxy.envoy.config.listener.v3.Listener proto, + TlsContextManager tlsContextManager) throws InvalidProtocolBufferException { if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND)) { throw new IllegalArgumentException("Listener " + proto.getName() + " is not INBOUND"); @@ -537,21 +543,25 @@ public final class EnvoyServerProtoData { throw new IllegalArgumentException( "Listener " + proto.getName() + " cannot have use_original_dst set to true"); } - List filterChains = validateAndSelectFilterChains(proto.getFilterChainsList()); + List filterChains = validateAndSelectFilterChains(proto.getFilterChainsList(), + tlsContextManager); return new Listener( proto.getName(), convertEnvoyAddressToString(proto.getAddress()), - filterChains, FilterChain.fromEnvoyProtoFilterChain(proto.getDefaultFilterChain(), true)); + filterChains, FilterChain + .fromEnvoyProtoFilterChain(proto.getDefaultFilterChain(), tlsContextManager, true)); } private static List validateAndSelectFilterChains( - List inputFilterChains) + List inputFilterChains, + TlsContextManager tlsContextManager) throws InvalidProtocolBufferException { List filterChains = new ArrayList<>(inputFilterChains.size()); for (io.envoyproxy.envoy.config.listener.v3.FilterChain filterChain : inputFilterChains) { if (isAcceptable(filterChain.getFilterChainMatch())) { - filterChains.add(FilterChain.fromEnvoyProtoFilterChain(filterChain, false)); + filterChains + .add(FilterChain.fromEnvoyProtoFilterChain(filterChain, tlsContextManager, false)); } } return filterChains; diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java index 448332af7c..275c6a288e 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java @@ -27,9 +27,9 @@ import io.grpc.Status; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourceHolder; import io.grpc.xds.EnvoyServerProtoData.CidrRange; -import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.Channel; import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollEventLoopGroup; @@ -74,6 +74,7 @@ public final class XdsClientWrapperForServerSds { private ScheduledExecutorService timeService; private XdsClient.LdsResourceWatcher listenerWatcher; private boolean newServerApi; + private String grpcServerResourceId; @VisibleForTesting final Set serverWatchers = new HashSet<>(); /** @@ -114,14 +115,14 @@ public final class XdsClientWrapperForServerSds { new XdsClient.LdsResourceWatcher() { @Override public void onChanged(XdsClient.LdsUpdate update) { - curListener.set(update.listener); + releaseOldSuppliers(curListener.getAndSet(update.listener)); reportSuccess(); } @Override public void onResourceDoesNotExist(String resourceName) { logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName); - curListener.set(null); + releaseOldSuppliers(curListener.getAndSet(null)); reportError(Status.NOT_FOUND.asException(), true); } @@ -129,10 +130,15 @@ public final class XdsClientWrapperForServerSds { public void onError(Status error) { logger.log( Level.WARNING, "LdsResourceWatcher in XdsClientWrapperForServerSds: {0}", error); - reportError(error.asException(), isResourceAbsent(error)); + if (isResourceAbsent(error)) { + releaseOldSuppliers(curListener.getAndSet(null)); + reportError(error.asException(), true); + } else { + reportError(error.asException(), false); + } } }; - String grpcServerResourceId = xdsClient.getBootstrapInfo() + grpcServerResourceId = xdsClient.getBootstrapInfo() .getServerListenerResourceNameTemplate(); newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); if (newServerApi && grpcServerResourceId == null) { @@ -145,6 +151,27 @@ public final class XdsClientWrapperForServerSds { xdsClient.watchLdsResource(grpcServerResourceId, listenerWatcher); } + // go thru the old listener and release all the old SslContextProviderSupplier + private void releaseOldSuppliers(EnvoyServerProtoData.Listener oldListener) { + if (oldListener != null) { + List filterChains = oldListener.getFilterChains(); + for (FilterChain filterChain : filterChains) { + releaseSupplier(filterChain); + } + releaseSupplier(oldListener.getDefaultFilterChain()); + } + } + + private static void releaseSupplier(FilterChain filterChain) { + if (filterChain != null) { + SslContextProviderSupplier sslContextProviderSupplier = + filterChain.getSslContextProviderSupplier(); + if (sslContextProviderSupplier != null) { + sslContextProviderSupplier.close(); + } + } + } + /** Whether the throwable indicates our listener resource is absent/deleted. */ private static boolean isResourceAbsent(Status status) { Status.Code code = status.getCode(); @@ -162,10 +189,10 @@ public final class XdsClientWrapperForServerSds { /** * Locates the best matching FilterChain to the channel from the current listener and if found - * returns the DownstreamTlsContext from that FilterChain, else null. + * returns the SslContextProviderSupplier from that FilterChain, else null. */ @Nullable - public DownstreamTlsContext getDownstreamTlsContext(Channel channel) { + public SslContextProviderSupplier getSslContextProviderSupplier(Channel channel) { EnvoyServerProtoData.Listener copyListener = curListener.get(); if (copyListener != null && channel != null) { SocketAddress localAddress = channel.localAddress(); @@ -176,7 +203,7 @@ public final class XdsClientWrapperForServerSds { checkState( port == localInetAddr.getPort(), "Channel localAddress port does not match requested listener port"); - return getDownstreamTlsContext(localInetAddr, remoteInetAddr, copyListener); + return getSslContextProviderSupplier(localInetAddr, remoteInetAddr, copyListener); } } return null; @@ -185,13 +212,13 @@ public final class XdsClientWrapperForServerSds { /** * Using the logic specified at * https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/listener/listener_components.proto.html?highlight=filter%20chain#listener-filterchainmatch - * locate a matching filter and return the corresponding DownstreamTlsContext or else return one - * from default filter chain. + * locate a matching filter and return the corresponding SslContextProviderSupplier or else + * return one from default filter chain. * * @param localInetAddr dest address of the inbound connection * @param remoteInetAddr source address of the inbound connection */ - private static DownstreamTlsContext getDownstreamTlsContext( + private static SslContextProviderSupplier getSslContextProviderSupplier( InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr, EnvoyServerProtoData.Listener listener) { List filterChains = listener.getFilterChains(); @@ -207,9 +234,9 @@ public final class XdsClientWrapperForServerSds { // close the connection throw new IllegalStateException("Found 2 matching filter-chains"); } else if (filterChains.size() == 1) { - return filterChains.get(0).getDownstreamTlsContext(); + return filterChains.get(0).getSslContextProviderSupplier(); } - return listener.getDefaultFilterChain().getDownstreamTlsContext(); + return listener.getDefaultFilterChain().getSslContextProviderSupplier(); } // destination_port present => Always fail match @@ -423,8 +450,10 @@ public final class XdsClientWrapperForServerSds { public void shutdown() { logger.log(Level.FINER, "Shutdown"); if (xdsClient != null) { + xdsClient.cancelLdsResourceWatch(grpcServerResourceId, listenerWatcher); xdsClient = xdsClientPool.returnObject(xdsClient); } + releaseOldSuppliers(curListener.getAndSet(null)); if (timeService != null) { timeService = SharedResourceHolder.release(timeServiceResource, timeService); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index d5ab62c2e5..ef7de2569e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -31,9 +31,7 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.InternalXdsAttributes; -import io.grpc.xds.TlsContextManager; import io.grpc.xds.XdsClientWrapperForServerSds; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; @@ -345,11 +343,11 @@ public final class SdsProtocolNegotiators { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { - DownstreamTlsContext downstreamTlsContext = + SslContextProviderSupplier sslContextProviderSupplier = xdsClientWrapperForServerSds == null ? null - : xdsClientWrapperForServerSds.getDownstreamTlsContext(ctx.channel()); - if (downstreamTlsContext == null) { + : xdsClientWrapperForServerSds.getSslContextProviderSupplier(ctx.channel()); + if (sslContextProviderSupplier == null) { if (fallbackProtocolNegotiator == null) { ctx.fireExceptionCaught(new CertStoreException("No certificate source found!")); return; @@ -369,8 +367,7 @@ public final class SdsProtocolNegotiators { this, null, new ServerSdsHandler( - grpcHandler, downstreamTlsContext, fallbackProtocolNegotiator, - xdsClientWrapperForServerSds.getTlsContextManager())); + grpcHandler, sslContextProviderSupplier)); ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); ctx.fireUserEventTriggered(pne); return; @@ -385,14 +382,11 @@ public final class SdsProtocolNegotiators { static final class ServerSdsHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; - private final DownstreamTlsContext downstreamTlsContext; - private final TlsContextManager tlsContextManager; - @Nullable private final ProtocolNegotiator fallbackProtocolNegotiator; + private final SslContextProviderSupplier sslContextProviderSupplier; ServerSdsHandler( GrpcHttp2ConnectionHandler grpcHandler, - DownstreamTlsContext downstreamTlsContext, - ProtocolNegotiator fallbackProtocolNegotiator, TlsContextManager tlsContextManager) { + SslContextProviderSupplier sslContextProviderSupplier) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -405,9 +399,7 @@ public final class SdsProtocolNegotiators { }, grpcHandler.getNegotiationLogger()); checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; - this.downstreamTlsContext = downstreamTlsContext; - this.fallbackProtocolNegotiator = fallbackProtocolNegotiator; - this.tlsContextManager = tlsContextManager; + this.sslContextProviderSupplier = sslContextProviderSupplier; } @Override @@ -415,23 +407,7 @@ public final class SdsProtocolNegotiators { final BufferReadsHandler bufferReads = new BufferReadsHandler(); ctx.pipeline().addBefore(ctx.name(), null, bufferReads); - SslContextProvider sslContextProviderTemp = null; - try { - sslContextProviderTemp = - tlsContextManager.findOrCreateServerSslContextProvider(downstreamTlsContext); - } catch (Exception e) { - if (fallbackProtocolNegotiator == null) { - ctx.fireExceptionCaught(new CertStoreException("No certificate source found!", e)); - return; - } - logger.log(Level.INFO, "Using fallback for {0}", ctx.channel().localAddress()); - // Delegate rest of handshake to fallback handler - ctx.pipeline().replace(this, null, fallbackProtocolNegotiator.newHandler(grpcHandler)); - ctx.pipeline().remove(bufferReads); - return; - } - final SslContextProvider sslContextProvider = sslContextProviderTemp; - sslContextProvider.addCallback( + sslContextProviderSupplier.updateSslContext( new SslContextProvider.Callback(ctx.executor()) { @Override @@ -445,7 +421,6 @@ public final class SdsProtocolNegotiators { fireProtocolNegotiationEvent(ctx); ctx.pipeline().remove(bufferReads); } - tlsContextManager.releaseServerSslContextProvider(sslContextProvider); } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index b319bb8311..900d3b9f14 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -19,11 +19,13 @@ package io.grpc.xds.internal.sds; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.base.MoreObjects; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.Objects; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -52,27 +54,36 @@ public final class SslContextProviderSupplier implements Closeable { /** Updates SslContext via the passed callback. */ public synchronized void updateSslContext(final SslContextProvider.Callback callback) { checkNotNull(callback, "callback"); - checkState(!shutdown, "Supplier is shutdown!"); - if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); + try { + checkState(!shutdown, "Supplier is shutdown!"); + if (sslContextProvider == null) { + sslContextProvider = getSslContextProvider(); + } + // we want to increment the ref-count so call findOrCreate again... + final SslContextProvider toRelease = getSslContextProvider(); + sslContextProvider.addCallback( + new SslContextProvider.Callback(callback.getExecutor()) { + + @Override + public void updateSecret(SslContext sslContext) { + callback.updateSecret(sslContext); + releaseSslContextProvider(toRelease); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); + releaseSslContextProvider(toRelease); + } + }); + } catch (final Throwable throwable) { + callback.getExecutor().execute(new Runnable() { + @Override + public void run() { + callback.onException(throwable); + } + }); } - // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); - sslContextProvider.addCallback( - new SslContextProvider.Callback(callback.getExecutor()) { - - @Override - public void updateSecret(SslContext sslContext) { - callback.updateSecret(sslContext); - releaseSslContextProvider(toRelease); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease); - } - }); } private void releaseSslContextProvider(SslContextProvider toRelease) { @@ -101,4 +112,34 @@ public final class SslContextProviderSupplier implements Closeable { } shutdown = true; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SslContextProviderSupplier that = (SslContextProviderSupplier) o; + return shutdown == that.shutdown + && Objects.equals(tlsContext, that.tlsContext) + && Objects.equals(tlsContextManager, that.tlsContextManager) + && Objects.equals(sslContextProvider, that.sslContextProvider); + } + + @Override + public int hashCode() { + return Objects.hash(tlsContext, tlsContextManager, sslContextProvider, shutdown); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("tlsContext", tlsContext) + .add("tlsContextManager", tlsContextManager) + .add("sslContextProvider", sslContextProvider) + .add("shutdown", shutdown) + .toString(); + } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index e092fd759c..cef630846f 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -688,7 +688,7 @@ public class ClientXdsClientDataTest { .setTrafficDirection(TrafficDirection.OUTBOUND) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()).isEqualTo("Listener listener1 is not INBOUND"); } @@ -701,7 +701,7 @@ public class ClientXdsClientDataTest { .addListenerFilters(ListenerFilter.newBuilder().build()) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("Listener listener1 cannot have listener_filters"); } @@ -715,7 +715,7 @@ public class ClientXdsClientDataTest { .setUseOriginalDst(BoolValue.of(true)) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("Listener listener1 cannot have use_original_dst set to true"); } @@ -729,7 +729,7 @@ public class ClientXdsClientDataTest { .addFilterChains(FilterChain.newBuilder().build()) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filerChain has to have envoy.http_connection_manager"); } @@ -753,7 +753,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filerChain has non-unique filter name:envoy.http_connection_manager"); } @@ -773,7 +773,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filter envoy.http_connection_manager with config_discovery not supported"); } @@ -789,7 +789,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filter envoy.http_connection_manager expected to have typed_config"); } @@ -809,7 +809,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "filter envoy.http_connection_manager with unsupported typed_config type:badTypeUrl"); @@ -830,7 +830,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("http-connection-manager has non-unique http-filter name:hf"); } @@ -852,7 +852,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "http-connection-manager http-filter envoy.router uses " @@ -877,7 +877,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "http-connection-manager http-filter envoy.router has unsupported typed-config type:" @@ -898,7 +898,7 @@ public class ClientXdsClientDataTest { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "http-connection-manager http-filter envoy.filters.http.router should have " diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index ff7c18fba4..8536711884 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -234,6 +234,8 @@ public abstract class ClientXdsClientTestBase { private CdsResourceWatcher cdsResourceWatcher; @Mock private EdsResourceWatcher edsResourceWatcher; + @Mock + private TlsContextManager tlsContextManager; private ManagedChannel channel; private ClientXdsClient xdsClient; @@ -279,7 +281,7 @@ public abstract class ClientXdsClientTestBase { backoffPolicyProvider, fakeClock.getStopwatchSupplier(), timeProvider, - mock(TlsContextManager.class)); + tlsContextManager); assertThat(resourceDiscoveryCalls).isEmpty(); assertThat(loadReportCalls).isEmpty(); @@ -2021,7 +2023,7 @@ public abstract class ClientXdsClientTestBase { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message listener = - mf.buildListenerWithFilterChain( + mf.buildListenerWithFilterChain( LISTENER_RESOURCE, 7000, "0.0.0.0", "google-sds-config-default", "ROOTCA"); List listeners = ImmutableList.of(Any.pack(listener)); call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); @@ -2030,10 +2032,11 @@ public abstract class ClientXdsClientTestBase { ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); assertThat(ldsUpdateCaptor.getValue().listener) - .isEqualTo(EnvoyServerProtoData.Listener.fromEnvoyProtoListener((Listener)listener)); + .isEqualTo(EnvoyServerProtoData.Listener + .fromEnvoyProtoListener((Listener) listener, tlsContextManager)); listener = - mf.buildListenerWithFilterChain( + mf.buildListenerWithFilterChain( LISTENER_RESOURCE, 7000, "0.0.0.0", "CERT2", "ROOTCA2"); listeners = ImmutableList.of(Any.pack(listener)); call.sendResponse(ResourceType.LDS, listeners, "1", "0001"); @@ -2043,7 +2046,8 @@ public abstract class ClientXdsClientTestBase { ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "1", "0001", NODE); verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); assertThat(ldsUpdateCaptor.getValue().listener) - .isEqualTo(EnvoyServerProtoData.Listener.fromEnvoyProtoListener((Listener)listener)); + .isEqualTo(EnvoyServerProtoData.Listener + .fromEnvoyProtoListener((Listener) listener, tlsContextManager)); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } diff --git a/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java b/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java index d2c0ca39aa..00dfa50061 100644 --- a/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java +++ b/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; @@ -34,6 +35,7 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.Listener; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -61,7 +63,7 @@ public class EnvoyServerProtoDataTest { .setTrafficDirection(TrafficDirection.INBOUND) .build(); - Listener xdsListener = Listener.fromEnvoyProtoListener(listener); + Listener xdsListener = Listener.fromEnvoyProtoListener(listener, mock(TlsContextManager.class)); assertThat(xdsListener.getName()).isEqualTo("8000"); assertThat(xdsListener.getAddress()).isEqualTo("10.2.1.34:8000"); List filterChains = xdsListener.getFilterChains(); @@ -81,7 +83,11 @@ public class EnvoyServerProtoDataTest { assertThat(inFilterChainMatch.getConnectionSourceType()) .isEqualTo(EnvoyServerProtoData.ConnectionSourceType.EXTERNAL); assertThat(inFilterChainMatch.getSourcePorts()).containsExactly(200, 300); - DownstreamTlsContext inFilterTlsContext = inFilter.getDownstreamTlsContext(); + SslContextProviderSupplier sslContextProviderSupplier = inFilter + .getSslContextProviderSupplier(); + assertThat(sslContextProviderSupplier.getTlsContext()).isInstanceOf(DownstreamTlsContext.class); + DownstreamTlsContext inFilterTlsContext = (DownstreamTlsContext) sslContextProviderSupplier + .getTlsContext(); assertThat(inFilterTlsContext.getCommonTlsContext()).isNotNull(); CommonTlsContext commonTlsContext = inFilterTlsContext.getCommonTlsContext(); List tlsCertSdsConfigs = commonTlsContext diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java index 6e14bd27a8..36152c342a 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.when; import com.google.protobuf.InvalidProtocolBufferException; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.Channel; import java.io.IOException; import java.net.InetAddress; @@ -46,6 +47,7 @@ public class FilterChainMatchTest { private static final String REMOTE_IP = "10.4.2.3"; // source @Mock private Channel channel; + @Mock private TlsContextManager tlsContextManager; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private XdsClient.LdsResourceWatcher registeredWatcher; @@ -54,7 +56,7 @@ public class FilterChainMatchTest { public void setUp() throws IOException { MockitoAnnotations.initMocks(this); xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(PORT, null); + .createXdsClientWrapperForServerSds(PORT, tlsContextManager); registeredWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); } @@ -64,6 +66,17 @@ public class FilterChainMatchTest { xdsClientWrapperForServerSds.shutdown(); } + private DownstreamTlsContext getDownstreamTlsContext() { + SslContextProviderSupplier sslContextProviderSupplier = + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + if (sslContextProviderSupplier != null) { + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class); + return (DownstreamTlsContext) tlsContext; + } + return null; + } + @Test public void singleFilterChainWithoutAlpn() throws UnknownHostException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); @@ -78,13 +91,12 @@ public class FilterChainMatchTest { DownstreamTlsContext tlsContext = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); EnvoyServerProtoData.FilterChain filterChain = - new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContext); } @@ -102,13 +114,12 @@ public class FilterChainMatchTest { DownstreamTlsContext tlsContext = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); EnvoyServerProtoData.FilterChain filterChain = - new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContext); } @@ -118,14 +129,13 @@ public class FilterChainMatchTest { DownstreamTlsContext tlsContext = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); EnvoyServerProtoData.FilterChain filterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContext); + new EnvoyServerProtoData.FilterChain(null, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(), filterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContext); } @@ -143,18 +153,19 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainWithDestPort = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithDestPort, tlsContextWithDestPort); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithDestPort, tlsContextWithDestPort, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithDestPort), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); } @@ -172,18 +183,19 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainWithMatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); } @@ -203,24 +215,25 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); } @Test public void dest0LengthPrefixRange() - throws UnknownHostException, InvalidProtocolBufferException { + throws UnknownHostException, InvalidProtocolBufferException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); DownstreamTlsContext tlsContext0Length = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -234,18 +247,19 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChain0Length = - new EnvoyServerProtoData.FilterChain(filterChainMatch0Length, tlsContext0Length); + new EnvoyServerProtoData.FilterChain(filterChainMatch0Length, tlsContext0Length, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChain0Length), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContext0Length); } @@ -264,7 +278,8 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); DownstreamTlsContext tlsContextMoreSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -277,9 +292,10 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -288,14 +304,13 @@ public class FilterChainMatchTest { defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); } @Test public void destPrefixRange_emptyListLessSpecific() - throws UnknownHostException, InvalidProtocolBufferException { + throws UnknownHostException, InvalidProtocolBufferException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); DownstreamTlsContext tlsContextLessSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -308,7 +323,8 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); DownstreamTlsContext tlsContextMoreSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -321,9 +337,10 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -332,8 +349,7 @@ public class FilterChainMatchTest { defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); } @@ -352,7 +368,8 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); DownstreamTlsContext tlsContextMoreSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -365,9 +382,10 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -376,8 +394,7 @@ public class FilterChainMatchTest { defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); } @@ -399,7 +416,7 @@ public class FilterChainMatchTest { Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = new EnvoyServerProtoData.FilterChain( - filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2); + filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2, tlsContextManager); DownstreamTlsContext tlsContextLessSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -412,9 +429,10 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -423,8 +441,7 @@ public class FilterChainMatchTest { defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); } @@ -442,18 +459,19 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); } @@ -471,18 +489,19 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainWithMatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); } @@ -504,7 +523,7 @@ public class FilterChainMatchTest { Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = new EnvoyServerProtoData.FilterChain( - filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2); + filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2, tlsContextManager); DownstreamTlsContext tlsContextLessSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -517,9 +536,10 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -528,14 +548,13 @@ public class FilterChainMatchTest { defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); } @Test public void sourcePrefixRange_2Matchers_expectException() - throws UnknownHostException, InvalidProtocolBufferException { + throws UnknownHostException, InvalidProtocolBufferException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); DownstreamTlsContext tlsContext1 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -550,7 +569,7 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChain1 = - new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1); + new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1, tlsContextManager); DownstreamTlsContext tlsContext2 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -563,16 +582,16 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChain2 = - new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2); + new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, null); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChain1, filterChain2), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); try { - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); fail("expect exception!"); } catch (IllegalStateException ise) { assertThat(ise).hasMessageThat().isEqualTo("Found 2 matching filter-chains"); @@ -597,7 +616,7 @@ public class FilterChainMatchTest { Arrays.asList()); EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts = new EnvoyServerProtoData.FilterChain( - filterChainMatchEmptySourcePorts, tlsContextEmptySourcePorts); + filterChainMatchEmptySourcePorts, tlsContextEmptySourcePorts, tlsContextManager); DownstreamTlsContext tlsContextSourcePortMatch = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -611,9 +630,9 @@ public class FilterChainMatchTest { Arrays.asList(7000, 15000)); EnvoyServerProtoData.FilterChain filterChainSourcePortMatch = new EnvoyServerProtoData.FilterChain( - filterChainMatchSourcePortMatch, tlsContextSourcePortMatch); + filterChainMatchSourcePortMatch, tlsContextSourcePortMatch, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -622,8 +641,7 @@ public class FilterChainMatchTest { defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextSourcePortMatch); } @@ -660,7 +678,7 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChain1 = - new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1); + new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1, tlsContextManager); // next 5 use prefix range: 4 with prefixLen of 30 and last one with 29 @@ -674,7 +692,7 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChain2 = - new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2); + new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2, tlsContextManager); // has prefix ranges with one not matching and source type local: gets eliminated in step 3 EnvoyServerProtoData.FilterChainMatch filterChainMatch3 = @@ -688,7 +706,7 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChain3 = - new EnvoyServerProtoData.FilterChain(filterChainMatch3, tlsContext3); + new EnvoyServerProtoData.FilterChain(filterChainMatch3, tlsContext3, tlsContextManager); // has prefix ranges with both matching and source type external but non matching source port: // gets eliminated in step 5 @@ -703,7 +721,7 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, Arrays.asList(16000, 9000)); EnvoyServerProtoData.FilterChain filterChain4 = - new EnvoyServerProtoData.FilterChain(filterChainMatch4, tlsContext4); + new EnvoyServerProtoData.FilterChain(filterChainMatch4, tlsContext4, tlsContextManager); // has prefix ranges with both matching and source type external and matching source port: this // gets selected @@ -720,7 +738,7 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList(15000, 8000)); EnvoyServerProtoData.FilterChain filterChain5 = - new EnvoyServerProtoData.FilterChain(filterChainMatch5, tlsContext5); + new EnvoyServerProtoData.FilterChain(filterChainMatch5, tlsContext5, tlsContextManager); // has prefix range with prefixLen of 29: gets eliminated in step 2 EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = @@ -732,10 +750,10 @@ public class FilterChainMatchTest { EnvoyServerProtoData.ConnectionSourceType.ANY, Arrays.asList()); EnvoyServerProtoData.FilterChain filterChain6 = - new EnvoyServerProtoData.FilterChain(filterChainMatch6, tlsContext6); + new EnvoyServerProtoData.FilterChain(filterChainMatch6, tlsContext6, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -745,8 +763,7 @@ public class FilterChainMatchTest { defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContextPicked = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); assertThat(tlsContextPicked).isSameInstanceAs(tlsContext5); } diff --git a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java b/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java index 7289a7ba8d..c4e888f543 100644 --- a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java +++ b/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java @@ -66,13 +66,15 @@ public class ServerWrapperForXdsTest { private XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener; private XdsClient.LdsResourceWatcher listenerWatcher; private Server mockServer; + private TlsContextManager tlsContextManager; @Before public void setUp() throws IOException { port = XdsServerTestHelper.findFreePort(); mockDelegateBuilder = mock(ServerBuilder.class); + tlsContextManager = mock(TlsContextManager.class); xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(port, null); + .createXdsClientWrapperForServerSds(port, tlsContextManager); mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); @@ -117,8 +119,8 @@ public class ServerWrapperForXdsTest { verifyCapturedCodeAndNotServing(Status.Code.ABORTED, ServerWrapperForXds.ServingState.STARTING); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); Throwable exception = future.get(2, TimeUnit.SECONDS); assertThat(exception).isNull(); assertThat(serverWrapperForXds.getCurrentServingState()) @@ -163,8 +165,8 @@ public class ServerWrapperForXdsTest { public void run() { XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); } }).start(); assertThat(settableFutureToSignalStart.get()).isNull(); @@ -197,9 +199,9 @@ public class ServerWrapperForXdsTest { @Override public void run() { XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + listenerWatcher, + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); } }).start(); Throwable exception = future.get(2, TimeUnit.SECONDS); @@ -242,8 +244,8 @@ public class ServerWrapperForXdsTest { Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); Throwable exception = future.get(2, TimeUnit.SECONDS); assertThat(exception).isNull(); assertThat(serverWrapperForXds.getCurrentServingState()) @@ -256,8 +258,8 @@ public class ServerWrapperForXdsTest { when(mockDelegateBuilder.build()).thenReturn(mockServer); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"), + tlsContextManager); Thread.sleep(100L); assertThat(serverWrapperForXds.getCurrentServingState()) .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); @@ -269,8 +271,8 @@ public class ServerWrapperForXdsTest { Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); Throwable exception = future.get(2, TimeUnit.SECONDS); assertThat(exception).isNull(); assertThat(serverWrapperForXds.getCurrentServingState()) @@ -302,8 +304,8 @@ public class ServerWrapperForXdsTest { public void run() { XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); } }).start(); assertThat(settableFutureToSignalStart.get()).isNull(); diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index 13ec57cef1..75d7b76dbb 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -19,9 +19,11 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -30,12 +32,15 @@ import io.grpc.StatusException; import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProvider; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.Channel; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.UnknownHostException; +import java.util.Arrays; import java.util.Collections; import org.junit.After; import org.junit.Before; @@ -53,15 +58,27 @@ public class XdsClientWrapperForServerSdsTestMisc { private static final int PORT = 7000; @Mock private Channel channel; + @Mock private TlsContextManager tlsContextManager; + @Mock private XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private XdsClient.LdsResourceWatcher registeredWatcher; + private InetSocketAddress localAddress; + private DownstreamTlsContext tlsContext1; + private DownstreamTlsContext tlsContext2; + private DownstreamTlsContext tlsContext3; @Before public void setUp() throws IOException { MockitoAnnotations.initMocks(this); + tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + tlsContext3 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(PORT, null); + .createXdsClientWrapperForServerSds(PORT, tlsContextManager); } @After @@ -73,7 +90,9 @@ public class XdsClientWrapperForServerSdsTestMisc { public void nonInetSocketAddress_expectNull() throws UnknownHostException { registeredWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - assertThat(sendListenerUpdate(new InProcessSocketAddress("test1"), null)).isNull(); + assertThat( + sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager)) + .isNull(); } @Test @@ -83,7 +102,7 @@ public class XdsClientWrapperForServerSdsTestMisc { try { InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT + 1); - DownstreamTlsContext unused = sendListenerUpdate(localAddress, null); + sendListenerUpdate(localAddress, null, null, tlsContextManager); fail("exception expected"); } catch (IllegalStateException expected) { assertThat(expected) @@ -114,86 +133,170 @@ public class XdsClientWrapperForServerSdsTestMisc { null); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext = xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext = getDownstreamTlsContext(); assertThat(tlsContext).isNull(); } - @Test - public void registerServerWatcher() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher = - mock(XdsClientWrapperForServerSds.ServerWatcher.class); - xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - verify(mockServerWatcher, never()) - .onListenerUpdate(); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); - verify(mockServerWatcher).onListenerUpdate(); - xdsClientWrapperForServerSds.removeServerWatcher(mockServerWatcher); - } - @Test public void registerServerWatcher_afterListenerUpdate() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); - XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher = - mock(XdsClientWrapperForServerSds.ServerWatcher.class); - xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); + registerWatcherAndCreateListenerUpdate(tlsContext1); verify(mockServerWatcher).onListenerUpdate(); } @Test - public void registerServerWatcher_notifyError() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher = - mock(XdsClientWrapperForServerSds.ServerWatcher.class); - xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); + public void registerServerWatcher_notifyNotFound() throws UnknownHostException { + commonErrorCheck(true, Status.NOT_FOUND, true); + } + + @Test + public void registerServerWatcher_notifyInternalError() throws UnknownHostException { + commonErrorCheck(false, Status.INTERNAL, false); + } + + @Test + public void registerServerWatcher_notifyPermDeniedError() throws UnknownHostException { + commonErrorCheck(false, Status.PERMISSION_DENIED, true); + } + + @Test + public void releaseOldSupplierOnChanged_noCloseDueToLazyLoading() throws UnknownHostException { + registerWatcherAndCreateListenerUpdate(tlsContext1); + XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext2, tlsContextManager); + verify(tlsContextManager, never()) + .findOrCreateServerSslContextProvider(any(DownstreamTlsContext.class)); + } + + @Test + public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); + XdsServerTestHelper + .generateListenerUpdate(registeredWatcher, Arrays.asList(1234), tlsContext2, + tlsContext3, tlsContextManager); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); + reset(tlsContextManager); + SslContextProvider sslContextProvider2 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext2))) + .thenReturn(sslContextProvider2); + SslContextProvider sslContextProvider3 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext3))) + .thenReturn(sslContextProvider3); + callUpdateSslContext(channel); + InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); + InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1111); + when(channel.remoteAddress()).thenReturn(remoteAddress); + callUpdateSslContext(channel); + XdsClient mockXdsClient = xdsClientWrapperForServerSds.getXdsClient(); + xdsClientWrapperForServerSds.shutdown(); + verify(mockXdsClient, times(1)) + .cancelLdsResourceWatch(eq("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT), + eq(registeredWatcher)); + verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider2)); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider3)); + } + + @Test + public void releaseOldSupplierOnNotFound_verifyClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); + registeredWatcher.onResourceDoesNotExist("not-found Error"); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); + } + + @Test + public void releaseOldSupplierOnPermDeniedError_verifyClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); + registeredWatcher.onError(Status.PERMISSION_DENIED); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); + } + + @Test + public void releaseOldSupplierOnInternalError_noClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); registeredWatcher.onError(Status.INTERNAL); + verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); + } + + private void callUpdateSslContext(Channel channel) { + SslContextProviderSupplier sslContextProviderSupplier = + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + assertThat(sslContextProviderSupplier).isNotNull(); + SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); + sslContextProviderSupplier.updateSslContext(callback); + } + + private void registerWatcherAndCreateListenerUpdate(DownstreamTlsContext tlsContext) + throws UnknownHostException { + registeredWatcher = + XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); + DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext, null, + tlsContextManager); + assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); + } + + private void commonErrorCheck(boolean generateResourceDoesNotExist, Status status, + boolean isAbsent) throws UnknownHostException { + registerWatcherAndCreateListenerUpdate(tlsContext1); + reset(mockServerWatcher); + if (generateResourceDoesNotExist) { + registeredWatcher.onResourceDoesNotExist("not-found Error"); + } else { + registeredWatcher.onError(status); + } ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockServerWatcher).onError(argCaptor.capture(), eq(false)); + verify(mockServerWatcher).onError(argCaptor.capture(), eq(isAbsent)); Throwable throwable = argCaptor.getValue(); assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException)throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.INTERNAL); - reset(mockServerWatcher); - registeredWatcher.onResourceDoesNotExist("not-found Error"); - ArgumentCaptor argCaptor1 = ArgumentCaptor.forClass(null); - verify(mockServerWatcher).onError(argCaptor1.capture(), eq(true)); - throwable = argCaptor1.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - captured = ((StatusException)throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.NOT_FOUND); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - verify(mockServerWatcher, never()) - .onListenerUpdate(); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); - verify(mockServerWatcher).onListenerUpdate(); + Status captured = ((StatusException) throwable).getStatus(); + assertThat(captured.getCode()).isEqualTo(status.getCode()); + if (isAbsent) { + assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNull(); + } else { + assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNotNull(); + } } private DownstreamTlsContext sendListenerUpdate( - SocketAddress localAddress, DownstreamTlsContext tlsContext) throws UnknownHostException { + SocketAddress localAddress, DownstreamTlsContext tlsContext, + DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) + throws UnknownHostException { when(channel.localAddress()).thenReturn(localAddress); InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); when(channel.remoteAddress()).thenReturn(remoteAddress); - XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext); - return xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + XdsServerTestHelper + .generateListenerUpdate(registeredWatcher, Arrays.asList(), tlsContext, + tlsContextForDefaultFilterChain, tlsContextManager); + return getDownstreamTlsContext(); + } + + private DownstreamTlsContext getDownstreamTlsContext() { + SslContextProviderSupplier sslContextProviderSupplier = + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + if (sslContextProviderSupplier != null) { + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class); + return (DownstreamTlsContext)tlsContext; + } + return null; } /** Creates XdsClientWrapperForServerSds: also used by other classes. */ @@ -203,7 +306,7 @@ public class XdsClientWrapperForServerSdsTestMisc { XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager); xdsClientWrapperForServerSds.start(); XdsSdsClientServerTest.generateListenerUpdateToWatcher( - downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher()); + downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher(), tlsContextManager); return xdsClientWrapperForServerSds; } } diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index 863bf3addc..efa59b69c0 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -113,19 +113,6 @@ public class XdsSdsClientServerTest { assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } - @Test - public void plaintextClientServer_withDefaultTlsContext() throws IOException, URISyntaxException { - DownstreamTlsContext defaultTlsContext = - EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - .getDefaultInstance()); - buildServerWithTlsContext(/* downstreamTlsContext= */ defaultTlsContext); - - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null); - assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); - } - @Test public void nullFallbackCredentials_expectException() throws IOException, URISyntaxException { try { @@ -289,7 +276,7 @@ public class XdsSdsClientServerTest { DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); - generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher); + generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher, tlsContextManager); try { SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); @@ -356,8 +343,10 @@ public class XdsSdsClientServerTest { } static void generateListenerUpdateToWatcher( - DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher) { - EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext); + DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext, + tlsContextManager); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); } @@ -371,12 +360,13 @@ public class XdsSdsClientServerTest { XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials) .addService(new SimpleServiceImpl()); XdsServerTestHelper.generateListenerUpdate( - xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext); + xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext, tlsContextManager); cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start(); } static EnvoyServerProtoData.Listener buildListener( - String name, String address, DownstreamTlsContext tlsContext) { + String name, String address, DownstreamTlsContext tlsContext, + TlsContextManager tlsContextManager) { EnvoyServerProtoData.FilterChainMatch filterChainMatch = new EnvoyServerProtoData.FilterChainMatch( 0, @@ -386,7 +376,7 @@ public class XdsSdsClientServerTest { null, Arrays.asList()); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener(name, address, Arrays.asList(defaultFilterChain), null); return listener; diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index 7e14f9d62a..0b174a4a31 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -63,6 +63,7 @@ public class XdsServerBuilderTest { private XdsClient.LdsResourceWatcher listenerWatcher; private int port; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; + private TlsContextManager tlsContextManager; private void buildServer(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) throws IOException { @@ -79,8 +80,9 @@ public class XdsServerBuilderTest { if (xdsServingStatusListener != null) { builder = builder.xdsServingStatusListener(xdsServingStatusListener); } + tlsContextManager = mock(TlsContextManager.class); xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(port, null); + .createXdsClientWrapperForServerSds(port, tlsContextManager); listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); } @@ -150,8 +152,8 @@ public class XdsServerBuilderTest { Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verifyServer(future, null, null); verifyShutdown(); } @@ -162,8 +164,8 @@ public class XdsServerBuilderTest { buildServer(null); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); xdsServer.start(); try { xdsServer.start(); @@ -183,8 +185,8 @@ public class XdsServerBuilderTest { Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verifyServer(future, mockXdsServingStatusListener, null); } @@ -224,8 +226,8 @@ public class XdsServerBuilderTest { reset(mockXdsServingStatusListener); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verifyServer(future, mockXdsServingStatusListener, null); } @@ -240,8 +242,8 @@ public class XdsServerBuilderTest { ServerSocket serverSocket = new ServerSocket(port); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); Throwable exception = future.get(5, TimeUnit.SECONDS); assertThat(exception).isInstanceOf(IOException.class); assertThat(exception).hasMessageThat().contains("Failed to bind"); @@ -258,12 +260,12 @@ public class XdsServerBuilderTest { Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class)); verifyServer(future, mockXdsServingStatusListener, null); listenerWatcher.onError(Status.ABORTED); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 97bd01d872..ef74584fb1 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -26,6 +26,7 @@ import io.grpc.internal.ObjectPool; import java.io.IOException; import java.net.ServerSocket; import java.util.Arrays; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; import org.mockito.ArgumentCaptor; @@ -112,14 +113,25 @@ class XdsServerTestHelper { * Creates a {@link XdsClient.LdsUpdate} with {@link * io.grpc.xds.EnvoyServerProtoData.FilterChain} with a destination port and an optional {@link * EnvoyServerProtoData.DownstreamTlsContext}. - * * @param registeredWatcher the watcher on which to generate the update * @param tlsContext if non-null, used to populate filterChain */ static void generateListenerUpdate( XdsClient.LdsResourceWatcher registeredWatcher, - EnvoyServerProtoData.DownstreamTlsContext tlsContext) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", tlsContext); + EnvoyServerProtoData.DownstreamTlsContext tlsContext, TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", + Arrays.asList(), tlsContext, null, tlsContextManager); + XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); + registeredWatcher.onChanged(listenerUpdate); + } + + static void generateListenerUpdate( + XdsClient.LdsResourceWatcher registeredWatcher, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, + tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); } @@ -132,7 +144,10 @@ class XdsServerTestHelper { } static EnvoyServerProtoData.Listener buildTestListener( - String name, String address, EnvoyServerProtoData.DownstreamTlsContext tlsContext) { + String name, String address, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = new EnvoyServerProtoData.FilterChainMatch( 0, @@ -140,11 +155,12 @@ class XdsServerTestHelper { Arrays.asList(), Arrays.asList(), null, - Arrays.asList()); + sourcePorts); EnvoyServerProtoData.FilterChain filterChain1 = - new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( name, address, Arrays.asList(filterChain1), defaultFilterChain); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java index 4f89d8a61d..1f1d32644e 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java @@ -297,7 +297,7 @@ public class SdsProtocolNegotiatorsTest { XdsClientWrapperForServerSds xdsClientWrapperForServerSds = XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( - 80, downstreamTlsContext, null); + 80, downstreamTlsContext, mock(TlsContextManager.class)); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( grpcHandler, xdsClientWrapperForServerSds, mockProtocolNegotiator); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java index 8c5922b7fd..0395f3055e 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java @@ -26,9 +26,11 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.google.common.util.concurrent.MoreExecutors; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; @@ -119,30 +121,44 @@ public class SslContextProviderSupplierTest { callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - try { - supplier.updateSslContext(mockCallback); - Assert.fail("no exception thrown"); - } catch (IllegalStateException expected) { - assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); - } + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = spy( + new SslContextProvider.Callback(MoreExecutors.directExecutor()) { + @Override + public void updateSecret(SslContext sslContext) { + Assert.fail("unexpected call"); + } + + @Override + protected void onException(Throwable argument) { + assertThat(argument).isInstanceOf(IllegalStateException.class); + assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); + } + }); + supplier.updateSslContext(mockCallback); } @Test public void testClose_nullSslContextProvider() { prepareSupplier(); doThrow(new NullPointerException()).when(mockTlsContextManager) - .releaseClientSslContextProvider(null); + .releaseClientSslContextProvider(null); supplier.close(); verify(mockTlsContextManager, never()) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - try { - supplier.updateSslContext(mockCallback); - Assert.fail("no exception thrown"); - } catch (IllegalStateException expected) { - assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); - } + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = spy( + new SslContextProvider.Callback(MoreExecutors.directExecutor()) { + @Override + public void updateSecret(SslContext sslContext) { + Assert.fail("unexpected call"); + } + + @Override + protected void onException(Throwable argument) { + assertThat(argument).isInstanceOf(IllegalStateException.class); + assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); + } + }); + supplier.updateSslContext(mockCallback); } }