From 1a2285b52786ee190e4157cf0b005aabb7316f96 Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko Date: Mon, 3 Mar 2025 17:28:36 -0500 Subject: [PATCH] xds: ensure server interceptors are created in a sync context (#11930) `XdsServerWrapper#generatePerRouteInterceptors` was always intended to be executed within a sync context. This PR ensures that by calling `syncContext.throwIfNotInThisSynchronizationContext()`. This change is needed for upcoming xDS filter state retention because the new tests in XdsServerWrapperTest flake with this NPE: > `Cannot invoke "io.grpc.xds.client.XdsClient$ResourceWatcher.onChanged(io.grpc.xds.client.XdsClient$ResourceUpdate)" because "this.ldsWatcher" is null` --- .../java/io/grpc/xds/XdsServerWrapper.java | 4 +- .../java/io/grpc/xds/XdsServerTestHelper.java | 83 +++++++++++++++---- .../io/grpc/xds/XdsServerWrapperTest.java | 28 +++---- 3 files changed, 79 insertions(+), 36 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index bbb17d9b61..e5b25ae458 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -524,9 +524,7 @@ final class XdsServerWrapper extends Server { private ImmutableMap generatePerRouteInterceptors( @Nullable List filterConfigs, List virtualHosts) { - // This should always be called from the sync context. - // Ideally we'd want to throw otherwise, but this breaks the tests now. - // syncContext.throwIfNotInThisSynchronizationContext(); + syncContext.throwIfNotInThisSynchronizationContext(); ImmutableMap.Builder perRouteInterceptors = new ImmutableMap.Builder<>(); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index a27c291771..0508b11c20 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -38,6 +38,7 @@ import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsResourceType; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -45,7 +46,10 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; /** @@ -174,12 +178,18 @@ public class XdsServerTestHelper { } } + // Implementation details: + // 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`. + // 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor. static final class FakeXdsClient extends XdsClient { - boolean shutdown; - SettableFuture ldsResource = SettableFuture.create(); - ResourceWatcher ldsWatcher; - CountDownLatch rdsCount = new CountDownLatch(1); + public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5); + + private boolean shutdown; + @Nullable SettableFuture ldsResource = SettableFuture.create(); + @Nullable ResourceWatcher ldsWatcher; + private CountDownLatch rdsCount = new CountDownLatch(1); final Map> rdsWatchers = new HashMap<>(); + @Nullable private volatile Executor serverExecutor; @Override public TlsContextManager getSecurityConfig() { @@ -193,14 +203,20 @@ public class XdsServerTestHelper { @Override @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType resourceType, - String resourceName, - ResourceWatcher watcher, - Executor syncContext) { + public synchronized void watchXdsResource( + XdsResourceType resourceType, + String resourceName, + ResourceWatcher watcher, + Executor executor) { + if (serverExecutor != null) { + assertThat(executor).isEqualTo(serverExecutor); + } + switch (resourceType.typeName()) { case "LDS": assertThat(ldsWatcher).isNull(); ldsWatcher = (ResourceWatcher) watcher; + serverExecutor = executor; ldsResource.set(resourceName); break; case "RDS": @@ -213,14 +229,14 @@ public class XdsServerTestHelper { } @Override - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { + public synchronized void cancelXdsResourceWatch( + XdsResourceType type, String resourceName, ResourceWatcher watcher) { switch (type.typeName()) { case "LDS": assertThat(ldsWatcher).isNotNull(); ldsResource = null; ldsWatcher = null; + serverExecutor = null; break; case "RDS": rdsWatchers.remove(resourceName); @@ -230,27 +246,58 @@ public class XdsServerTestHelper { } @Override - public void shutdown() { + public synchronized void shutdown() { shutdown = true; } @Override - public boolean isShutDown() { + public synchronized boolean isShutDown() { return shutdown; } + public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException { + if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new TimeoutException("Timeout " + timeout + " waiting for RDSs"); + } + } + + public void setExpectedRdsCount(int count) { + rdsCount = new CountDownLatch(count); + } + + private void execute(Runnable action) { + // This method ensures that all watcher updates: + // - Happen after the server started watching LDS. + // - Are executed within the sync context of the server. + // + // Note that this doesn't guarantee that any of the RDS watchers are created. + // Tests should use setExpectedRdsCount(int) and awaitRds() for that. + if (ldsResource == null) { + throw new IllegalStateException("xDS resource update after watcher cancel"); + } + try { + ldsResource.get(DEFAULT_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS); + } catch (ExecutionException | TimeoutException e) { + throw new RuntimeException("Can't resolve LDS resource name in " + DEFAULT_TIMEOUT, e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + serverExecutor.execute(action); + } + void deliverLdsUpdate(List filterChains, FilterChain defaultFilterChain) { - ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create( - "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); + deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create( + "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); } void deliverLdsUpdate(LdsUpdate ldsUpdate) { - ldsWatcher.onChanged(ldsUpdate); + execute(() -> ldsWatcher.onChanged(ldsUpdate)); } - void deliverRdsUpdate(String rdsName, List virtualHosts) { - rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts)); + void deliverRdsUpdate(String resourceName, List virtualHosts) { + execute(() -> rdsWatchers.get(resourceName).onChanged(new RdsUpdate(virtualHosts))); } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 41f005ba58..388052a3dc 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -74,7 +74,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -252,7 +251,7 @@ public class XdsServerWrapperTest { FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); verify(listener, timeout(5000)).onServing(); @@ -261,7 +260,7 @@ public class XdsServerWrapperTest { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -303,7 +302,7 @@ public class XdsServerWrapperTest { verify(mockServer, never()).start(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -342,7 +341,7 @@ public class XdsServerWrapperTest { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockBuilder, times(1)).build(); verify(mockServer, times(1)).shutdown(); xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS); @@ -367,7 +366,7 @@ public class XdsServerWrapperTest { FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -434,7 +433,7 @@ public class XdsServerWrapperTest { xdsClient.ldsResource.get(5, TimeUnit.SECONDS); FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -544,7 +543,7 @@ public class XdsServerWrapperTest { 0L, Collections.singletonList(virtualHost), new ArrayList()); EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); - xdsClient.rdsCount = new CountDownLatch(3); + xdsClient.setExpectedRdsCount(3); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); @@ -556,7 +555,7 @@ public class XdsServerWrapperTest { xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3); verify(mockServer, never()).start(); verify(listener, never()).onServing(); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); @@ -602,12 +601,11 @@ public class XdsServerWrapperTest { EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0")); - xdsClient.rdsCount = new CountDownLatch(1); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-0"))); start.get(5000, TimeUnit.MILLISECONDS); @@ -633,9 +631,9 @@ public class XdsServerWrapperTest { EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1")); - xdsClient.rdsCount = new CountDownLatch(1); + xdsClient.setExpectedRdsCount(1); xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); xdsClient.deliverRdsUpdate("r0", @@ -688,7 +686,7 @@ public class XdsServerWrapperTest { EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); - xdsClient.rdsCount.await(); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); start.get(5000, TimeUnit.MILLISECONDS); assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) @@ -1235,7 +1233,7 @@ public class XdsServerWrapperTest { VirtualHost virtualHost = VirtualHost.create( "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), ImmutableMap.of("filter-config-name-0", f0Override)); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost)); start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer).start();