diff --git a/core/src/main/java/io/grpc/StreamTracer.java b/core/src/main/java/io/grpc/StreamTracer.java index 63b4c97a46..bf1b33f5f4 100644 --- a/core/src/main/java/io/grpc/StreamTracer.java +++ b/core/src/main/java/io/grpc/StreamTracer.java @@ -16,6 +16,7 @@ package io.grpc; +import com.google.errorprone.annotations.DoNotMock; import javax.annotation.concurrent.ThreadSafe; /** @@ -23,6 +24,7 @@ import javax.annotation.concurrent.ThreadSafe; */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") @ThreadSafe +@DoNotMock public abstract class StreamTracer { /** * Stream is closed. This will be called exactly once. diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index 96b75aa0f0..660ecef79d 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -17,14 +17,13 @@ package io.grpc.internal; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; import io.grpc.Attributes; import io.grpc.CallOptions; @@ -36,6 +35,7 @@ import io.grpc.Status.Code; import io.grpc.StreamTracer; import io.grpc.internal.AbstractClientStream.TransportState; import io.grpc.internal.MessageFramerTest.ByteWritableBuffer; +import io.grpc.internal.testing.TestClientStreamTracer; import java.io.ByteArrayInputStream; import org.junit.Before; import org.junit.Rule; @@ -213,7 +213,7 @@ public class AbstractClientStreamTest { @Test public void getRequest() { AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class); - final ClientStreamTracer tracer = spy(new ClientStreamTracer() {}); + final TestClientStreamTracer tracer = new TestClientStreamTracer(); ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { @Override @@ -237,10 +237,9 @@ public class AbstractClientStreamTest { // GET requests don't have BODY. verify(sink, never()) .writeFrame(any(WritableBuffer.class), any(Boolean.class), any(Boolean.class)); - verify(tracer).outboundMessage(); - verify(tracer).outboundWireSize(1); - verify(tracer).outboundUncompressedSize(1); - verifyNoMoreInteractions(tracer); + assertEquals(1, tracer.getOutboundMessageCount()); + assertEquals(1, tracer.getOutboundWireSize()); + assertEquals(1, tracer.getOutboundUncompressedSize()); } /** diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index dbc0325460..9310e7b359 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -20,7 +20,6 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.anyInt; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -36,6 +35,7 @@ import io.grpc.StatusRuntimeException; import io.grpc.StreamTracer; import io.grpc.internal.MessageDeframer.Listener; import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream; +import io.grpc.internal.testing.TestStreamTracer.TestBaseStreamTracer; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -60,7 +60,7 @@ public class MessageDeframerTest { @Rule public final ExpectedException thrown = ExpectedException.none(); private Listener listener = mock(Listener.class); - private StreamTracer tracer = mock(StreamTracer.class); + private TestBaseStreamTracer tracer = new TestBaseStreamTracer(); private StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[]{tracer}); private ArgumentCaptor wireSizeCaptor = ArgumentCaptor.forClass(Long.class); private ArgumentCaptor uncompressedSizeCaptor = ArgumentCaptor.forClass(Long.class); @@ -374,23 +374,9 @@ public class MessageDeframerTest { private void checkStats( int messagesReceived, long wireBytesReceived, long uncompressedBytesReceived) { - long actualWireSize = 0; - long actualUncompressedSize = 0; - - verify(tracer, times(messagesReceived)).inboundMessage(); - verify(tracer, atLeast(0)).inboundWireSize(wireSizeCaptor.capture()); - for (Long portion : wireSizeCaptor.getAllValues()) { - actualWireSize += portion; - } - - verify(tracer, atLeast(0)).inboundUncompressedSize(uncompressedSizeCaptor.capture()); - for (Long portion : uncompressedSizeCaptor.getAllValues()) { - actualUncompressedSize += portion; - } - - verifyNoMoreInteractions(tracer); - assertEquals(wireBytesReceived, actualWireSize); - assertEquals(uncompressedBytesReceived, actualUncompressedSize); + assertEquals(messagesReceived, tracer.getInboundMessageCount()); + assertEquals(wireBytesReceived, tracer.getInboundWireSize()); + assertEquals(uncompressedBytesReceived, tracer.getInboundUncompressedSize()); } private static List bytes(ArgumentCaptor captor) { diff --git a/core/src/test/java/io/grpc/internal/MessageFramerTest.java b/core/src/test/java/io/grpc/internal/MessageFramerTest.java index 4554d4c278..68ed71ff6f 100644 --- a/core/src/test/java/io/grpc/internal/MessageFramerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageFramerTest.java @@ -20,7 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -28,6 +27,7 @@ import static org.mockito.Mockito.verifyZeroInteractions; import io.grpc.Codec; import io.grpc.StreamTracer; +import io.grpc.internal.testing.TestStreamTracer.TestBaseStreamTracer; import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; import java.nio.ByteBuffer; @@ -49,8 +49,8 @@ import org.mockito.MockitoAnnotations; public class MessageFramerTest { @Mock private MessageFramer.Sink sink; - @Mock - private StreamTracer tracer; + + private final TestBaseStreamTracer tracer = new TestBaseStreamTracer(); private MessageFramer framer; @Captor @@ -371,20 +371,9 @@ public class MessageFramerTest { long actualWireSize = 0; long actualUncompressedSize = 0; - verify(tracer, times(messagesSent)).outboundMessage(); - verify(tracer, atLeast(0)).outboundWireSize(wireSizeCaptor.capture()); - for (Long portion : wireSizeCaptor.getAllValues()) { - actualWireSize += portion; - } - - verify(tracer, atLeast(0)).outboundUncompressedSize(uncompressedSizeCaptor.capture()); - for (Long portion : uncompressedSizeCaptor.getAllValues()) { - actualUncompressedSize += portion; - } - - verifyNoMoreInteractions(tracer); - assertEquals(wireBytesSent, actualWireSize); - assertEquals(uncompressedBytesSent, actualUncompressedSize); + assertEquals(messagesSent, tracer.getOutboundMessageCount()); + assertEquals(uncompressedBytesSent, tracer.getOutboundUncompressedSize()); + assertEquals(wireBytesSent, tracer.getOutboundWireSize()); } static class ByteWritableBuffer implements WritableBuffer { diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 4d2fc56417..e24d619fa5 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -36,7 +36,6 @@ import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -62,6 +61,7 @@ import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.StringMarshaller; import io.grpc.internal.ServerImpl.JumpToApplicationThreadServerStreamListener; +import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.util.MutableHandlerRegistry; import java.io.ByteArrayInputStream; import java.io.File; @@ -129,12 +129,13 @@ public class ServerImplTest { @Mock private ServerStreamTracer.Factory streamTracerFactory; private List streamTracerFactories; - private final ServerStreamTracer streamTracer = spy(new ServerStreamTracer() { + private final TestServerStreamTracer streamTracer = new TestServerStreamTracer() { @Override public Context filterContext(Context context) { - return context.withValue(SERVER_TRACER_ADDED_KEY, "context added by tracer"); + Context newCtx = super.filterContext(context); + return newCtx.withValue(SERVER_TRACER_ADDED_KEY, "context added by tracer"); } - }); + }; @Mock private ObjectPool executorPool; private Builder builder = new Builder(); @@ -365,7 +366,7 @@ public class ServerImplTest { assertEquals("Method not found: Waiter/nonexist", status.getDescription()); verify(streamTracerFactory).newServerStreamTracer(eq("Waiter/nonexist"), same(requestHeaders)); - verify(streamTracer, never()).serverCallStarted(any(ServerCall.class)); + assertNull(streamTracer.getServerCall()); assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode()); } @@ -435,7 +436,7 @@ public class ServerImplTest { assertEquals(1, executor.runDueTasks()); ServerCall call = callReference.get(); assertNotNull(call); - verify(streamTracer).serverCallStarted(same(call)); + assertSame(call, streamTracer.getServerCall()); verify(stream).getAuthority(); Context callContext = callContextReference.get(); assertNotNull(callContext); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index b6dc9e87f6..1224f62e17 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -25,10 +25,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; @@ -66,12 +63,14 @@ import io.grpc.ServerInterceptors; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.StatusRuntimeException; -import io.grpc.StreamTracer; import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.AbstractServerImplBuilder; import io.grpc.internal.GrpcUtil; import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory; import io.grpc.internal.testing.StatsTestUtils.MetricsRecord; +import io.grpc.internal.testing.TestClientStreamTracer; +import io.grpc.internal.testing.TestServerStreamTracer; +import io.grpc.internal.testing.TestStreamTracer; import io.grpc.protobuf.ProtoUtils; import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientCalls; @@ -139,21 +138,32 @@ public abstract class AbstractInteropTest { private static final LinkedBlockingQueue serverStreamTracers = new LinkedBlockingQueue(); - private static class ServerStreamTracerInfo { + private static final class ServerStreamTracerInfo { final String fullMethodName; - final ServerStreamTracer tracer; + final InteropServerStreamTracer tracer; - ServerStreamTracerInfo(String fullMethodName, ServerStreamTracer tracer) { + ServerStreamTracerInfo(String fullMethodName, InteropServerStreamTracer tracer) { this.fullMethodName = fullMethodName; this.tracer = tracer; } + + private static final class InteropServerStreamTracer extends TestServerStreamTracer { + private volatile Context contextCapture; + + @Override + public Context filterContext(Context context) { + contextCapture = context; + return super.filterContext(context); + } + } } private static final ServerStreamTracer.Factory serverStreamTracerFactory = new ServerStreamTracer.Factory() { @Override public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { - ServerStreamTracer tracer = spy(new ServerStreamTracer() {}); + ServerStreamTracerInfo.InteropServerStreamTracer tracer + = new ServerStreamTracerInfo.InteropServerStreamTracer(); serverStreamTracers.add(new ServerStreamTracerInfo(fullMethodName, tracer)); return tracer; } @@ -200,14 +210,14 @@ public abstract class AbstractInteropTest { protected TestServiceGrpc.TestServiceBlockingStub blockingStub; protected TestServiceGrpc.TestServiceStub asyncStub; - private final LinkedBlockingQueue clientStreamTracers = - new LinkedBlockingQueue(); + private final LinkedBlockingQueue clientStreamTracers = + new LinkedBlockingQueue(); private final ClientStreamTracer.Factory clientStreamTracerFactory = new ClientStreamTracer.Factory() { @Override public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) { - ClientStreamTracer tracer = spy(new ClientStreamTracer() {}); + TestClientStreamTracer tracer = new TestClientStreamTracer(); clientStreamTracers.add(tracer); return tracer; } @@ -1655,22 +1665,26 @@ public abstract class AbstractInteropTest { assertMetrics(method, status, null, null); } - private void assertClientMetrics(String method, Status.Code status, + private void assertClientMetrics(String method, Status.Code code, Collection requests, Collection responses) { // Tracer-based stats - ClientStreamTracer tracer = clientStreamTracers.poll(); + TestClientStreamTracer tracer = clientStreamTracers.poll(); assertNotNull(tracer); - verify(tracer).outboundHeaders(); + assertTrue(tracer.getOutboundHeaders()); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); // assertClientMetrics() is called right after application receives status, // but streamClosed() may be called slightly later than that. So we need a timeout. - verify(tracer, timeout(5000)).streamClosed(statusCaptor.capture()); - assertEquals(status, statusCaptor.getValue().getCode()); + try { + assertTrue(tracer.await(5, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + assertEquals(code, tracer.getStatus().getCode()); // CensusStreamTracerModule records final status in interceptor, which is guaranteed to be done // before application receives status. MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(); - checkTags(clientRecord, false, method, status); + checkTags(clientRecord, false, method, code); if (requests != null && responses != null) { checkTracerMetrics(tracer, requests, responses); @@ -1682,7 +1696,7 @@ public abstract class AbstractInteropTest { assertClientMetrics(method, status, null, null); } - private void assertServerMetrics(String method, Status.Code status, + private void assertServerMetrics(String method, Status.Code code, Collection requests, Collection responses) { AssertionError checkFailure = null; boolean passed = false; @@ -1703,7 +1717,7 @@ public abstract class AbstractInteropTest { break; } try { - checkTags(serverRecord, true, method, status); + checkTags(serverRecord, true, method, code); if (requests != null && responses != null) { checkCensusMetrics(serverRecord, true, requests, responses); } @@ -1731,12 +1745,16 @@ public abstract class AbstractInteropTest { } try { assertEquals(method, tracerInfo.fullMethodName); - verify(tracerInfo.tracer).filterContext(any(Context.class)); + assertNotNull(tracerInfo.tracer.contextCapture); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); // On the server, streamClosed() may be called after the client receives the final status. // So we use a timeout. - verify(tracerInfo.tracer, timeout(1000)).streamClosed(statusCaptor.capture()); - assertEquals(status, statusCaptor.getValue().getCode()); + try { + assertTrue(tracerInfo.tracer.await(1, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + assertEquals(code, tracerInfo.tracer.getStatus().getCode()); if (requests != null && responses != null) { checkTracerMetrics(tracerInfo.tracer, responses, requests); } @@ -1768,11 +1786,11 @@ public abstract class AbstractInteropTest { } private static void checkTracerMetrics( - StreamTracer tracer, + TestStreamTracer tracer, Collection sentMessages, Collection receivedMessages) { - verify(tracer, times(sentMessages.size())).outboundMessage(); - verify(tracer, times(receivedMessages.size())).inboundMessage(); + assertEquals(sentMessages.size(), tracer.getOutboundMessageCount()); + assertEquals(receivedMessages.size(), tracer.getInboundMessageCount()); long uncompressedSentSize = 0; for (MessageLite msg : sentMessages) { @@ -1782,20 +1800,9 @@ public abstract class AbstractInteropTest { for (MessageLite msg : receivedMessages) { uncompressedReceivedSize += msg.getSerializedSize(); } - ArgumentCaptor outboundSizeCaptor = ArgumentCaptor.forClass(Long.class); - ArgumentCaptor inboundSizeCaptor = ArgumentCaptor.forClass(Long.class); - verify(tracer, atLeast(0)).outboundUncompressedSize(outboundSizeCaptor.capture()); - verify(tracer, atLeast(0)).inboundUncompressedSize(inboundSizeCaptor.capture()); - long recordedUncompressedOutboundSize = 0; - for (Long size : outboundSizeCaptor.getAllValues()) { - recordedUncompressedOutboundSize += size; - } - long recordedUncompressedInboundSize = 0; - for (Long size : inboundSizeCaptor.getAllValues()) { - recordedUncompressedInboundSize += size; - } - assertEquals(uncompressedSentSize, recordedUncompressedOutboundSize); - assertEquals(uncompressedReceivedSize, recordedUncompressedInboundSize); + + assertEquals(uncompressedSentSize, tracer.getOutboundUncompressedSize()); + assertEquals(uncompressedReceivedSize, tracer.getInboundUncompressedSize()); } private static void checkCensusMetrics(MetricsRecord record, boolean server, diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index ea47200f4a..891553dbe6 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -64,6 +64,7 @@ import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransportListener; import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; @@ -106,7 +107,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase> serverCall = + new AtomicReference>(); + + @Override + public void await() throws InterruptedException { + delegate.await(); + } + + @Override + public boolean await(long timeout, TimeUnit timeUnit) throws InterruptedException { + return delegate.await(timeout, timeUnit); + } + + /** + * Returns the ServerCall passed to {@link ServerStreamTracer#serverCallStarted}. + */ + public ServerCall getServerCall() { + return serverCall.get(); + } + + @Override + public int getInboundMessageCount() { + return delegate.getInboundMessageCount(); + } + + @Override + public Status getStatus() { + return delegate.getStatus(); + } + + @Override + public long getInboundWireSize() { + return delegate.getInboundWireSize(); + } + + @Override + public long getInboundUncompressedSize() { + return delegate.getInboundUncompressedSize(); + } + + @Override + public int getOutboundMessageCount() { + return delegate.getOutboundMessageCount(); + } + + @Override + public long getOutboundWireSize() { + return delegate.getOutboundWireSize(); + } + + @Override + public long getOutboundUncompressedSize() { + return delegate.getOutboundUncompressedSize(); + } + + @Override + public void outboundWireSize(long bytes) { + delegate.outboundWireSize(bytes); + } + + @Override + public void inboundWireSize(long bytes) { + delegate.inboundWireSize(bytes); + } + + @Override + public void outboundUncompressedSize(long bytes) { + delegate.outboundUncompressedSize(bytes); + } + + @Override + public void inboundUncompressedSize(long bytes) { + delegate.inboundUncompressedSize(bytes); + } + + @Override + public void streamClosed(Status status) { + delegate.streamClosed(status); + } + + @Override + public void inboundMessage() { + delegate.inboundMessage(); + } + + @Override + public void outboundMessage() { + delegate.outboundMessage(); + } + + @Override + public void serverCallStarted(ServerCall call) { + if (!serverCall.compareAndSet(null, call) && delegate.failDuplicateCallbacks.get()) { + throw new AssertionError("serverCallStarted called more than once"); + } + } +} diff --git a/testing/src/main/java/io/grpc/internal/testing/TestStreamTracer.java b/testing/src/main/java/io/grpc/internal/testing/TestStreamTracer.java new file mode 100644 index 0000000000..1a98a0f031 --- /dev/null +++ b/testing/src/main/java/io/grpc/internal/testing/TestStreamTracer.java @@ -0,0 +1,179 @@ +/* + * Copyright 2017, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal.testing; + +import io.grpc.Status; +import io.grpc.StreamTracer; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A {@link StreamTracer} suitable for testing. + */ +public interface TestStreamTracer { + + /** + * Waits for the stream to be done. + */ + void await() throws InterruptedException; + + /** + * Waits for the stream to be done. + */ + boolean await(long timeout, TimeUnit timeUnit) throws InterruptedException; + + /** + * Returns how many times {@link StreamTracer#inboundMessage} has been called. + */ + int getInboundMessageCount(); + + /** + * Returns how many times {@link StreamTracer#outboundMessage} has been called. + */ + int getOutboundMessageCount(); + + /** + * Returns the status passed to {@link StreamTracer#streamClosed}. + */ + Status getStatus(); + + /** + * Returns to sum of all sizes passed to {@link StreamTracer#inboundWireSize}. + */ + long getInboundWireSize(); + + /** + * Returns to sum of all sizes passed to {@link StreamTracer#inboundUncompressedSize}. + */ + long getInboundUncompressedSize(); + + /** + * Returns to sum of all sizes passed to {@link StreamTracer#outboundWireSize}. + */ + long getOutboundWireSize(); + + /** + * Returns to sum of al sizes passed to {@link StreamTracer#outboundUncompressedSize}. + */ + long getOutboundUncompressedSize(); + + /** + * A {@link StreamTracer} suitable for testing. + */ + public static class TestBaseStreamTracer extends StreamTracer implements TestStreamTracer { + + protected final AtomicLong outboundWireSize = new AtomicLong(); + protected final AtomicLong inboundWireSize = new AtomicLong(); + protected final AtomicLong outboundUncompressedSize = new AtomicLong(); + protected final AtomicLong inboundUncompressedSize = new AtomicLong(); + protected final AtomicInteger inboundMessageCount = new AtomicInteger(); + protected final AtomicInteger outboundMessageCount = new AtomicInteger(); + protected final AtomicReference streamClosedStatus = new AtomicReference(); + protected final CountDownLatch streamClosed = new CountDownLatch(1); + protected final AtomicBoolean failDuplicateCallbacks = new AtomicBoolean(true); + + @Override + public void await() throws InterruptedException { + streamClosed.await(); + } + + @Override + public boolean await(long timeout, TimeUnit timeUnit) throws InterruptedException { + return streamClosed.await(timeout, timeUnit); + } + + @Override + public int getInboundMessageCount() { + return inboundMessageCount.get(); + } + + @Override + public int getOutboundMessageCount() { + return outboundMessageCount.get(); + } + + @Override + public Status getStatus() { + return streamClosedStatus.get(); + } + + @Override + public long getInboundWireSize() { + return inboundWireSize.get(); + } + + @Override + public long getInboundUncompressedSize() { + return inboundUncompressedSize.get(); + } + + @Override + public long getOutboundWireSize() { + return outboundWireSize.get(); + } + + @Override + public long getOutboundUncompressedSize() { + return outboundUncompressedSize.get(); + } + + @Override + public void outboundWireSize(long bytes) { + outboundWireSize.addAndGet(bytes); + } + + @Override + public void inboundWireSize(long bytes) { + inboundWireSize.addAndGet(bytes); + } + + @Override + public void outboundUncompressedSize(long bytes) { + outboundUncompressedSize.addAndGet(bytes); + } + + @Override + public void inboundUncompressedSize(long bytes) { + inboundUncompressedSize.addAndGet(bytes); + } + + @Override + public void streamClosed(Status status) { + if (!streamClosedStatus.compareAndSet(null, status)) { + if (failDuplicateCallbacks.get()) { + throw new AssertionError("streamClosed called more than once"); + } + } else { + streamClosed.countDown(); + } + } + + @Override + public void inboundMessage() { + inboundMessageCount.incrementAndGet(); + } + + @Override + public void outboundMessage() { + outboundMessageCount.incrementAndGet(); + } + } +}