xds: ensure server interceptors are created in a sync context (#11930)

`XdsServerWrapper#generatePerRouteInterceptors` was always intended
to be executed within a sync context. This PR ensures that by calling
`syncContext.throwIfNotInThisSynchronizationContext()`.

This change is needed for upcoming xDS filter state retention because
the new tests in XdsServerWrapperTest flake with this NPE:

> `Cannot invoke "io.grpc.xds.client.XdsClient$ResourceWatcher.onChanged(io.grpc.xds.client.XdsClient$ResourceUpdate)" because "this.ldsWatcher" is null`
This commit is contained in:
Sergii Tkachenko 2025-03-03 17:28:36 -05:00 committed by GitHub
parent cdab410b81
commit 1a2285b527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 36 deletions

View File

@ -524,9 +524,7 @@ final class XdsServerWrapper extends Server {
private ImmutableMap<Route, ServerInterceptor> generatePerRouteInterceptors(
@Nullable List<NamedFilterConfig> filterConfigs, List<VirtualHost> virtualHosts) {
// This should always be called from the sync context.
// Ideally we'd want to throw otherwise, but this breaks the tests now.
// syncContext.throwIfNotInThisSynchronizationContext();
syncContext.throwIfNotInThisSynchronizationContext();
ImmutableMap.Builder<Route, ServerInterceptor> perRouteInterceptors =
new ImmutableMap.Builder<>();

View File

@ -38,6 +38,7 @@ import io.grpc.xds.client.EnvoyProtoData;
import io.grpc.xds.client.XdsClient;
import io.grpc.xds.client.XdsInitializationException;
import io.grpc.xds.client.XdsResourceType;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -45,7 +46,10 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;
/**
@ -174,12 +178,18 @@ public class XdsServerTestHelper {
}
}
// Implementation details:
// 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`.
// 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor.
static final class FakeXdsClient extends XdsClient {
boolean shutdown;
SettableFuture<String> ldsResource = SettableFuture.create();
ResourceWatcher<LdsUpdate> ldsWatcher;
CountDownLatch rdsCount = new CountDownLatch(1);
public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5);
private boolean shutdown;
@Nullable SettableFuture<String> ldsResource = SettableFuture.create();
@Nullable ResourceWatcher<LdsUpdate> ldsWatcher;
private CountDownLatch rdsCount = new CountDownLatch(1);
final Map<String, ResourceWatcher<RdsUpdate>> rdsWatchers = new HashMap<>();
@Nullable private volatile Executor serverExecutor;
@Override
public TlsContextManager getSecurityConfig() {
@ -193,14 +203,20 @@ public class XdsServerTestHelper {
@Override
@SuppressWarnings("unchecked")
public <T extends ResourceUpdate> void watchXdsResource(XdsResourceType<T> resourceType,
String resourceName,
ResourceWatcher<T> watcher,
Executor syncContext) {
public synchronized <T extends ResourceUpdate> void watchXdsResource(
XdsResourceType<T> resourceType,
String resourceName,
ResourceWatcher<T> watcher,
Executor executor) {
if (serverExecutor != null) {
assertThat(executor).isEqualTo(serverExecutor);
}
switch (resourceType.typeName()) {
case "LDS":
assertThat(ldsWatcher).isNull();
ldsWatcher = (ResourceWatcher<LdsUpdate>) watcher;
serverExecutor = executor;
ldsResource.set(resourceName);
break;
case "RDS":
@ -213,14 +229,14 @@ public class XdsServerTestHelper {
}
@Override
public <T extends ResourceUpdate> void cancelXdsResourceWatch(XdsResourceType<T> type,
String resourceName,
ResourceWatcher<T> watcher) {
public synchronized <T extends ResourceUpdate> void cancelXdsResourceWatch(
XdsResourceType<T> type, String resourceName, ResourceWatcher<T> watcher) {
switch (type.typeName()) {
case "LDS":
assertThat(ldsWatcher).isNotNull();
ldsResource = null;
ldsWatcher = null;
serverExecutor = null;
break;
case "RDS":
rdsWatchers.remove(resourceName);
@ -230,27 +246,58 @@ public class XdsServerTestHelper {
}
@Override
public void shutdown() {
public synchronized void shutdown() {
shutdown = true;
}
@Override
public boolean isShutDown() {
public synchronized boolean isShutDown() {
return shutdown;
}
public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException {
if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) {
throw new TimeoutException("Timeout " + timeout + " waiting for RDSs");
}
}
public void setExpectedRdsCount(int count) {
rdsCount = new CountDownLatch(count);
}
private void execute(Runnable action) {
// This method ensures that all watcher updates:
// - Happen after the server started watching LDS.
// - Are executed within the sync context of the server.
//
// Note that this doesn't guarantee that any of the RDS watchers are created.
// Tests should use setExpectedRdsCount(int) and awaitRds() for that.
if (ldsResource == null) {
throw new IllegalStateException("xDS resource update after watcher cancel");
}
try {
ldsResource.get(DEFAULT_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS);
} catch (ExecutionException | TimeoutException e) {
throw new RuntimeException("Can't resolve LDS resource name in " + DEFAULT_TIMEOUT, e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
serverExecutor.execute(action);
}
void deliverLdsUpdate(List<FilterChain> filterChains,
FilterChain defaultFilterChain) {
ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create(
"listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain)));
deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create(
"listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain)));
}
void deliverLdsUpdate(LdsUpdate ldsUpdate) {
ldsWatcher.onChanged(ldsUpdate);
execute(() -> ldsWatcher.onChanged(ldsUpdate));
}
void deliverRdsUpdate(String rdsName, List<VirtualHost> virtualHosts) {
rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts));
void deliverRdsUpdate(String resourceName, List<VirtualHost> virtualHosts) {
execute(() -> rdsWatchers.get(resourceName).onChanged(new RdsUpdate(virtualHosts)));
}
}
}

View File

@ -74,7 +74,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@ -252,7 +251,7 @@ public class XdsServerWrapperTest {
FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual);
FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds"));
xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1")));
verify(listener, timeout(5000)).onServing();
@ -261,7 +260,7 @@ public class XdsServerWrapperTest {
xdsServerWrapper.shutdown();
assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue();
assertThat(xdsClient.isShutDown()).isTrue();
verify(mockServer).shutdown();
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
@ -303,7 +302,7 @@ public class XdsServerWrapperTest {
verify(mockServer, never()).start();
assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue();
assertThat(xdsClient.isShutDown()).isTrue();
verify(mockServer).shutdown();
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
@ -342,7 +341,7 @@ public class XdsServerWrapperTest {
xdsServerWrapper.shutdown();
assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue();
assertThat(xdsClient.isShutDown()).isTrue();
verify(mockBuilder, times(1)).build();
verify(mockServer, times(1)).shutdown();
xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS);
@ -367,7 +366,7 @@ public class XdsServerWrapperTest {
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier();
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1")));
try {
@ -434,7 +433,7 @@ public class XdsServerWrapperTest {
xdsClient.ldsResource.get(5, TimeUnit.SECONDS);
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1")));
try {
@ -544,7 +543,7 @@ public class XdsServerWrapperTest {
0L, Collections.singletonList(virtualHost), new ArrayList<NamedFilterConfig>());
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
xdsClient.rdsCount = new CountDownLatch(3);
xdsClient.setExpectedRdsCount(3);
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
assertThat(start.isDone()).isFalse();
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();
@ -556,7 +555,7 @@ public class XdsServerWrapperTest {
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3);
verify(mockServer, never()).start();
verify(listener, never()).onServing();
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r1",
Collections.singletonList(createVirtualHost("virtual-host-1")));
@ -602,12 +601,11 @@ public class XdsServerWrapperTest {
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0"));
xdsClient.rdsCount = new CountDownLatch(1);
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2);
assertThat(start.isDone()).isFalse();
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r0",
Collections.singletonList(createVirtualHost("virtual-host-0")));
start.get(5000, TimeUnit.MILLISECONDS);
@ -633,9 +631,9 @@ public class XdsServerWrapperTest {
EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0"));
EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1"));
EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1"));
xdsClient.rdsCount = new CountDownLatch(1);
xdsClient.setExpectedRdsCount(1);
xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r1",
Collections.singletonList(createVirtualHost("virtual-host-1")));
xdsClient.deliverRdsUpdate("r0",
@ -688,7 +686,7 @@ public class XdsServerWrapperTest {
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
xdsClient.rdsCount.await();
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED);
start.get(5000, TimeUnit.MILLISECONDS);
assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size())
@ -1235,7 +1233,7 @@ public class XdsServerWrapperTest {
VirtualHost virtualHost = VirtualHost.create(
"v1", Collections.singletonList("foo.google.com"), Arrays.asList(route),
ImmutableMap.of("filter-config-name-0", f0Override));
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost));
start.get(5000, TimeUnit.MILLISECONDS);
verify(mockServer).start();