core: remove channel reference from ManagedChannelWrapper (#5283)

This avoids a memory leak when the channel itself participates in a
reference cycle (e.g., when an interceptor retains a reference to an
Android app's context). With the current implementation, the static
`ManagedChannelOrphanWrapper.refs` map will keep the channel reachable
and prevent the ref cycle from being GCed.
This commit is contained in:
Eric Gribkoff 2019-01-25 16:36:04 -08:00 committed by GitHub
parent f973bbc06f
commit ce2ae1fb6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 44 deletions

View File

@ -24,7 +24,6 @@ import java.lang.ref.SoftReference;
import java.lang.ref.WeakReference; import java.lang.ref.WeakReference;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.LogRecord; import java.util.logging.LogRecord;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -56,24 +55,17 @@ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel {
@Override @Override
public ManagedChannel shutdown() { public ManagedChannel shutdown() {
phantom.shutdown = true; phantom.shutdown = true;
phantom.clear();
return super.shutdown(); return super.shutdown();
} }
@Override @Override
public ManagedChannel shutdownNow() { public ManagedChannel shutdownNow() {
phantom.shutdownNow = true; phantom.shutdown = true;
phantom.clear();
return super.shutdownNow(); return super.shutdownNow();
} }
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
boolean ret = super.awaitTermination(timeout, unit);
if (ret) {
phantom.clear();
}
return ret;
}
@VisibleForTesting @VisibleForTesting
static final class ManagedChannelReference extends WeakReference<ManagedChannelOrphanWrapper> { static final class ManagedChannelReference extends WeakReference<ManagedChannelOrphanWrapper> {
@ -87,10 +79,9 @@ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel {
private final ReferenceQueue<ManagedChannelOrphanWrapper> refqueue; private final ReferenceQueue<ManagedChannelOrphanWrapper> refqueue;
private final ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs; private final ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs;
private final ManagedChannel channel; private final String channelStr;
private final Reference<RuntimeException> allocationSite; private final Reference<RuntimeException> allocationSite;
private volatile boolean shutdown; private volatile boolean shutdown;
private volatile boolean shutdownNow;
ManagedChannelReference( ManagedChannelReference(
ManagedChannelOrphanWrapper orphanable, ManagedChannelOrphanWrapper orphanable,
@ -102,7 +93,7 @@ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel {
ENABLE_ALLOCATION_TRACKING ENABLE_ALLOCATION_TRACKING
? new RuntimeException("ManagedChannel allocation site") ? new RuntimeException("ManagedChannel allocation site")
: missingCallSite); : missingCallSite);
this.channel = channel; this.channelStr = channel.toString();
this.refqueue = refqueue; this.refqueue = refqueue;
this.refs = refs; this.refs = refs;
this.refs.put(this, this); this.refs.put(this, this);
@ -144,21 +135,18 @@ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel {
while ((ref = (ManagedChannelReference) refqueue.poll()) != null) { while ((ref = (ManagedChannelReference) refqueue.poll()) != null) {
RuntimeException maybeAllocationSite = ref.allocationSite.get(); RuntimeException maybeAllocationSite = ref.allocationSite.get();
ref.clearInternal(); // technically the reference is gone already. ref.clearInternal(); // technically the reference is gone already.
if (!(ref.shutdown && ref.channel.isTerminated())) { if (!ref.shutdown) {
orphanedChannels++; orphanedChannels++;
Level level = ref.shutdownNow ? Level.FINE : Level.SEVERE; Level level = Level.SEVERE;
if (logger.isLoggable(level)) { if (logger.isLoggable(level)) {
String fmt = String fmt =
"*~*~*~ Channel {0} was not " "*~*~*~ Channel {0} was not shutdown properly!!! ~*~*~*"
// Prefer to complain about shutdown if neither has been called. + System.getProperty("line.separator")
+ (!ref.shutdown ? "shutdown" : "terminated") + " Make sure to call shutdown()/shutdownNow() and wait "
+ " properly!!! ~*~*~*" + "until awaitTermination() returns true.";
+ System.getProperty("line.separator")
+ " Make sure to call shutdown()/shutdownNow() and wait "
+ "until awaitTermination() returns true.";
LogRecord lr = new LogRecord(level, fmt); LogRecord lr = new LogRecord(level, fmt);
lr.setLoggerName(logger.getName()); lr.setLoggerName(logger.getName());
lr.setParameters(new Object[]{ref.channel.toString()}); lr.setParameters(new Object[] {ref.channelStr});
lr.setThrown(maybeAllocationSite); lr.setThrown(maybeAllocationSite);
logger.log(lr); logger.log(lr);
} }

View File

@ -18,20 +18,22 @@ package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import com.google.common.testing.GcFinalization;
import com.google.common.testing.GcFinalization.FinalizationPredicate;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.internal.ManagedChannelOrphanWrapper.ManagedChannelReference; import io.grpc.internal.ManagedChannelOrphanWrapper.ManagedChannelReference;
import java.lang.ref.ReferenceQueue; import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Filter; import java.util.logging.Filter;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.LogRecord; import java.util.logging.LogRecord;
@ -43,10 +45,10 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public final class ManagedChannelOrphanWrapperTest { public final class ManagedChannelOrphanWrapperTest {
@Test @Test
public void orphanedChannelsAreLogged() throws Exception { public void orphanedChannelsAreLogged() {
ManagedChannel mc = mock(ManagedChannel.class); ManagedChannel mc = new TestManagedChannel();
String channelString = mc.toString(); String channelString = mc.toString();
ReferenceQueue<ManagedChannelOrphanWrapper> refqueue = final ReferenceQueue<ManagedChannelOrphanWrapper> refqueue =
new ReferenceQueue<ManagedChannelOrphanWrapper>(); new ReferenceQueue<ManagedChannelOrphanWrapper>();
ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs = ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs =
new ConcurrentHashMap<ManagedChannelReference, ManagedChannelReference>(); new ConcurrentHashMap<ManagedChannelReference, ManagedChannelReference>();
@ -71,22 +73,18 @@ public final class ManagedChannelOrphanWrapperTest {
} }
}); });
// TODO(carl-mastrangelo): consider using com.google.common.testing.GcFinalization instead.
try { try {
channel = null; channel = null;
boolean success = false; final AtomicInteger numOrphans = new AtomicInteger();
for (int retry = 0; retry < 3; retry++) { GcFinalization.awaitDone(
System.gc(); new FinalizationPredicate() {
System.runFinalization(); @Override
int orphans = ManagedChannelReference.cleanQueue(refqueue); public boolean isDone() {
if (orphans == 1) { numOrphans.getAndAdd(ManagedChannelReference.cleanQueue(refqueue));
success = true; return numOrphans.get() > 0;
break; }
} });
assertEquals("unexpected extra orphans", 0, orphans); assertEquals("unexpected extra orphans", 1, numOrphans.get());
Thread.sleep(100L * (1L << retry));
}
assertTrue("Channel was not garbage collected", success);
LogRecord lr; LogRecord lr;
synchronized (records) { synchronized (records) {
@ -102,7 +100,32 @@ public final class ManagedChannelOrphanWrapperTest {
} }
} }
private static final class TestManagedChannel extends ManagedChannel { @Test
public void refCycleIsGCed() {
ReferenceQueue<ManagedChannelOrphanWrapper> refqueue =
new ReferenceQueue<ManagedChannelOrphanWrapper>();
ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs =
new ConcurrentHashMap<ManagedChannelReference, ManagedChannelReference>();
ApplicationWithChannelRef app = new ApplicationWithChannelRef();
ChannelWithApplicationRef channelImpl = new ChannelWithApplicationRef();
ManagedChannelOrphanWrapper channel =
new ManagedChannelOrphanWrapper(channelImpl, refqueue, refs);
app.channel = channel;
channelImpl.application = app;
WeakReference<ApplicationWithChannelRef> appWeakRef =
new WeakReference<ApplicationWithChannelRef>(app);
// Simulate the application and channel going out of scope. A ref cycle between app and
// channel remains, so ensure that our tracking of orphaned channels does not prevent this
// reference cycle from being GCed.
channel = null;
app = null;
channelImpl = null;
GcFinalization.awaitClear(appWeakRef);
}
private static class TestManagedChannel extends ManagedChannel {
@Override @Override
public ManagedChannel shutdown() { public ManagedChannel shutdown() {
return null; return null;
@ -139,4 +162,12 @@ public final class ManagedChannelOrphanWrapperTest {
return null; return null;
} }
} }
private static final class ApplicationWithChannelRef {
private ManagedChannel channel;
}
private static final class ChannelWithApplicationRef extends TestManagedChannel {
private ApplicationWithChannelRef application;
}
} }