diff --git a/core/src/test/java/io/grpc/ContextsTest.java b/core/src/test/java/io/grpc/ContextsTest.java index 7a4e73ff28..21ab32581b 100644 --- a/core/src/test/java/io/grpc/ContextsTest.java +++ b/core/src/test/java/io/grpc/ContextsTest.java @@ -31,6 +31,7 @@ package io.grpc; +import static io.grpc.Contexts.interceptCall; import static io.grpc.Contexts.statusFromCancelled; import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.Assert.assertEquals; @@ -41,12 +42,16 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import io.grpc.internal.FakeClock; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -55,6 +60,147 @@ import java.util.concurrent.TimeoutException; */ @RunWith(JUnit4.class) public class ContextsTest { + private static Context.Key contextKey = Context.key("key"); + /** For use in comparing context by reference. */ + private Context uniqueContext = Context.ROOT.withValue(contextKey, new Object()); + @SuppressWarnings("unchecked") + private MethodDescriptor method = mock(MethodDescriptor.class); + @SuppressWarnings("unchecked") + private ServerCall call = mock(ServerCall.class); + private Metadata headers = new Metadata(); + + @Test + public void interceptCall_basic() { + Context origContext = Context.current(); + final Object message = new Object(); + final List methodCalls = new ArrayList(); + final ServerCall.Listener listener = new ServerCall.Listener() { + @Override public void onMessage(Object messageIn) { + assertSame(message, messageIn); + assertSame(uniqueContext, Context.current()); + methodCalls.add(1); + } + + @Override public void onHalfClose() { + assertSame(uniqueContext, Context.current()); + methodCalls.add(2); + } + + @Override public void onCancel() { + assertSame(uniqueContext, Context.current()); + methodCalls.add(3); + } + + @Override public void onComplete() { + assertSame(uniqueContext, Context.current()); + methodCalls.add(4); + } + + @Override public void onReady() { + assertSame(uniqueContext, Context.current()); + methodCalls.add(5); + } + }; + ServerCall.Listener wrapped = interceptCall(uniqueContext, method, call, headers, + new ServerCallHandler() { + @Override + public ServerCall.Listener startCall(MethodDescriptor method, + ServerCall call, Metadata headers) { + assertSame(ContextsTest.this.method, method); + assertSame(ContextsTest.this.call, call); + assertSame(ContextsTest.this.headers, headers); + assertSame(uniqueContext, Context.current()); + return listener; + } + }); + assertSame(origContext, Context.current()); + + wrapped.onMessage(message); + wrapped.onHalfClose(); + wrapped.onCancel(); + wrapped.onComplete(); + wrapped.onReady(); + assertEquals(Arrays.asList(1, 2, 3, 4, 5), methodCalls); + assertSame(origContext, Context.current()); + } + + @Test + public void interceptCall_restoresIfNextThrows() { + Context origContext = Context.current(); + try { + interceptCall(uniqueContext, method, call, headers, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall(MethodDescriptor method, + ServerCall call, Metadata headers) { + throw new RuntimeException(); + } + }); + fail("Expected exception"); + } catch (RuntimeException expected) { + } + assertSame(origContext, Context.current()); + } + + @Test + public void interceptCall_restoresIfListenerThrows() { + Context origContext = Context.current(); + final ServerCall.Listener listener = new ServerCall.Listener() { + @Override public void onMessage(Object messageIn) { + throw new RuntimeException(); + } + + @Override public void onHalfClose() { + throw new RuntimeException(); + } + + @Override public void onCancel() { + throw new RuntimeException(); + } + + @Override public void onComplete() { + throw new RuntimeException(); + } + + @Override public void onReady() { + throw new RuntimeException(); + } + }; + ServerCall.Listener wrapped = interceptCall(uniqueContext, method, call, headers, + new ServerCallHandler() { + @Override + public ServerCall.Listener startCall(MethodDescriptor method, + ServerCall call, Metadata headers) { + return listener; + } + }); + + try { + wrapped.onMessage(new Object()); + fail("Exception expected"); + } catch (RuntimeException expected) { + } + try { + wrapped.onHalfClose(); + fail("Exception expected"); + } catch (RuntimeException expected) { + } + try { + wrapped.onCancel(); + fail("Exception expected"); + } catch (RuntimeException expected) { + } + try { + wrapped.onComplete(); + fail("Exception expected"); + } catch (RuntimeException expected) { + } + try { + wrapped.onReady(); + fail("Exception expected"); + } catch (RuntimeException expected) { + } + assertSame(origContext, Context.current()); + } @Test public void statusFromCancelled_returnNullIfCtxNotCancelled() { @@ -134,5 +280,4 @@ public class ContextsTest { assertEquals("context must not be null", npe.getMessage()); } } - }