diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index 7e1ebbd091..e4e0cc10ea 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -27,6 +27,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeTrue; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyString; @@ -37,7 +38,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import com.google.common.base.Objects; import com.google.common.collect.Iterables; @@ -47,6 +47,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Grpc; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.TransportStats; @@ -62,6 +63,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; @@ -81,7 +83,6 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Matchers; -import org.mockito.stubbing.OngoingStubbing; /** Standard unit tests for {@link ClientTransport}s and {@link ServerTransport}s. */ @RunWith(JUnit4.class) @@ -160,20 +161,50 @@ public abstract class AbstractTransportTest { "ascii-key", Metadata.ASCII_STRING_MARSHALLER); private Metadata.Key binaryKey = Metadata.Key.of( "key-bin", StringBinaryMarshaller.INSTANCE); + private final Metadata.Key tracerHeaderKey = Metadata.Key.of( + "tracer-key", Metadata.ASCII_STRING_MARSHALLER); + private final String tracerKeyValue = "tracer-key-value"; private ManagedClientTransport.Listener mockClientTransportListener = mock(ManagedClientTransport.Listener.class); private MockServerListener serverListener = new MockServerListener(); private ArgumentCaptor throwableCaptor = ArgumentCaptor.forClass(Throwable.class); - private final ClientStreamTracer.Factory clientStreamTracerFactory = - mock(ClientStreamTracer.Factory.class); - private final TestClientStreamTracer clientStreamTracer1 = new TestClientStreamTracer(); private final TestClientStreamTracer clientStreamTracer2 = new TestClientStreamTracer(); - private final ServerStreamTracer.Factory serverStreamTracerFactory = - mock(ServerStreamTracer.Factory.class); + private final ClientStreamTracer.Factory clientStreamTracerFactory = mock( + ClientStreamTracer.Factory.class, + delegatesTo(new ClientStreamTracer.Factory() { + final ArrayDeque tracers = + new ArrayDeque<>(Arrays.asList(clientStreamTracer1, clientStreamTracer2)); + + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metadata) { + metadata.put(tracerHeaderKey, tracerKeyValue); + TestClientStreamTracer tracer = tracers.poll(); + if (tracer != null) { + return tracer; + } + return new TestClientStreamTracer(); + } + })); + private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer(); + private final ServerStreamTracer.Factory serverStreamTracerFactory = mock( + ServerStreamTracer.Factory.class, + delegatesTo(new ServerStreamTracer.Factory() { + final ArrayDeque tracers = + new ArrayDeque<>(Arrays.asList(serverStreamTracer1, serverStreamTracer2)); + + @Override + public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { + TestServerStreamTracer tracer = tracers.poll(); + if (tracer != null) { + return tracer; + } + return new TestServerStreamTracer(); + } + })); @Rule public ExpectedException thrown = ExpectedException.none(); @@ -181,21 +212,6 @@ public abstract class AbstractTransportTest { @Before public void setUp() { server = Iterables.getOnlyElement(newServer(Arrays.asList(serverStreamTracerFactory))); - OngoingStubbing clientStubbing = - when(clientStreamTracerFactory - .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class))) - .thenReturn(clientStreamTracer1) - .thenReturn(clientStreamTracer2); - OngoingStubbing serverStubbing = - when(serverStreamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) - .thenReturn(serverStreamTracer1) - .thenReturn(serverStreamTracer2); - for (int i = 0; i < 5; i++) { - // flowControlPushBack() creates quite a few streams. We need to make sure tracers are not - // shared among them, or assertion in TestClientStreamTracer will fail. - clientStubbing.thenReturn(new TestClientStreamTracer()); - serverStubbing.thenReturn(new TestServerStreamTracer()); - } callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory); } @@ -728,7 +744,6 @@ public abstract class AbstractTransportTest { } @Test - @SuppressWarnings("deprecation") public void basicStream() throws Exception { InOrder clientInOrder = inOrder(clientStreamTracerFactory); InOrder serverInOrder = inOrder(serverStreamTracerFactory); @@ -773,6 +788,7 @@ public abstract class AbstractTransportTest { Lists.newArrayList(serverStreamCreation.headers.getAll(asciiKey))); assertEquals(Lists.newArrayList(clientHeadersCopy.getAll(binaryKey)), Lists.newArrayList(serverStreamCreation.headers.getAll(binaryKey))); + assertEquals(tracerKeyValue, serverStreamCreation.headers.get(tracerHeaderKey)); ServerStream serverStream = serverStreamCreation.stream; ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener;