diff --git a/context/src/main/java/io/grpc/ThreadLocalContextStorage.java b/context/src/main/java/io/grpc/ThreadLocalContextStorage.java index bab00cd92e..6b19f00348 100644 --- a/context/src/main/java/io/grpc/ThreadLocalContextStorage.java +++ b/context/src/main/java/io/grpc/ThreadLocalContextStorage.java @@ -28,7 +28,8 @@ final class ThreadLocalContextStorage extends Context.Storage { /** * Currently bound context. */ - private static final ThreadLocal localContext = new ThreadLocal(); + // VisibleForTesting + static final ThreadLocal localContext = new ThreadLocal(); @Override public Context doAttach(Context toAttach) { @@ -63,6 +64,10 @@ final class ThreadLocalContextStorage extends Context.Storage { @Override public Context current() { - return localContext.get(); + Context current = localContext.get(); + if (current == null) { + return Context.ROOT; + } + return current; } } diff --git a/context/src/test/java/io/grpc/ThreadLocalContextStorageTest.java b/context/src/test/java/io/grpc/ThreadLocalContextStorageTest.java index 5e3f7ea26a..429dee6961 100644 --- a/context/src/test/java/io/grpc/ThreadLocalContextStorageTest.java +++ b/context/src/test/java/io/grpc/ThreadLocalContextStorageTest.java @@ -18,6 +18,13 @@ package io.grpc; import static com.google.common.truth.Truth.assertThat; +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Handler; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -25,16 +32,58 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public final class ThreadLocalContextStorageTest { private static final Context.Key KEY = Context.key("test-key"); - private final ThreadLocalContextStorage storage = new ThreadLocalContextStorage(); + private static final ThreadLocalContextStorage storage = new ThreadLocalContextStorage(); + + private Context contextBeforeTest; + + @Before public void saveContext() { + contextBeforeTest = storage.doAttach(Context.ROOT); + } + + @After public void restoreContext() { + storage.detach(Context.ROOT, contextBeforeTest); + } @Test public void detach_threadLocalClearedOnRoot() { Context context = Context.ROOT.withValue(KEY, new Object()); Context old = storage.doAttach(context); - assertThat(old).isNull(); assertThat(storage.current()).isSameAs(context); - // Users see nulls converted to ROOT, so they will pass non-null as the "old" value - storage.detach(context, Context.ROOT); - assertThat(storage.current()).isNull(); + assertThat(ThreadLocalContextStorage.localContext.get()).isSameAs(context); + storage.detach(context, old); + // thread local must contain null to avoid leaking our ClassLoader via ROOT + assertThat(ThreadLocalContextStorage.localContext.get()).isNull(); + } + + @Test + public void detach_detachRoot() { + final List logs = new ArrayList<>(); + Handler handler = new Handler() { + @Override public void publish(LogRecord record) { + logs.add(record); + } + + @Override public void flush() {} + + @Override public void close() {} + }; + + // Explicitly choose ROOT as the current context + Context context = Context.ROOT; + Context old = storage.doAttach(context); + + // Attach and detach a random context + Context innerContext = Context.ROOT.withValue(KEY, new Object()); + storage.detach(innerContext, storage.doAttach(innerContext)); + + Logger logger = Logger.getLogger(ThreadLocalContextStorage.class.getName()); + logger.addHandler(handler); + try { + // Make sure detaching ROOT doesn't log a warning + storage.detach(context, old); + } finally { + logger.removeHandler(handler); + } + assertThat(logs).isEmpty(); } }