From 6755cfed34edfcd0fd700e65da4162b0367b6c35 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Mon, 26 Apr 2021 14:37:11 -0700 Subject: [PATCH] tsan, xds: fix XdsClientWrapperForServerSds data races (#8107) --- .../xds/XdsClientWrapperForServerSds.java | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java index 504165a81a..ee63da4872 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java @@ -54,6 +54,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -70,7 +71,7 @@ public final class XdsClientWrapperForServerSds { private static final TimeServiceResource timeServiceResource = new TimeServiceResource("GrpcServerXdsClient"); - private EnvoyServerProtoData.Listener curListener; + private AtomicReference curListener = new AtomicReference<>(); @SuppressWarnings("unused") @Nullable private XdsClient xdsClient; private final int port; @@ -137,14 +138,14 @@ public final class XdsClientWrapperForServerSds { new XdsClient.LdsResourceWatcher() { @Override public void onChanged(XdsClient.LdsUpdate update) { - curListener = update.listener; + curListener.set(update.listener); reportSuccess(); } @Override public void onResourceDoesNotExist(String resourceName) { logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName); - curListener = null; + curListener.set(null); reportError(Status.NOT_FOUND.asException(), true); } @@ -180,7 +181,8 @@ public final class XdsClientWrapperForServerSds { */ @Nullable public DownstreamTlsContext getDownstreamTlsContext(Channel channel) { - if (curListener != null && channel != null) { + EnvoyServerProtoData.Listener copyListener = curListener.get(); + if (copyListener != null && channel != null) { SocketAddress localAddress = channel.localAddress(); SocketAddress remoteAddress = channel.remoteAddress(); if (localAddress instanceof InetSocketAddress && remoteAddress instanceof InetSocketAddress) { @@ -189,7 +191,7 @@ public final class XdsClientWrapperForServerSds { checkState( port == localInetAddr.getPort(), "Channel localAddress port does not match requested listener port"); - return getDownstreamTlsContext(localInetAddr, remoteInetAddr); + return getDownstreamTlsContext(localInetAddr, remoteInetAddr, copyListener); } } return null; @@ -204,9 +206,10 @@ public final class XdsClientWrapperForServerSds { * @param localInetAddr dest address of the inbound connection * @param remoteInetAddr source address of the inbound connection */ - private DownstreamTlsContext getDownstreamTlsContext( - InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr) { - List filterChains = curListener.getFilterChains(); + private static DownstreamTlsContext getDownstreamTlsContext( + InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr, + EnvoyServerProtoData.Listener listener) { + List filterChains = listener.getFilterChains(); filterChains = filterOnDestinationPort(filterChains); filterChains = filterOnIpAddress(filterChains, localInetAddr.getAddress(), true); @@ -221,7 +224,7 @@ public final class XdsClientWrapperForServerSds { } else if (filterChains.size() == 1) { return filterChains.get(0).getDownstreamTlsContext(); } - return curListener.getDefaultFilterChain().getDownstreamTlsContext(); + return listener.getDefaultFilterChain().getDownstreamTlsContext(); } // destination_port present => Always fail match @@ -255,7 +258,7 @@ public final class XdsClientWrapperForServerSds { return filteredOnMatch.isEmpty() ? filteredOnEmpty : filteredOnMatch; } - private List filterOnSourceType( + private static List filterOnSourceType( List filterChains, InetAddress sourceAddress, InetAddress destAddress) { ArrayList filtered = new ArrayList<>(filterChains.size()); for (FilterChain filterChain : filterChains) { @@ -350,7 +353,7 @@ public final class XdsClientWrapperForServerSds { } // use prefix_ranges (CIDR) and get the most specific matches - private List filterOnIpAddress( + private static List filterOnIpAddress( List filterChains, InetAddress address, boolean forDestination) { PriorityQueue heap = new PriorityQueue<>(10, new QueueElementComparator()); @@ -384,7 +387,8 @@ public final class XdsClientWrapperForServerSds { synchronized (serverWatchers) { serverWatchers.add(serverWatcher); } - if (curListener != null) { + EnvoyServerProtoData.Listener copyListener = curListener.get(); + if (copyListener != null) { serverWatcher.onListenerUpdate(); } }