testing,core: don't use mocks for stream tracers (#3305)

This is a big, but mostly mechanical change.  The newly added Test*StreamTracer classes are designed to be extended which is why they are non final and have protected fields.  There are a few notable things in this:

1.  verifyNoMoreInteractions is gone.   The API for StreamTracers doesn't make this guarantee.  I have recovered this behavior by failing duplicate calls.  This has resulted in a few bugs in the test code being fixed.

2.  StreamTracers cannot be mocked anymore.  Tracers need to be thread safe, which mocks simply are not.  This leads to a HUGE number of reports when trying to find real races in gRPC.

3.  If these classes are useful, we can promote them out of internal.  I just put them here out of convenience.
This commit is contained in:
Carl Mastrangelo 2017-08-03 16:59:06 -07:00 committed by GitHub
parent c4f91272d2
commit 02cb718767
11 changed files with 592 additions and 167 deletions

View File

@ -16,6 +16,7 @@
package io.grpc; package io.grpc;
import com.google.errorprone.annotations.DoNotMock;
import javax.annotation.concurrent.ThreadSafe; import javax.annotation.concurrent.ThreadSafe;
/** /**
@ -23,6 +24,7 @@ import javax.annotation.concurrent.ThreadSafe;
*/ */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861")
@ThreadSafe @ThreadSafe
@DoNotMock
public abstract class StreamTracer { public abstract class StreamTracer {
/** /**
* Stream is closed. This will be called exactly once. * Stream is closed. This will be called exactly once.

View File

@ -17,14 +17,13 @@
package io.grpc.internal; package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallOptions; import io.grpc.CallOptions;
@ -36,6 +35,7 @@ import io.grpc.Status.Code;
import io.grpc.StreamTracer; import io.grpc.StreamTracer;
import io.grpc.internal.AbstractClientStream.TransportState; import io.grpc.internal.AbstractClientStream.TransportState;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer; import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import io.grpc.internal.testing.TestClientStreamTracer;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
@ -213,7 +213,7 @@ public class AbstractClientStreamTest {
@Test @Test
public void getRequest() { public void getRequest() {
AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class); AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class);
final ClientStreamTracer tracer = spy(new ClientStreamTracer() {}); final TestClientStreamTracer tracer = new TestClientStreamTracer();
ClientStreamTracer.Factory tracerFactory = ClientStreamTracer.Factory tracerFactory =
new ClientStreamTracer.Factory() { new ClientStreamTracer.Factory() {
@Override @Override
@ -237,10 +237,9 @@ public class AbstractClientStreamTest {
// GET requests don't have BODY. // GET requests don't have BODY.
verify(sink, never()) verify(sink, never())
.writeFrame(any(WritableBuffer.class), any(Boolean.class), any(Boolean.class)); .writeFrame(any(WritableBuffer.class), any(Boolean.class), any(Boolean.class));
verify(tracer).outboundMessage(); assertEquals(1, tracer.getOutboundMessageCount());
verify(tracer).outboundWireSize(1); assertEquals(1, tracer.getOutboundWireSize());
verify(tracer).outboundUncompressedSize(1); assertEquals(1, tracer.getOutboundUncompressedSize());
verifyNoMoreInteractions(tracer);
} }
/** /**

View File

@ -20,7 +20,6 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -36,6 +35,7 @@ import io.grpc.StatusRuntimeException;
import io.grpc.StreamTracer; import io.grpc.StreamTracer;
import io.grpc.internal.MessageDeframer.Listener; import io.grpc.internal.MessageDeframer.Listener;
import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream; import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream;
import io.grpc.internal.testing.TestStreamTracer.TestBaseStreamTracer;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
@ -60,7 +60,7 @@ public class MessageDeframerTest {
@Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final ExpectedException thrown = ExpectedException.none();
private Listener listener = mock(Listener.class); private Listener listener = mock(Listener.class);
private StreamTracer tracer = mock(StreamTracer.class); private TestBaseStreamTracer tracer = new TestBaseStreamTracer();
private StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[]{tracer}); private StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[]{tracer});
private ArgumentCaptor<Long> wireSizeCaptor = ArgumentCaptor.forClass(Long.class); private ArgumentCaptor<Long> wireSizeCaptor = ArgumentCaptor.forClass(Long.class);
private ArgumentCaptor<Long> uncompressedSizeCaptor = ArgumentCaptor.forClass(Long.class); private ArgumentCaptor<Long> uncompressedSizeCaptor = ArgumentCaptor.forClass(Long.class);
@ -374,23 +374,9 @@ public class MessageDeframerTest {
private void checkStats( private void checkStats(
int messagesReceived, long wireBytesReceived, long uncompressedBytesReceived) { int messagesReceived, long wireBytesReceived, long uncompressedBytesReceived) {
long actualWireSize = 0; assertEquals(messagesReceived, tracer.getInboundMessageCount());
long actualUncompressedSize = 0; assertEquals(wireBytesReceived, tracer.getInboundWireSize());
assertEquals(uncompressedBytesReceived, tracer.getInboundUncompressedSize());
verify(tracer, times(messagesReceived)).inboundMessage();
verify(tracer, atLeast(0)).inboundWireSize(wireSizeCaptor.capture());
for (Long portion : wireSizeCaptor.getAllValues()) {
actualWireSize += portion;
}
verify(tracer, atLeast(0)).inboundUncompressedSize(uncompressedSizeCaptor.capture());
for (Long portion : uncompressedSizeCaptor.getAllValues()) {
actualUncompressedSize += portion;
}
verifyNoMoreInteractions(tracer);
assertEquals(wireBytesReceived, actualWireSize);
assertEquals(uncompressedBytesReceived, actualUncompressedSize);
} }
private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) { private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) {

View File

@ -20,7 +20,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -28,6 +27,7 @@ import static org.mockito.Mockito.verifyZeroInteractions;
import io.grpc.Codec; import io.grpc.Codec;
import io.grpc.StreamTracer; import io.grpc.StreamTracer;
import io.grpc.internal.testing.TestStreamTracer.TestBaseStreamTracer;
import java.io.BufferedInputStream; import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -49,8 +49,8 @@ import org.mockito.MockitoAnnotations;
public class MessageFramerTest { public class MessageFramerTest {
@Mock @Mock
private MessageFramer.Sink sink; private MessageFramer.Sink sink;
@Mock
private StreamTracer tracer; private final TestBaseStreamTracer tracer = new TestBaseStreamTracer();
private MessageFramer framer; private MessageFramer framer;
@Captor @Captor
@ -371,20 +371,9 @@ public class MessageFramerTest {
long actualWireSize = 0; long actualWireSize = 0;
long actualUncompressedSize = 0; long actualUncompressedSize = 0;
verify(tracer, times(messagesSent)).outboundMessage(); assertEquals(messagesSent, tracer.getOutboundMessageCount());
verify(tracer, atLeast(0)).outboundWireSize(wireSizeCaptor.capture()); assertEquals(uncompressedBytesSent, tracer.getOutboundUncompressedSize());
for (Long portion : wireSizeCaptor.getAllValues()) { assertEquals(wireBytesSent, tracer.getOutboundWireSize());
actualWireSize += portion;
}
verify(tracer, atLeast(0)).outboundUncompressedSize(uncompressedSizeCaptor.capture());
for (Long portion : uncompressedSizeCaptor.getAllValues()) {
actualUncompressedSize += portion;
}
verifyNoMoreInteractions(tracer);
assertEquals(wireBytesSent, actualWireSize);
assertEquals(uncompressedBytesSent, actualUncompressedSize);
} }
static class ByteWritableBuffer implements WritableBuffer { static class ByteWritableBuffer implements WritableBuffer {

View File

@ -36,7 +36,6 @@ import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -62,6 +61,7 @@ import io.grpc.ServiceDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StringMarshaller; import io.grpc.StringMarshaller;
import io.grpc.internal.ServerImpl.JumpToApplicationThreadServerStreamListener; import io.grpc.internal.ServerImpl.JumpToApplicationThreadServerStreamListener;
import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.util.MutableHandlerRegistry; import io.grpc.util.MutableHandlerRegistry;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.File; import java.io.File;
@ -129,12 +129,13 @@ public class ServerImplTest {
@Mock @Mock
private ServerStreamTracer.Factory streamTracerFactory; private ServerStreamTracer.Factory streamTracerFactory;
private List<ServerStreamTracer.Factory> streamTracerFactories; private List<ServerStreamTracer.Factory> streamTracerFactories;
private final ServerStreamTracer streamTracer = spy(new ServerStreamTracer() { private final TestServerStreamTracer streamTracer = new TestServerStreamTracer() {
@Override @Override
public <ReqT, RespT> Context filterContext(Context context) { public <ReqT, RespT> Context filterContext(Context context) {
return context.withValue(SERVER_TRACER_ADDED_KEY, "context added by tracer"); Context newCtx = super.filterContext(context);
return newCtx.withValue(SERVER_TRACER_ADDED_KEY, "context added by tracer");
} }
}); };
@Mock @Mock
private ObjectPool<Executor> executorPool; private ObjectPool<Executor> executorPool;
private Builder builder = new Builder(); private Builder builder = new Builder();
@ -365,7 +366,7 @@ public class ServerImplTest {
assertEquals("Method not found: Waiter/nonexist", status.getDescription()); assertEquals("Method not found: Waiter/nonexist", status.getDescription());
verify(streamTracerFactory).newServerStreamTracer(eq("Waiter/nonexist"), same(requestHeaders)); verify(streamTracerFactory).newServerStreamTracer(eq("Waiter/nonexist"), same(requestHeaders));
verify(streamTracer, never()).serverCallStarted(any(ServerCall.class)); assertNull(streamTracer.getServerCall());
assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode()); assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode());
} }
@ -435,7 +436,7 @@ public class ServerImplTest {
assertEquals(1, executor.runDueTasks()); assertEquals(1, executor.runDueTasks());
ServerCall<String, Integer> call = callReference.get(); ServerCall<String, Integer> call = callReference.get();
assertNotNull(call); assertNotNull(call);
verify(streamTracer).serverCallStarted(same(call)); assertSame(call, streamTracer.getServerCall());
verify(stream).getAuthority(); verify(stream).getAuthority();
Context callContext = callContextReference.get(); Context callContext = callContextReference.get();
assertNotNull(callContext); assertNotNull(callContext);

View File

@ -25,10 +25,7 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -66,12 +63,14 @@ import io.grpc.ServerInterceptors;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import io.grpc.StreamTracer;
import io.grpc.auth.MoreCallCredentials; import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.AbstractServerImplBuilder; import io.grpc.internal.AbstractServerImplBuilder;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory; import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory;
import io.grpc.internal.testing.StatsTestUtils.MetricsRecord; import io.grpc.internal.testing.StatsTestUtils.MetricsRecord;
import io.grpc.internal.testing.TestClientStreamTracer;
import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.internal.testing.TestStreamTracer;
import io.grpc.protobuf.ProtoUtils; import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientCalls; import io.grpc.stub.ClientCalls;
@ -139,21 +138,32 @@ public abstract class AbstractInteropTest {
private static final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers = private static final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers =
new LinkedBlockingQueue<ServerStreamTracerInfo>(); new LinkedBlockingQueue<ServerStreamTracerInfo>();
private static class ServerStreamTracerInfo { private static final class ServerStreamTracerInfo {
final String fullMethodName; final String fullMethodName;
final ServerStreamTracer tracer; final InteropServerStreamTracer tracer;
ServerStreamTracerInfo(String fullMethodName, ServerStreamTracer tracer) { ServerStreamTracerInfo(String fullMethodName, InteropServerStreamTracer tracer) {
this.fullMethodName = fullMethodName; this.fullMethodName = fullMethodName;
this.tracer = tracer; this.tracer = tracer;
} }
private static final class InteropServerStreamTracer extends TestServerStreamTracer {
private volatile Context contextCapture;
@Override
public <ReqT, RespT> Context filterContext(Context context) {
contextCapture = context;
return super.filterContext(context);
}
}
} }
private static final ServerStreamTracer.Factory serverStreamTracerFactory = private static final ServerStreamTracer.Factory serverStreamTracerFactory =
new ServerStreamTracer.Factory() { new ServerStreamTracer.Factory() {
@Override @Override
public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) {
ServerStreamTracer tracer = spy(new ServerStreamTracer() {}); ServerStreamTracerInfo.InteropServerStreamTracer tracer
= new ServerStreamTracerInfo.InteropServerStreamTracer();
serverStreamTracers.add(new ServerStreamTracerInfo(fullMethodName, tracer)); serverStreamTracers.add(new ServerStreamTracerInfo(fullMethodName, tracer));
return tracer; return tracer;
} }
@ -200,14 +210,14 @@ public abstract class AbstractInteropTest {
protected TestServiceGrpc.TestServiceBlockingStub blockingStub; protected TestServiceGrpc.TestServiceBlockingStub blockingStub;
protected TestServiceGrpc.TestServiceStub asyncStub; protected TestServiceGrpc.TestServiceStub asyncStub;
private final LinkedBlockingQueue<ClientStreamTracer> clientStreamTracers = private final LinkedBlockingQueue<TestClientStreamTracer> clientStreamTracers =
new LinkedBlockingQueue<ClientStreamTracer>(); new LinkedBlockingQueue<TestClientStreamTracer>();
private final ClientStreamTracer.Factory clientStreamTracerFactory = private final ClientStreamTracer.Factory clientStreamTracerFactory =
new ClientStreamTracer.Factory() { new ClientStreamTracer.Factory() {
@Override @Override
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) { public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
ClientStreamTracer tracer = spy(new ClientStreamTracer() {}); TestClientStreamTracer tracer = new TestClientStreamTracer();
clientStreamTracers.add(tracer); clientStreamTracers.add(tracer);
return tracer; return tracer;
} }
@ -1655,22 +1665,26 @@ public abstract class AbstractInteropTest {
assertMetrics(method, status, null, null); assertMetrics(method, status, null, null);
} }
private void assertClientMetrics(String method, Status.Code status, private void assertClientMetrics(String method, Status.Code code,
Collection<? extends MessageLite> requests, Collection<? extends MessageLite> responses) { Collection<? extends MessageLite> requests, Collection<? extends MessageLite> responses) {
// Tracer-based stats // Tracer-based stats
ClientStreamTracer tracer = clientStreamTracers.poll(); TestClientStreamTracer tracer = clientStreamTracers.poll();
assertNotNull(tracer); assertNotNull(tracer);
verify(tracer).outboundHeaders(); assertTrue(tracer.getOutboundHeaders());
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
// assertClientMetrics() is called right after application receives status, // assertClientMetrics() is called right after application receives status,
// but streamClosed() may be called slightly later than that. So we need a timeout. // but streamClosed() may be called slightly later than that. So we need a timeout.
verify(tracer, timeout(5000)).streamClosed(statusCaptor.capture()); try {
assertEquals(status, statusCaptor.getValue().getCode()); assertTrue(tracer.await(5, TimeUnit.SECONDS));
} catch (InterruptedException e) {
throw new AssertionError(e);
}
assertEquals(code, tracer.getStatus().getCode());
// CensusStreamTracerModule records final status in interceptor, which is guaranteed to be done // CensusStreamTracerModule records final status in interceptor, which is guaranteed to be done
// before application receives status. // before application receives status.
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(); MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord();
checkTags(clientRecord, false, method, status); checkTags(clientRecord, false, method, code);
if (requests != null && responses != null) { if (requests != null && responses != null) {
checkTracerMetrics(tracer, requests, responses); checkTracerMetrics(tracer, requests, responses);
@ -1682,7 +1696,7 @@ public abstract class AbstractInteropTest {
assertClientMetrics(method, status, null, null); assertClientMetrics(method, status, null, null);
} }
private void assertServerMetrics(String method, Status.Code status, private void assertServerMetrics(String method, Status.Code code,
Collection<? extends MessageLite> requests, Collection<? extends MessageLite> responses) { Collection<? extends MessageLite> requests, Collection<? extends MessageLite> responses) {
AssertionError checkFailure = null; AssertionError checkFailure = null;
boolean passed = false; boolean passed = false;
@ -1703,7 +1717,7 @@ public abstract class AbstractInteropTest {
break; break;
} }
try { try {
checkTags(serverRecord, true, method, status); checkTags(serverRecord, true, method, code);
if (requests != null && responses != null) { if (requests != null && responses != null) {
checkCensusMetrics(serverRecord, true, requests, responses); checkCensusMetrics(serverRecord, true, requests, responses);
} }
@ -1731,12 +1745,16 @@ public abstract class AbstractInteropTest {
} }
try { try {
assertEquals(method, tracerInfo.fullMethodName); assertEquals(method, tracerInfo.fullMethodName);
verify(tracerInfo.tracer).filterContext(any(Context.class)); assertNotNull(tracerInfo.tracer.contextCapture);
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
// On the server, streamClosed() may be called after the client receives the final status. // On the server, streamClosed() may be called after the client receives the final status.
// So we use a timeout. // So we use a timeout.
verify(tracerInfo.tracer, timeout(1000)).streamClosed(statusCaptor.capture()); try {
assertEquals(status, statusCaptor.getValue().getCode()); assertTrue(tracerInfo.tracer.await(1, TimeUnit.SECONDS));
} catch (InterruptedException e) {
throw new AssertionError(e);
}
assertEquals(code, tracerInfo.tracer.getStatus().getCode());
if (requests != null && responses != null) { if (requests != null && responses != null) {
checkTracerMetrics(tracerInfo.tracer, responses, requests); checkTracerMetrics(tracerInfo.tracer, responses, requests);
} }
@ -1768,11 +1786,11 @@ public abstract class AbstractInteropTest {
} }
private static void checkTracerMetrics( private static void checkTracerMetrics(
StreamTracer tracer, TestStreamTracer tracer,
Collection<? extends MessageLite> sentMessages, Collection<? extends MessageLite> sentMessages,
Collection<? extends MessageLite> receivedMessages) { Collection<? extends MessageLite> receivedMessages) {
verify(tracer, times(sentMessages.size())).outboundMessage(); assertEquals(sentMessages.size(), tracer.getOutboundMessageCount());
verify(tracer, times(receivedMessages.size())).inboundMessage(); assertEquals(receivedMessages.size(), tracer.getInboundMessageCount());
long uncompressedSentSize = 0; long uncompressedSentSize = 0;
for (MessageLite msg : sentMessages) { for (MessageLite msg : sentMessages) {
@ -1782,20 +1800,9 @@ public abstract class AbstractInteropTest {
for (MessageLite msg : receivedMessages) { for (MessageLite msg : receivedMessages) {
uncompressedReceivedSize += msg.getSerializedSize(); uncompressedReceivedSize += msg.getSerializedSize();
} }
ArgumentCaptor<Long> outboundSizeCaptor = ArgumentCaptor.forClass(Long.class);
ArgumentCaptor<Long> inboundSizeCaptor = ArgumentCaptor.forClass(Long.class); assertEquals(uncompressedSentSize, tracer.getOutboundUncompressedSize());
verify(tracer, atLeast(0)).outboundUncompressedSize(outboundSizeCaptor.capture()); assertEquals(uncompressedReceivedSize, tracer.getInboundUncompressedSize());
verify(tracer, atLeast(0)).inboundUncompressedSize(inboundSizeCaptor.capture());
long recordedUncompressedOutboundSize = 0;
for (Long size : outboundSizeCaptor.getAllValues()) {
recordedUncompressedOutboundSize += size;
}
long recordedUncompressedInboundSize = 0;
for (Long size : inboundSizeCaptor.getAllValues()) {
recordedUncompressedInboundSize += size;
}
assertEquals(uncompressedSentSize, recordedUncompressedOutboundSize);
assertEquals(uncompressedReceivedSize, recordedUncompressedInboundSize);
} }
private static void checkCensusMetrics(MetricsRecord record, boolean server, private static void checkCensusMetrics(MetricsRecord record, boolean server,

View File

@ -64,6 +64,7 @@ import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext; import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
@ -106,7 +107,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
private ServerStreamTracer.Factory streamTracerFactory; private ServerStreamTracer.Factory streamTracerFactory;
private final ServerTransportListener transportListener = spy(new ServerTransportListenerImpl()); private final ServerTransportListener transportListener = spy(new ServerTransportListenerImpl());
private final ServerStreamTracer streamTracer = spy(new ServerStreamTracer() {}); private final TestServerStreamTracer streamTracer = new TestServerStreamTracer();
private NettyServerStream stream; private NettyServerStream stream;
private KeepAliveManager spyKeepAliveManager; private KeepAliveManager spyKeepAliveManager;

View File

@ -17,30 +17,28 @@
package io.grpc.internal.testing; package io.grpc.internal.testing;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue; import static org.junit.Assume.assumeTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -161,10 +159,12 @@ public abstract class AbstractTransportTest {
= ArgumentCaptor.forClass(InputStream.class); = ArgumentCaptor.forClass(InputStream.class);
private final ClientStreamTracer.Factory clientStreamTracerFactory = private final ClientStreamTracer.Factory clientStreamTracerFactory =
mock(ClientStreamTracer.Factory.class); mock(ClientStreamTracer.Factory.class);
private final ClientStreamTracer clientStreamTracer = spy(new ClientStreamTracer() {}); private final TestClientStreamTracer clientStreamTracer1 = new TestClientStreamTracer();
private final TestClientStreamTracer clientStreamTracer2 = new TestClientStreamTracer();
private final ServerStreamTracer.Factory serverStreamTracerFactory = private final ServerStreamTracer.Factory serverStreamTracerFactory =
mock(ServerStreamTracer.Factory.class); mock(ServerStreamTracer.Factory.class);
private final ServerStreamTracer serverStreamTracer = spy(new ServerStreamTracer() {}); private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer();
private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer();
@Rule @Rule
public ExpectedException thrown = ExpectedException.none(); public ExpectedException thrown = ExpectedException.none();
@ -174,9 +174,11 @@ public abstract class AbstractTransportTest {
server = newServer(Arrays.asList(serverStreamTracerFactory)); server = newServer(Arrays.asList(serverStreamTracerFactory));
when(clientStreamTracerFactory when(clientStreamTracerFactory
.newClientStreamTracer(any(CallOptions.class), any(Metadata.class))) .newClientStreamTracer(any(CallOptions.class), any(Metadata.class)))
.thenReturn(clientStreamTracer); .thenReturn(clientStreamTracer1)
.thenReturn(clientStreamTracer2);
when(serverStreamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) when(serverStreamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class)))
.thenReturn(serverStreamTracer); .thenReturn(serverStreamTracer1)
.thenReturn(serverStreamTracer2);
callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory); callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory);
} }
@ -402,8 +404,10 @@ public abstract class AbstractTransportTest {
verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture());
assertFalse(statusCaptor.getValue().isOk()); assertFalse(statusCaptor.getValue().isOk());
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer, timeout(TIMEOUT_MS)).streamClosed(same(status)); assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
verify(serverStreamTracer, timeout(TIMEOUT_MS)).streamClosed(same(statusCaptor.getValue())); assertSame(status, clientStreamTracer1.getStatus());
assertTrue(serverStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
assertSame(statusCaptor.getValue(), serverStreamTracer1.getStatus());
} }
} }
@ -437,8 +441,10 @@ public abstract class AbstractTransportTest {
.closed(statusCaptor.capture(), any(Metadata.class)); .closed(statusCaptor.capture(), any(Metadata.class));
assertFalse(statusCaptor.getValue().isOk()); assertFalse(statusCaptor.getValue().isOk());
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer, timeout(TIMEOUT_MS)).streamClosed(same(statusCaptor.getValue())); assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
verify(serverStreamTracer, timeout(TIMEOUT_MS)).streamClosed(same(shutdownStatus)); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
assertTrue(serverStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
assertSame(shutdownStatus, serverStreamTracer1.getStatus());
} }
// Generally will be same status provided to shutdownNow, but InProcessTransport can't // Generally will be same status provided to shutdownNow, but InProcessTransport can't
@ -506,7 +512,7 @@ public abstract class AbstractTransportTest {
@Test @Test
public void newStream_duringShutdown() throws Exception { public void newStream_duringShutdown() throws Exception {
InOrder inOrder = inOrder(clientStreamTracerFactory, clientStreamTracer, serverStreamTracer); InOrder inOrder = inOrder(clientStreamTracerFactory);
server.start(serverListener); server.start(serverListener);
client = newClientTransport(server); client = newClientTransport(server);
runIfNotNull(client.start(mockClientTransportListener)); runIfNotNull(client.start(mockClientTransportListener));
@ -531,7 +537,7 @@ public abstract class AbstractTransportTest {
.closed(statusCaptor.capture(), any(Metadata.class)); .closed(statusCaptor.capture(), any(Metadata.class));
assertCodeEquals(Status.UNAVAILABLE, statusCaptor.getValue()); assertCodeEquals(Status.UNAVAILABLE, statusCaptor.getValue());
if (metricsExpected()) { if (metricsExpected()) {
inOrder.verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer2.getStatus());
} }
// Make sure earlier stream works. // Make sure earlier stream works.
@ -568,8 +574,9 @@ public abstract class AbstractTransportTest {
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracerFactory).newClientStreamTracer( verify(clientStreamTracerFactory).newClientStreamTracer(
any(CallOptions.class), any(Metadata.class)); any(CallOptions.class), any(Metadata.class));
verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
verifyZeroInteractions(serverStreamTracerFactory); // Assert no interactions
assertNull(serverStreamTracer1.getServerCall());
} }
} }
@ -624,8 +631,8 @@ public abstract class AbstractTransportTest {
@Test @Test
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
public void basicStream() throws Exception { public void basicStream() throws Exception {
InOrder clientInOrder = inOrder(clientStreamTracerFactory, clientStreamTracer); InOrder clientInOrder = inOrder(clientStreamTracerFactory);
InOrder serverInOrder = inOrder(serverStreamTracerFactory, serverStreamTracer); InOrder serverInOrder = inOrder(serverStreamTracerFactory);
server.start(serverListener); server.start(serverListener);
client = newClientTransport(server); client = newClientTransport(server);
runIfNotNull(client.start(mockClientTransportListener)); runIfNotNull(client.start(mockClientTransportListener));
@ -651,7 +658,7 @@ public abstract class AbstractTransportTest {
StreamCreation serverStreamCreation StreamCreation serverStreamCreation
= serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS);
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer, timeout(TIMEOUT_MS)).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
} }
assertEquals(methodDescriptor.getFullMethodName(), serverStreamCreation.method); assertEquals(methodDescriptor.getFullMethodName(), serverStreamCreation.method);
assertEquals(Lists.newArrayList(clientHeadersCopy.getAll(asciiKey)), assertEquals(Lists.newArrayList(clientHeadersCopy.getAll(asciiKey)),
@ -675,15 +682,16 @@ public abstract class AbstractTransportTest {
assertTrue(clientStream.isReady()); assertTrue(clientStream.isReady());
clientStream.writeMessage(methodDescriptor.streamRequest("Hello!")); clientStream.writeMessage(methodDescriptor.streamRequest("Hello!"));
if (metricsExpected()) { if (metricsExpected()) {
clientInOrder.verify(clientStreamTracer).outboundMessage(); assertThat(clientStreamTracer1.getOutboundMessageCount()).isGreaterThan(0);
} }
clientStream.flush(); clientStream.flush();
verify(mockServerStreamListener, timeout(TIMEOUT_MS)).messageRead(inputStreamCaptor.capture()); verify(mockServerStreamListener, timeout(TIMEOUT_MS)).messageRead(inputStreamCaptor.capture());
if (metricsExpected()) { if (metricsExpected()) {
clientInOrder.verify(clientStreamTracer).outboundWireSize(anyLong()); assertThat(clientStreamTracer1.getOutboundMessageCount()).isGreaterThan(0);
clientInOrder.verify(clientStreamTracer).outboundUncompressedSize(anyLong()); assertThat(clientStreamTracer1.getOutboundWireSize()).isGreaterThan(0L);
serverInOrder.verify(serverStreamTracer).inboundMessage(); assertThat(clientStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L);
assertEquals(1, serverStreamTracer1.getInboundMessageCount());
} }
assertEquals("Hello!", methodDescriptor.parseRequest(inputStreamCaptor.getValue())); assertEquals("Hello!", methodDescriptor.parseRequest(inputStreamCaptor.getValue()));
inputStreamCaptor.getValue().close(); inputStreamCaptor.getValue().close();
@ -692,8 +700,8 @@ public abstract class AbstractTransportTest {
verify(mockServerStreamListener, timeout(TIMEOUT_MS)).halfClosed(); verify(mockServerStreamListener, timeout(TIMEOUT_MS)).halfClosed();
if (metricsExpected()) { if (metricsExpected()) {
serverInOrder.verify(serverStreamTracer).inboundWireSize(anyLong()); assertThat(serverStreamTracer1.getInboundWireSize()).isGreaterThan(0L);
serverInOrder.verify(serverStreamTracer, atLeast(1)).inboundUncompressedSize(anyLong()); assertThat(serverStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L);
} }
Metadata serverHeaders = new Metadata(); Metadata serverHeaders = new Metadata();
@ -715,21 +723,21 @@ public abstract class AbstractTransportTest {
assertTrue(serverStream.isReady()); assertTrue(serverStream.isReady());
serverStream.writeMessage(methodDescriptor.streamResponse("Hi. Who are you?")); serverStream.writeMessage(methodDescriptor.streamResponse("Hi. Who are you?"));
if (metricsExpected()) { if (metricsExpected()) {
serverInOrder.verify(serverStreamTracer).outboundMessage(); assertEquals(1, serverStreamTracer1.getOutboundMessageCount());
} }
serverStream.flush(); serverStream.flush();
verify(mockClientStreamListener, timeout(TIMEOUT_MS)).messageRead(inputStreamCaptor.capture()); verify(mockClientStreamListener, timeout(TIMEOUT_MS)).messageRead(inputStreamCaptor.capture());
if (metricsExpected()) { if (metricsExpected()) {
serverInOrder.verify(serverStreamTracer).outboundWireSize(anyLong()); assertThat(serverStreamTracer1.getOutboundWireSize()).isGreaterThan(0L);
serverInOrder.verify(serverStreamTracer, atLeast(1)).outboundUncompressedSize(anyLong()); assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L);
clientInOrder.verify(clientStreamTracer).inboundHeaders(); assertTrue(clientStreamTracer1.getInboundHeaders());
clientInOrder.verify(clientStreamTracer).inboundMessage(); assertThat(clientStreamTracer1.getInboundMessageCount()).isGreaterThan(0);
} }
assertEquals("Hi. Who are you?", methodDescriptor.parseResponse(inputStreamCaptor.getValue())); assertEquals("Hi. Who are you?", methodDescriptor.parseResponse(inputStreamCaptor.getValue()));
if (metricsExpected()) { if (metricsExpected()) {
clientInOrder.verify(clientStreamTracer).inboundWireSize(anyLong()); assertThat(clientStreamTracer1.getInboundWireSize()).isGreaterThan(0L);
clientInOrder.verify(clientStreamTracer, atLeast(1)).inboundUncompressedSize(anyLong()); assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L);
} }
inputStreamCaptor.getValue().close(); inputStreamCaptor.getValue().close();
@ -741,14 +749,14 @@ public abstract class AbstractTransportTest {
trailers.put(binaryKey, "äbinarytrailers"); trailers.put(binaryKey, "äbinarytrailers");
serverStream.close(status, trailers); serverStream.close(status, trailers);
if (metricsExpected()) { if (metricsExpected()) {
serverInOrder.verify(serverStreamTracer).streamClosed(same(status)); assertSame(status, serverStreamTracer1.getStatus());
} }
verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture());
assertCodeEquals(Status.OK, statusCaptor.getValue()); assertCodeEquals(Status.OK, statusCaptor.getValue());
verify(mockClientStreamListener, timeout(TIMEOUT_MS)) verify(mockClientStreamListener, timeout(TIMEOUT_MS))
.closed(statusCaptor.capture(), metadataCaptor.capture()); .closed(statusCaptor.capture(), metadataCaptor.capture());
if (metricsExpected()) { if (metricsExpected()) {
clientInOrder.verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
} }
assertEquals(status.getCode(), statusCaptor.getValue().getCode()); assertEquals(status.getCode(), statusCaptor.getValue().getCode());
assertEquals(status.getDescription(), statusCaptor.getValue().getDescription()); assertEquals(status.getDescription(), statusCaptor.getValue().getDescription());
@ -808,12 +816,10 @@ public abstract class AbstractTransportTest {
assertEquals(status.getCode(), statusCaptor.getValue().getCode()); assertEquals(status.getCode(), statusCaptor.getValue().getCode());
assertEquals(status.getDescription(), statusCaptor.getValue().getDescription()); assertEquals(status.getDescription(), statusCaptor.getValue().getDescription());
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
verify(clientStreamTracer).inboundHeaders(); assertTrue(clientStreamTracer1.getInboundHeaders());
verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
verify(serverStreamTracer).streamClosed(same(status)); assertSame(status, serverStreamTracer1.getStatus());
verifyNoMoreInteractions(clientStreamTracer);
verifyNoMoreInteractions(serverStreamTracer);
} }
} }
@ -846,12 +852,10 @@ public abstract class AbstractTransportTest {
assertEquals("Hello. Goodbye.", statusCaptor.getValue().getDescription()); assertEquals("Hello. Goodbye.", statusCaptor.getValue().getDescription());
assertNull(statusCaptor.getValue().getCause()); assertNull(statusCaptor.getValue().getCause());
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
verify(clientStreamTracer).inboundHeaders(); assertTrue(clientStreamTracer1.getInboundHeaders());
verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
verify(serverStreamTracer).streamClosed(same(status)); assertSame(status, serverStreamTracer1.getStatus());
verifyNoMoreInteractions(clientStreamTracer);
verifyNoMoreInteractions(serverStreamTracer);
} }
} }
@ -891,11 +895,9 @@ public abstract class AbstractTransportTest {
assertEquals(Lists.newArrayList(trailers.getAll(binaryKey)), assertEquals(Lists.newArrayList(trailers.getAll(binaryKey)),
Lists.newArrayList(metadataCaptor.getValue().getAll(binaryKey))); Lists.newArrayList(metadataCaptor.getValue().getAll(binaryKey)));
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
verify(serverStreamTracer).streamClosed(same(status)); assertSame(status, serverStreamTracer1.getStatus());
verifyNoMoreInteractions(clientStreamTracer);
verifyNoMoreInteractions(serverStreamTracer);
} }
} }
@ -926,11 +928,9 @@ public abstract class AbstractTransportTest {
assertEquals(status.getDescription(), statusCaptor.getValue().getDescription()); assertEquals(status.getDescription(), statusCaptor.getValue().getDescription());
assertNull(statusCaptor.getValue().getCause()); assertNull(statusCaptor.getValue().getCause());
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
verify(serverStreamTracer).streamClosed(same(status)); assertSame(status, serverStreamTracer1.getStatus());
verifyNoMoreInteractions(clientStreamTracer);
verifyNoMoreInteractions(serverStreamTracer);
} }
} }
@ -964,11 +964,9 @@ public abstract class AbstractTransportTest {
verify(mockServerStreamListener, never()).closed(any(Status.class)); verify(mockServerStreamListener, never()).closed(any(Status.class));
verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class)); verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class));
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
verify(clientStreamTracer).streamClosed(same(status)); assertSame(status, clientStreamTracer1.getStatus());
verify(serverStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), serverStreamTracer1.getStatus());
verifyNoMoreInteractions(clientStreamTracer);
verifyNoMoreInteractions(serverStreamTracer);
} }
} }
@ -1027,20 +1025,19 @@ public abstract class AbstractTransportTest {
serverStream.close(Status.OK, new Metadata()); serverStream.close(Status.OK, new Metadata());
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracer).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
verify(clientStreamTracer).inboundHeaders(); assertTrue(clientStreamTracer1.getInboundHeaders());
verify(clientStreamTracer).inboundMessage(); assertEquals(1, clientStreamTracer1.getInboundMessageCount());
verify(clientStreamTracer).inboundWireSize(anyLong()); assertThat(clientStreamTracer1.getInboundWireSize()).isGreaterThan(0L);
verify(clientStreamTracer, atLeast(1)).inboundUncompressedSize(anyLong()); assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L);
verify(clientStreamTracer).streamClosed(same(status)); assertSame(status, clientStreamTracer1.getStatus());
verify(serverStreamTracer).outboundMessage(); assertEquals(1, serverStreamTracer1.getOutboundMessageCount());
verify(serverStreamTracer).outboundWireSize(anyLong()); assertThat(serverStreamTracer1.getOutboundWireSize()).isGreaterThan(0L);
verify(serverStreamTracer, atLeast(1)).outboundUncompressedSize(anyLong()); assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L);
// There is a race between client cancelling and server closing. The final status seen by the // There is a race between client cancelling and server closing. The final status seen by the
// server is non-deterministic. // server is non-deterministic.
verify(serverStreamTracer, timeout(TIMEOUT_MS)).streamClosed(any(Status.class)); assertTrue(serverStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
verifyNoMoreInteractions(clientStreamTracer); assertNotNull(serverStreamTracer1.getStatus());
verifyNoMoreInteractions(serverStreamTracer);
} }
} }
@ -1075,12 +1072,10 @@ public abstract class AbstractTransportTest {
if (metricsExpected()) { if (metricsExpected()) {
verify(clientStreamTracerFactory).newClientStreamTracer( verify(clientStreamTracerFactory).newClientStreamTracer(
any(CallOptions.class), any(Metadata.class)); any(CallOptions.class), any(Metadata.class));
verify(clientStreamTracer).outboundHeaders(); assertTrue(clientStreamTracer1.getOutboundHeaders());
verify(clientStreamTracer).streamClosed(same(statusCaptor.getValue())); assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus());
verify(serverStreamTracerFactory).newServerStreamTracer(anyString(), any(Metadata.class)); verify(serverStreamTracerFactory).newServerStreamTracer(anyString(), any(Metadata.class));
verify(serverStreamTracer).streamClosed(same(status)); assertSame(status, serverStreamTracer1.getStatus());
verifyNoMoreInteractions(clientStreamTracer);
verifyNoMoreInteractions(serverStreamTracer);
} }
// Second cancellation shouldn't trigger additional callbacks // Second cancellation shouldn't trigger additional callbacks

View File

@ -0,0 +1,140 @@
/*
* Copyright 2017, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal.testing;
import io.grpc.ClientStreamTracer;
import io.grpc.Status;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* A {@link ClientStreamTracer} suitable for testing.
*/
public class TestClientStreamTracer extends ClientStreamTracer implements TestStreamTracer {
private final TestBaseStreamTracer delegate = new TestBaseStreamTracer();
protected final AtomicBoolean outboundHeadersCalled = new AtomicBoolean();
protected final AtomicBoolean inboundHeadersCalled = new AtomicBoolean();
@Override
public void await() throws InterruptedException {
delegate.await();
}
@Override
public boolean await(long timeout, TimeUnit timeUnit) throws InterruptedException {
return delegate.await(timeout, timeUnit);
}
/**
* Returns if {@link ClientStreamTracer#inboundHeaders} has been called.
*/
public boolean getInboundHeaders() {
return inboundHeadersCalled.get();
}
/**
* Returns if {@link ClientStreamTracer#outboundHeaders} has been called.
*/
public boolean getOutboundHeaders() {
return outboundHeadersCalled.get();
}
@Override
public int getInboundMessageCount() {
return delegate.getInboundMessageCount();
}
@Override
public Status getStatus() {
return delegate.getStatus();
}
@Override
public long getInboundWireSize() {
return delegate.getInboundWireSize();
}
@Override
public long getInboundUncompressedSize() {
return delegate.getInboundUncompressedSize();
}
@Override
public int getOutboundMessageCount() {
return delegate.getOutboundMessageCount();
}
@Override
public long getOutboundWireSize() {
return delegate.getOutboundWireSize();
}
@Override
public long getOutboundUncompressedSize() {
return delegate.getOutboundUncompressedSize();
}
@Override
public void outboundWireSize(long bytes) {
delegate.outboundWireSize(bytes);
}
@Override
public void inboundWireSize(long bytes) {
delegate.inboundWireSize(bytes);
}
@Override
public void outboundUncompressedSize(long bytes) {
delegate.outboundUncompressedSize(bytes);
}
@Override
public void inboundUncompressedSize(long bytes) {
delegate.inboundUncompressedSize(bytes);
}
@Override
public void streamClosed(Status status) {
delegate.streamClosed(status);
}
@Override
public void inboundMessage() {
delegate.inboundMessage();
}
@Override
public void outboundMessage() {
delegate.outboundMessage();
}
@Override
public void outboundHeaders() {
if (!outboundHeadersCalled.compareAndSet(false, true)
&& delegate.failDuplicateCallbacks.get()) {
throw new AssertionError("outboundHeaders called more than once");
}
}
@Override
public void inboundHeaders() {
if (!inboundHeadersCalled.compareAndSet(false, true) && delegate.failDuplicateCallbacks.get()) {
throw new AssertionError("inboundHeaders called more than once");
}
}
}

View File

@ -0,0 +1,126 @@
/*
* Copyright 2017, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal.testing;
import io.grpc.ServerCall;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
/**
* A {@link ServerStreamTracer} suitable for testing.
*/
public class TestServerStreamTracer extends ServerStreamTracer implements TestStreamTracer {
private final TestBaseStreamTracer delegate = new TestBaseStreamTracer();
protected final AtomicReference<ServerCall<?,?>> serverCall =
new AtomicReference<ServerCall<?,?>>();
@Override
public void await() throws InterruptedException {
delegate.await();
}
@Override
public boolean await(long timeout, TimeUnit timeUnit) throws InterruptedException {
return delegate.await(timeout, timeUnit);
}
/**
* Returns the ServerCall passed to {@link ServerStreamTracer#serverCallStarted}.
*/
public ServerCall<?, ?> getServerCall() {
return serverCall.get();
}
@Override
public int getInboundMessageCount() {
return delegate.getInboundMessageCount();
}
@Override
public Status getStatus() {
return delegate.getStatus();
}
@Override
public long getInboundWireSize() {
return delegate.getInboundWireSize();
}
@Override
public long getInboundUncompressedSize() {
return delegate.getInboundUncompressedSize();
}
@Override
public int getOutboundMessageCount() {
return delegate.getOutboundMessageCount();
}
@Override
public long getOutboundWireSize() {
return delegate.getOutboundWireSize();
}
@Override
public long getOutboundUncompressedSize() {
return delegate.getOutboundUncompressedSize();
}
@Override
public void outboundWireSize(long bytes) {
delegate.outboundWireSize(bytes);
}
@Override
public void inboundWireSize(long bytes) {
delegate.inboundWireSize(bytes);
}
@Override
public void outboundUncompressedSize(long bytes) {
delegate.outboundUncompressedSize(bytes);
}
@Override
public void inboundUncompressedSize(long bytes) {
delegate.inboundUncompressedSize(bytes);
}
@Override
public void streamClosed(Status status) {
delegate.streamClosed(status);
}
@Override
public void inboundMessage() {
delegate.inboundMessage();
}
@Override
public void outboundMessage() {
delegate.outboundMessage();
}
@Override
public void serverCallStarted(ServerCall<?, ?> call) {
if (!serverCall.compareAndSet(null, call) && delegate.failDuplicateCallbacks.get()) {
throw new AssertionError("serverCallStarted called more than once");
}
}
}

View File

@ -0,0 +1,179 @@
/*
* Copyright 2017, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal.testing;
import io.grpc.Status;
import io.grpc.StreamTracer;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
/**
* A {@link StreamTracer} suitable for testing.
*/
public interface TestStreamTracer {
/**
* Waits for the stream to be done.
*/
void await() throws InterruptedException;
/**
* Waits for the stream to be done.
*/
boolean await(long timeout, TimeUnit timeUnit) throws InterruptedException;
/**
* Returns how many times {@link StreamTracer#inboundMessage} has been called.
*/
int getInboundMessageCount();
/**
* Returns how many times {@link StreamTracer#outboundMessage} has been called.
*/
int getOutboundMessageCount();
/**
* Returns the status passed to {@link StreamTracer#streamClosed}.
*/
Status getStatus();
/**
* Returns to sum of all sizes passed to {@link StreamTracer#inboundWireSize}.
*/
long getInboundWireSize();
/**
* Returns to sum of all sizes passed to {@link StreamTracer#inboundUncompressedSize}.
*/
long getInboundUncompressedSize();
/**
* Returns to sum of all sizes passed to {@link StreamTracer#outboundWireSize}.
*/
long getOutboundWireSize();
/**
* Returns to sum of al sizes passed to {@link StreamTracer#outboundUncompressedSize}.
*/
long getOutboundUncompressedSize();
/**
* A {@link StreamTracer} suitable for testing.
*/
public static class TestBaseStreamTracer extends StreamTracer implements TestStreamTracer {
protected final AtomicLong outboundWireSize = new AtomicLong();
protected final AtomicLong inboundWireSize = new AtomicLong();
protected final AtomicLong outboundUncompressedSize = new AtomicLong();
protected final AtomicLong inboundUncompressedSize = new AtomicLong();
protected final AtomicInteger inboundMessageCount = new AtomicInteger();
protected final AtomicInteger outboundMessageCount = new AtomicInteger();
protected final AtomicReference<Status> streamClosedStatus = new AtomicReference<Status>();
protected final CountDownLatch streamClosed = new CountDownLatch(1);
protected final AtomicBoolean failDuplicateCallbacks = new AtomicBoolean(true);
@Override
public void await() throws InterruptedException {
streamClosed.await();
}
@Override
public boolean await(long timeout, TimeUnit timeUnit) throws InterruptedException {
return streamClosed.await(timeout, timeUnit);
}
@Override
public int getInboundMessageCount() {
return inboundMessageCount.get();
}
@Override
public int getOutboundMessageCount() {
return outboundMessageCount.get();
}
@Override
public Status getStatus() {
return streamClosedStatus.get();
}
@Override
public long getInboundWireSize() {
return inboundWireSize.get();
}
@Override
public long getInboundUncompressedSize() {
return inboundUncompressedSize.get();
}
@Override
public long getOutboundWireSize() {
return outboundWireSize.get();
}
@Override
public long getOutboundUncompressedSize() {
return outboundUncompressedSize.get();
}
@Override
public void outboundWireSize(long bytes) {
outboundWireSize.addAndGet(bytes);
}
@Override
public void inboundWireSize(long bytes) {
inboundWireSize.addAndGet(bytes);
}
@Override
public void outboundUncompressedSize(long bytes) {
outboundUncompressedSize.addAndGet(bytes);
}
@Override
public void inboundUncompressedSize(long bytes) {
inboundUncompressedSize.addAndGet(bytes);
}
@Override
public void streamClosed(Status status) {
if (!streamClosedStatus.compareAndSet(null, status)) {
if (failDuplicateCallbacks.get()) {
throw new AssertionError("streamClosed called more than once");
}
} else {
streamClosed.countDown();
}
}
@Override
public void inboundMessage() {
inboundMessageCount.incrementAndGet();
}
@Override
public void outboundMessage() {
outboundMessageCount.incrementAndGet();
}
}
}