rls: fix child lb leak when client channel is shutdown (#8750)

When client channel is shutting down, the RlsLoadBalancer is shutting down. However, the child loadbalancers of RlsLoadBalancer are not shut down. This is causing the issue b/209831670
This commit is contained in:
ZHANG Dapeng 2022-01-12 14:58:44 -08:00 committed by GitHub
parent 26f0d611db
commit 7a23fb27fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 74 deletions

View File

@ -30,17 +30,14 @@ import com.google.common.util.concurrent.SettableFuture;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ConnectivityState;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancerProvider;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
@ -51,7 +48,6 @@ import io.grpc.lookup.v1.RouteLookupServiceGrpc;
import io.grpc.lookup.v1.RouteLookupServiceGrpc.RouteLookupServiceStub;
import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider;
import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener;
import io.grpc.rls.LbPolicyConfiguration.ChildLoadBalancingPolicy;
import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper;
import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory;
import io.grpc.rls.LruCache.EvictionListener;
@ -138,7 +134,8 @@ final class CachingRlsLbClient {
rlsConfig.getCacheSizeBytes(),
builder.evictionListener,
scheduledExecutorService,
timeProvider);
timeProvider,
lock);
logger = helper.getChannelLogger();
String serverHost = null;
try {
@ -181,7 +178,9 @@ final class CachingRlsLbClient {
new ChildLoadBalancerHelperProvider(helper, new SubchannelStateManagerImpl(), rlsPicker);
refCountedChildPolicyWrapperFactory =
new RefCountedChildPolicyWrapperFactory(
childLbHelperProvider, new BackoffRefreshListener());
lbPolicyConfig.getLoadBalancingPolicy(), childLbResolvedAddressFactory,
childLbHelperProvider,
new BackoffRefreshListener());
logger.log(ChannelLogLevel.DEBUG, "CachingRlsLbClient created");
}
@ -536,6 +535,7 @@ final class CachingRlsLbClient {
private final long staleTime;
private final ChildPolicyWrapper childPolicyWrapper;
// GuardedBy CachingRlsLbClient.lock
DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) {
super(request);
this.response = checkNotNull(response, "response");
@ -546,29 +546,6 @@ final class CachingRlsLbClient {
long now = timeProvider.currentTimeNanos();
expireTime = now + maxAgeNanos;
staleTime = now + staleAgeNanos;
if (childPolicyWrapper.getPicker() != null) {
childPolicyWrapper.refreshState();
} else {
createChildLbPolicy();
}
}
private void createChildLbPolicy() {
ChildLoadBalancingPolicy childPolicy = lbPolicyConfig.getLoadBalancingPolicy();
LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider();
ConfigOrError lbConfig =
lbProvider
.parseLoadBalancingPolicyConfig(
childPolicy.getEffectiveChildPolicy(childPolicyWrapper.getTarget()));
LoadBalancer lb = lbProvider.newLoadBalancer(childPolicyWrapper.getHelper());
logger.log(
ChannelLogLevel.DEBUG,
"RLS child lb created. config: {0}",
lbConfig.getConfig());
lb.handleResolvedAddresses(childLbResolvedAddressFactory.create(lbConfig.getConfig()));
lb.requestConnection();
}
/**
@ -637,7 +614,9 @@ final class CachingRlsLbClient {
@Override
void cleanup() {
refCountedChildPolicyWrapperFactory.release(childPolicyWrapper);
synchronized (lock) {
refCountedChildPolicyWrapperFactory.release(childPolicyWrapper);
}
}
@Override
@ -856,14 +835,15 @@ final class CachingRlsLbClient {
RlsAsyncLruCache(long maxEstimatedSizeBytes,
@Nullable EvictionListener<RouteLookupRequest, CacheEntry> evictionListener,
ScheduledExecutorService ses, TimeProvider timeProvider) {
ScheduledExecutorService ses, TimeProvider timeProvider, Object lock) {
super(
maxEstimatedSizeBytes,
new AutoCleaningEvictionListener(evictionListener),
1,
TimeUnit.MINUTES,
ses,
timeProvider);
timeProvider,
lock);
}
@Override
@ -985,27 +965,9 @@ final class CachingRlsLbClient {
}
fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget);
}
LoadBalancerProvider lbProvider =
lbPolicyConfig.getLoadBalancingPolicy().getEffectiveLbProvider();
final LoadBalancer lb =
lbProvider.newLoadBalancer(fallbackChildPolicyWrapper.getHelper());
final ConfigOrError lbConfig =
lbProvider
.parseLoadBalancingPolicyConfig(
lbPolicyConfig
.getLoadBalancingPolicy()
.getEffectiveChildPolicy(defaultTarget));
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
lb.handleResolvedAddresses(
childLbResolvedAddressFactory.create(lbConfig.getConfig()));
lb.requestConnection();
}
});
}
// GuardedBy CachingRlsLbClient.lock
void close() {
if (fallbackChildPolicyWrapper != null) {
refCountedChildPolicyWrapperFactory.release(fallbackChildPolicyWrapper);

View File

@ -22,12 +22,15 @@ import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ConnectivityState;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.internal.ObjectPool;
import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider;
import io.grpc.rls.RlsProtoData.RouteLookupConfig;
@ -191,33 +194,49 @@ final class LbPolicyConfiguration {
/** Factory for {@link ChildPolicyWrapper}. */
static final class RefCountedChildPolicyWrapperFactory {
// GuardedBy CachingRlsLbClient.lock
@VisibleForTesting
final Map<String /* target */, RefCountedChildPolicyWrapper> childPolicyMap =
new HashMap<>();
private final ChildLoadBalancerHelperProvider childLbHelperProvider;
private final ChildLbStatusListener childLbStatusListener;
private final ChildLoadBalancingPolicy childPolicy;
private final ResolvedAddressFactory childLbResolvedAddressFactory;
public RefCountedChildPolicyWrapperFactory(
ChildLoadBalancingPolicy childPolicy,
ResolvedAddressFactory childLbResolvedAddressFactory,
ChildLoadBalancerHelperProvider childLbHelperProvider,
ChildLbStatusListener childLbStatusListener) {
this.childPolicy = checkNotNull(childPolicy, "childPolicy");
this.childLbResolvedAddressFactory =
checkNotNull(childLbResolvedAddressFactory, "childLbResolvedAddressFactory");
this.childLbHelperProvider = checkNotNull(childLbHelperProvider, "childLbHelperProvider");
this.childLbStatusListener = checkNotNull(childLbStatusListener, "childLbStatusListener");
}
// GuardedBy CachingRlsLbClient.lock
ChildPolicyWrapper createOrGet(String target) {
// TODO(creamsoup) check if the target is valid or not
RefCountedChildPolicyWrapper pooledChildPolicyWrapper = childPolicyMap.get(target);
if (pooledChildPolicyWrapper == null) {
ChildPolicyWrapper childPolicyWrapper =
new ChildPolicyWrapper(target, childLbHelperProvider, childLbStatusListener);
ChildPolicyWrapper childPolicyWrapper = new ChildPolicyWrapper(
target, childPolicy, childLbResolvedAddressFactory, childLbHelperProvider,
childLbStatusListener);
pooledChildPolicyWrapper = RefCountedChildPolicyWrapper.of(childPolicyWrapper);
childPolicyMap.put(target, pooledChildPolicyWrapper);
return pooledChildPolicyWrapper.getObject();
} else {
ChildPolicyWrapper childPolicyWrapper = pooledChildPolicyWrapper.getObject();
if (childPolicyWrapper.getPicker() != null) {
childPolicyWrapper.refreshState();
}
return childPolicyWrapper;
}
return pooledChildPolicyWrapper.getObject();
}
// GuardedBy CachingRlsLbClient.lock
void release(ChildPolicyWrapper childPolicyWrapper) {
checkNotNull(childPolicyWrapper, "childPolicyWrapper");
String target = childPolicyWrapper.getTarget();
@ -238,16 +257,36 @@ final class LbPolicyConfiguration {
private final String target;
private final ChildPolicyReportingHelper helper;
private final LoadBalancer lb;
private volatile SubchannelPicker picker;
private ConnectivityState state;
public ChildPolicyWrapper(
String target,
ChildLoadBalancingPolicy childPolicy,
final ResolvedAddressFactory childLbResolvedAddressFactory,
ChildLoadBalancerHelperProvider childLbHelperProvider,
ChildLbStatusListener childLbStatusListener) {
this.target = target;
this.helper =
new ChildPolicyReportingHelper(childLbHelperProvider, childLbStatusListener);
LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider();
final ConfigOrError lbConfig =
lbProvider
.parseLoadBalancingPolicyConfig(
childPolicy.getEffectiveChildPolicy(target));
this.lb = lbProvider.newLoadBalancer(helper);
helper.getChannelLogger().log(
ChannelLogLevel.DEBUG, "RLS child lb created. config: {0}", lbConfig.getConfig());
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
lb.handleResolvedAddresses(
childLbResolvedAddressFactory.create(lbConfig.getConfig()));
lb.requestConnection();
}
});
}
String getTarget() {
@ -263,7 +302,25 @@ final class LbPolicyConfiguration {
}
void refreshState() {
helper.updateBalancingState(state, picker);
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
helper.updateBalancingState(state, picker);
}
}
);
}
void shutdown() {
helper.getSynchronizationContext().execute(
new Runnable() {
@Override
public void run() {
lb.shutdown();
}
}
);
}
@Override
@ -346,6 +403,7 @@ final class LbPolicyConfiguration {
long newCnt = refCnt.decrementAndGet();
checkState(newCnt != -1, "Cannot return never pooled childPolicyWrapper");
if (newCnt == 0) {
childPolicyWrapper.shutdown();
childPolicyWrapper = null;
}
return null;

View File

@ -48,7 +48,7 @@ import javax.annotation.concurrent.ThreadSafe;
@ThreadSafe
abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> {
private final Object lock = new Object();
private final Object lock;
@GuardedBy("lock")
private final LinkedHashMap<K, SizedValue> delegate;
@ -64,9 +64,11 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> {
int cleaningInterval,
TimeUnit cleaningIntervalUnit,
ScheduledExecutorService ses,
final TimeProvider timeProvider) {
final TimeProvider timeProvider,
Object lock) {
checkState(estimatedMaxSizeBytes > 0, "max estimated cache size should be positive");
this.estimatedMaxSizeBytes = estimatedMaxSizeBytes;
this.lock = checkNotNull(lock, "lock");
this.evictionListener = new SizeHandlingEvictionListener(evictionListener);
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
delegate = new LinkedHashMap<K, SizedValue>(
@ -200,14 +202,15 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> {
}
@Override
public final void invalidateAll(Iterable<K> keys) {
checkNotNull(keys, "keys");
public final void invalidateAll() {
synchronized (lock) {
for (K key : keys) {
SizedValue existing = delegate.remove(key);
if (existing != null) {
evictionListener.onEviction(key, existing, EvictionType.EXPLICIT);
Iterator<Map.Entry<K, SizedValue>> iterator = delegate.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<K, SizedValue> entry = iterator.next();
if (entry.getValue() != null) {
evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT);
}
iterator.remove();
}
}
}
@ -291,13 +294,10 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> {
public final void close() {
synchronized (lock) {
periodicCleaner.stop();
doClose();
delegate.clear();
invalidateAll();
}
}
protected void doClose() {}
/** Periodically cleans up the AsyncRequestCache. */
private final class PeriodicCleaner {

View File

@ -49,10 +49,10 @@ interface LruCache<K, V> {
V invalidate(K key);
/**
* Invalidates cache entries for given keys. This operation will trigger {@link EvictionListener}
* Invalidates cache entries for all keys. This operation will trigger {@link EvictionListener}
* with {@link EvictionType#EXPLICIT}.
*/
void invalidateAll(Iterable<K> keys);
void invalidateAll();
/** Returns {@code true} if given key is cached. */
@CheckReturnValue

View File

@ -19,6 +19,7 @@ package io.grpc.rls;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.rls.CachingRlsLbClient.RLS_DATA_KEY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
@ -81,6 +82,7 @@ import java.net.SocketAddress;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
@ -172,6 +174,9 @@ public class CachingRlsLbClientTest {
public void tearDown() throws Exception {
rlsLbClient.close();
CachingRlsLbClient.enableOobChannelDirectPath = existingEnableOobChannelDirectPath;
assertWithMessage(
"On client shut down, RlsLoadBalancer must shut down with all its child loadbalancers.")
.that(lbProvider.loadBalancers).isEmpty();
}
private CachedRouteLookupResponse getInSyncContext(
@ -462,6 +467,7 @@ public class CachingRlsLbClientTest {
* immediately fails when using the fallback target.
*/
private static final class TestLoadBalancerProvider extends LoadBalancerProvider {
final Set<LoadBalancer> loadBalancers = new HashSet<>();
@Override
public boolean isAvailable() {
@ -486,7 +492,7 @@ public class CachingRlsLbClientTest {
@Override
public LoadBalancer newLoadBalancer(final Helper helper) {
return new LoadBalancer() {
LoadBalancer loadBalancer = new LoadBalancer() {
@Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
@ -527,8 +533,12 @@ public class CachingRlsLbClientTest {
@Override
public void shutdown() {
loadBalancers.remove(this);
}
};
loadBalancers.add(loadBalancer);
return loadBalancer;
}
}

View File

@ -18,17 +18,25 @@ package io.grpc.rls;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.grpc.ChannelLogger;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.SynchronizationContext;
import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider;
import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener;
import io.grpc.rls.LbPolicyConfiguration.ChildLoadBalancingPolicy;
@ -36,23 +44,58 @@ import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper;
import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper;
import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException;
import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory;
import java.lang.Thread.UncaughtExceptionHandler;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentMatchers;
@RunWith(JUnit4.class)
public class LbPolicyConfigurationTest {
private final Helper helper = mock(Helper.class);
private final LoadBalancerProvider lbProvider = mock(LoadBalancerProvider.class);
private final SubchannelStateManager subchannelStateManager = new SubchannelStateManagerImpl();
private final SubchannelPicker picker = mock(SubchannelPicker.class);
private final ChildLbStatusListener childLbStatusListener = mock(ChildLbStatusListener.class);
private final ResolvedAddressFactory resolvedAddressFactory =
new ResolvedAddressFactory() {
@Override
public ResolvedAddresses create(Object childLbConfig) {
return ResolvedAddresses.newBuilder()
.setAddresses(ImmutableList.<EquivalentAddressGroup>of())
.build();
}
};
private final RefCountedChildPolicyWrapperFactory factory =
new RefCountedChildPolicyWrapperFactory(
new ChildLoadBalancingPolicy(
"targetFieldName",
ImmutableMap.<String, Object>of("foo", "bar"),
lbProvider),
resolvedAddressFactory,
new ChildLoadBalancerHelperProvider(helper, subchannelStateManager, picker),
childLbStatusListener);
@Before
public void setUp() {
doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger();
doReturn(
new SynchronizationContext(
new UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
}))
.when(helper).getSynchronizationContext();
doReturn(mock(LoadBalancer.class)).when(lbProvider).newLoadBalancer(any(Helper.class));
doReturn(ConfigOrError.fromConfig(new Object()))
.when(lbProvider).parseLoadBalancingPolicyConfig(ArgumentMatchers.<Map<String, ?>>any());
}
@Test
public void childPolicyWrapper_refCounted() {
String target = "target";

View File

@ -23,7 +23,6 @@ import static org.mockito.Mockito.CALLS_REAL_METHODS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import com.google.common.collect.ImmutableList;
import io.grpc.rls.DoNotUseDirectScheduledExecutorService.FakeTimeProvider;
import io.grpc.rls.LruCache.EvictionListener;
import io.grpc.rls.LruCache.EvictionType;
@ -62,7 +61,8 @@ public class LinkedHashLruCacheTest {
10,
TimeUnit.NANOSECONDS,
fakeScheduledService,
timeProvider) {
timeProvider,
new Object()) {
@Override
protected boolean isExpired(Integer key, Entry value, long nowNanos) {
return value.expireTime <= nowNanos;
@ -210,7 +210,7 @@ public class LinkedHashLruCacheTest {
assertThat(cache.estimatedSize()).isEqualTo(2);
cache.invalidateAll(ImmutableList.of(1, 2));
cache.invalidateAll();
assertThat(cache.estimatedSize()).isEqualTo(0);
}