xds: ensure server interceptors are created in a sync context (#11930)

`XdsServerWrapper#generatePerRouteInterceptors` was always intended
to be executed within a sync context. This PR ensures that by calling
`syncContext.throwIfNotInThisSynchronizationContext()`.

This change is needed for upcoming xDS filter state retention because
the new tests in XdsServerWrapperTest flake with this NPE:

> `Cannot invoke "io.grpc.xds.client.XdsClient$ResourceWatcher.onChanged(io.grpc.xds.client.XdsClient$ResourceUpdate)" because "this.ldsWatcher" is null`
This commit is contained in:
Sergii Tkachenko 2025-03-03 17:28:36 -05:00 committed by GitHub
parent cdab410b81
commit 1a2285b527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 36 deletions

View File

@ -524,9 +524,7 @@ final class XdsServerWrapper extends Server {
private ImmutableMap<Route, ServerInterceptor> generatePerRouteInterceptors( private ImmutableMap<Route, ServerInterceptor> generatePerRouteInterceptors(
@Nullable List<NamedFilterConfig> filterConfigs, List<VirtualHost> virtualHosts) { @Nullable List<NamedFilterConfig> filterConfigs, List<VirtualHost> virtualHosts) {
// This should always be called from the sync context. syncContext.throwIfNotInThisSynchronizationContext();
// Ideally we'd want to throw otherwise, but this breaks the tests now.
// syncContext.throwIfNotInThisSynchronizationContext();
ImmutableMap.Builder<Route, ServerInterceptor> perRouteInterceptors = ImmutableMap.Builder<Route, ServerInterceptor> perRouteInterceptors =
new ImmutableMap.Builder<>(); new ImmutableMap.Builder<>();

View File

@ -38,6 +38,7 @@ import io.grpc.xds.client.EnvoyProtoData;
import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsClient;
import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsInitializationException;
import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.client.XdsResourceType;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -45,7 +46,10 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable; import javax.annotation.Nullable;
/** /**
@ -174,12 +178,18 @@ public class XdsServerTestHelper {
} }
} }
// Implementation details:
// 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`.
// 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor.
static final class FakeXdsClient extends XdsClient { static final class FakeXdsClient extends XdsClient {
boolean shutdown; public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5);
SettableFuture<String> ldsResource = SettableFuture.create();
ResourceWatcher<LdsUpdate> ldsWatcher; private boolean shutdown;
CountDownLatch rdsCount = new CountDownLatch(1); @Nullable SettableFuture<String> ldsResource = SettableFuture.create();
@Nullable ResourceWatcher<LdsUpdate> ldsWatcher;
private CountDownLatch rdsCount = new CountDownLatch(1);
final Map<String, ResourceWatcher<RdsUpdate>> rdsWatchers = new HashMap<>(); final Map<String, ResourceWatcher<RdsUpdate>> rdsWatchers = new HashMap<>();
@Nullable private volatile Executor serverExecutor;
@Override @Override
public TlsContextManager getSecurityConfig() { public TlsContextManager getSecurityConfig() {
@ -193,14 +203,20 @@ public class XdsServerTestHelper {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T extends ResourceUpdate> void watchXdsResource(XdsResourceType<T> resourceType, public synchronized <T extends ResourceUpdate> void watchXdsResource(
String resourceName, XdsResourceType<T> resourceType,
ResourceWatcher<T> watcher, String resourceName,
Executor syncContext) { ResourceWatcher<T> watcher,
Executor executor) {
if (serverExecutor != null) {
assertThat(executor).isEqualTo(serverExecutor);
}
switch (resourceType.typeName()) { switch (resourceType.typeName()) {
case "LDS": case "LDS":
assertThat(ldsWatcher).isNull(); assertThat(ldsWatcher).isNull();
ldsWatcher = (ResourceWatcher<LdsUpdate>) watcher; ldsWatcher = (ResourceWatcher<LdsUpdate>) watcher;
serverExecutor = executor;
ldsResource.set(resourceName); ldsResource.set(resourceName);
break; break;
case "RDS": case "RDS":
@ -213,14 +229,14 @@ public class XdsServerTestHelper {
} }
@Override @Override
public <T extends ResourceUpdate> void cancelXdsResourceWatch(XdsResourceType<T> type, public synchronized <T extends ResourceUpdate> void cancelXdsResourceWatch(
String resourceName, XdsResourceType<T> type, String resourceName, ResourceWatcher<T> watcher) {
ResourceWatcher<T> watcher) {
switch (type.typeName()) { switch (type.typeName()) {
case "LDS": case "LDS":
assertThat(ldsWatcher).isNotNull(); assertThat(ldsWatcher).isNotNull();
ldsResource = null; ldsResource = null;
ldsWatcher = null; ldsWatcher = null;
serverExecutor = null;
break; break;
case "RDS": case "RDS":
rdsWatchers.remove(resourceName); rdsWatchers.remove(resourceName);
@ -230,27 +246,58 @@ public class XdsServerTestHelper {
} }
@Override @Override
public void shutdown() { public synchronized void shutdown() {
shutdown = true; shutdown = true;
} }
@Override @Override
public boolean isShutDown() { public synchronized boolean isShutDown() {
return shutdown; return shutdown;
} }
public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException {
if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) {
throw new TimeoutException("Timeout " + timeout + " waiting for RDSs");
}
}
public void setExpectedRdsCount(int count) {
rdsCount = new CountDownLatch(count);
}
private void execute(Runnable action) {
// This method ensures that all watcher updates:
// - Happen after the server started watching LDS.
// - Are executed within the sync context of the server.
//
// Note that this doesn't guarantee that any of the RDS watchers are created.
// Tests should use setExpectedRdsCount(int) and awaitRds() for that.
if (ldsResource == null) {
throw new IllegalStateException("xDS resource update after watcher cancel");
}
try {
ldsResource.get(DEFAULT_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS);
} catch (ExecutionException | TimeoutException e) {
throw new RuntimeException("Can't resolve LDS resource name in " + DEFAULT_TIMEOUT, e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
serverExecutor.execute(action);
}
void deliverLdsUpdate(List<FilterChain> filterChains, void deliverLdsUpdate(List<FilterChain> filterChains,
FilterChain defaultFilterChain) { FilterChain defaultFilterChain) {
ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create( deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create(
"listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain)));
} }
void deliverLdsUpdate(LdsUpdate ldsUpdate) { void deliverLdsUpdate(LdsUpdate ldsUpdate) {
ldsWatcher.onChanged(ldsUpdate); execute(() -> ldsWatcher.onChanged(ldsUpdate));
} }
void deliverRdsUpdate(String rdsName, List<VirtualHost> virtualHosts) { void deliverRdsUpdate(String resourceName, List<VirtualHost> virtualHosts) {
rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts)); execute(() -> rdsWatchers.get(resourceName).onChanged(new RdsUpdate(virtualHosts)));
} }
} }
} }

View File

@ -74,7 +74,6 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -252,7 +251,7 @@ public class XdsServerWrapperTest {
FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual);
FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds")); FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds"));
xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds", xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1"))); Collections.singletonList(createVirtualHost("virtual-host-1")));
verify(listener, timeout(5000)).onServing(); verify(listener, timeout(5000)).onServing();
@ -261,7 +260,7 @@ public class XdsServerWrapperTest {
xdsServerWrapper.shutdown(); xdsServerWrapper.shutdown();
assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull(); assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue(); assertThat(xdsClient.isShutDown()).isTrue();
verify(mockServer).shutdown(); verify(mockServer).shutdown();
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
@ -303,7 +302,7 @@ public class XdsServerWrapperTest {
verify(mockServer, never()).start(); verify(mockServer, never()).start();
assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull(); assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue(); assertThat(xdsClient.isShutDown()).isTrue();
verify(mockServer).shutdown(); verify(mockServer).shutdown();
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
@ -342,7 +341,7 @@ public class XdsServerWrapperTest {
xdsServerWrapper.shutdown(); xdsServerWrapper.shutdown();
assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull(); assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue(); assertThat(xdsClient.isShutDown()).isTrue();
verify(mockBuilder, times(1)).build(); verify(mockBuilder, times(1)).build();
verify(mockServer, times(1)).shutdown(); verify(mockServer, times(1)).shutdown();
xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS); xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS);
@ -367,7 +366,7 @@ public class XdsServerWrapperTest {
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier(); SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier();
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds", xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1"))); Collections.singletonList(createVirtualHost("virtual-host-1")));
try { try {
@ -434,7 +433,7 @@ public class XdsServerWrapperTest {
xdsClient.ldsResource.get(5, TimeUnit.SECONDS); xdsClient.ldsResource.get(5, TimeUnit.SECONDS);
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds", xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1"))); Collections.singletonList(createVirtualHost("virtual-host-1")));
try { try {
@ -544,7 +543,7 @@ public class XdsServerWrapperTest {
0L, Collections.singletonList(virtualHost), new ArrayList<NamedFilterConfig>()); 0L, Collections.singletonList(virtualHost), new ArrayList<NamedFilterConfig>());
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
xdsClient.rdsCount = new CountDownLatch(3); xdsClient.setExpectedRdsCount(3);
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
assertThat(start.isDone()).isFalse(); assertThat(start.isDone()).isFalse();
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();
@ -556,7 +555,7 @@ public class XdsServerWrapperTest {
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3);
verify(mockServer, never()).start(); verify(mockServer, never()).start();
verify(listener, never()).onServing(); verify(listener, never()).onServing();
xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r1", xdsClient.deliverRdsUpdate("r1",
Collections.singletonList(createVirtualHost("virtual-host-1"))); Collections.singletonList(createVirtualHost("virtual-host-1")));
@ -602,12 +601,11 @@ public class XdsServerWrapperTest {
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0")); EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0"));
xdsClient.rdsCount = new CountDownLatch(1);
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2);
assertThat(start.isDone()).isFalse(); assertThat(start.isDone()).isFalse();
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();
xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r0", xdsClient.deliverRdsUpdate("r0",
Collections.singletonList(createVirtualHost("virtual-host-0"))); Collections.singletonList(createVirtualHost("virtual-host-0")));
start.get(5000, TimeUnit.MILLISECONDS); start.get(5000, TimeUnit.MILLISECONDS);
@ -633,9 +631,9 @@ public class XdsServerWrapperTest {
EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0"));
EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1"));
EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1")); EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1"));
xdsClient.rdsCount = new CountDownLatch(1); xdsClient.setExpectedRdsCount(1);
xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4); xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r1", xdsClient.deliverRdsUpdate("r1",
Collections.singletonList(createVirtualHost("virtual-host-1"))); Collections.singletonList(createVirtualHost("virtual-host-1")));
xdsClient.deliverRdsUpdate("r0", xdsClient.deliverRdsUpdate("r0",
@ -688,7 +686,7 @@ public class XdsServerWrapperTest {
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
xdsClient.rdsCount.await(); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED);
start.get(5000, TimeUnit.MILLISECONDS); start.get(5000, TimeUnit.MILLISECONDS);
assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size())
@ -1235,7 +1233,7 @@ public class XdsServerWrapperTest {
VirtualHost virtualHost = VirtualHost.create( VirtualHost virtualHost = VirtualHost.create(
"v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route),
ImmutableMap.of("filter-config-name-0", f0Override)); ImmutableMap.of("filter-config-name-0", f0Override));
xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost)); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost));
start.get(5000, TimeUnit.MILLISECONDS); start.get(5000, TimeUnit.MILLISECONDS);
verify(mockServer).start(); verify(mockServer).start();