From 7fc986e6d0e4b5ec6b05ce4e50fa151799114d14 Mon Sep 17 00:00:00 2001 From: Louis Ryan Date: Wed, 17 Feb 2016 15:55:32 -0800 Subject: [PATCH] Fix flakiness in Cascading cancellation tests Add explicit shutdown for other executors --- .../testing/integration/CascadingTest.java | 80 ++++++++++--------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/CascadingTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/CascadingTest.java index 7375dfc37e..60a4555d6c 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/CascadingTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/CascadingTest.java @@ -58,10 +58,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import java.io.IOException; -import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -75,23 +74,26 @@ public class CascadingTest { TestServiceGrpc.TestService service; private ManagedChannelImpl channel; private ServerImpl server; - private AtomicInteger depth; + private AtomicInteger nodeCount; private AtomicInteger observedCancellations; private AtomicInteger receivedCancellations; private TestServiceGrpc.TestServiceBlockingStub blockingStub; private TestServiceGrpc.TestServiceStub asyncStub; private ScheduledExecutorService scheduler; + private ExecutorService otherWork; @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); - channel = InProcessChannelBuilder.forName("channel").build(); - depth = new AtomicInteger(); + nodeCount = new AtomicInteger(); observedCancellations = new AtomicInteger(); receivedCancellations = new AtomicInteger(); + scheduler = Executors.newScheduledThreadPool(1); + // Use a cached thread pool as we need a thread for each blocked call + otherWork = Executors.newCachedThreadPool(); + channel = InProcessChannelBuilder.forName("channel").executor(otherWork).build(); blockingStub = TestServiceGrpc.newBlockingStub(channel); asyncStub = TestServiceGrpc.newStub(channel); - scheduler = Executors.newScheduledThreadPool(1); } @After @@ -99,6 +101,8 @@ public class CascadingTest { Context.ROOT.attach(); channel.shutdownNow(); server.shutdownNow(); + otherWork.shutdownNow(); + scheduler.shutdownNow(); } /** @@ -108,18 +112,21 @@ public class CascadingTest { @Test public void testCascadingCancellationViaOuterContextExpiration() throws Exception { startChainingServer(3); - Context.current().withDeadlineAfter(150, TimeUnit.MILLISECONDS, scheduler).attach(); + Context.current().withDeadlineAfter(500, TimeUnit.MILLISECONDS, scheduler).attach(); try { blockingStub.unaryCall(Messages.SimpleRequest.getDefaultInstance()); fail("Expected cancellation"); } catch (StatusRuntimeException sre) { // Wait for the workers to finish - Thread.sleep(500); Status status = Status.fromThrowable(sre); assertEquals(Status.Code.CANCELLED, status.getCode()); + // Wait for the channel to shutdown so we know all the calls have completed + channel.shutdown(); + channel.awaitTermination(5, TimeUnit.SECONDS); + // Should have 3 calls before timeout propagates - assertEquals(3, depth.get()); + assertEquals(3, nodeCount.get()); // Should have observed 2 cancellations responses from downstream servers assertEquals(2, observedCancellations.get()); @@ -135,19 +142,22 @@ public class CascadingTest { public void testCascadingCancellationViaMethodTimeout() throws Exception { startChainingServer(3); try { - blockingStub.withDeadlineAfter(150, TimeUnit.MILLISECONDS) + blockingStub.withDeadlineAfter(500, TimeUnit.MILLISECONDS) .unaryCall(Messages.SimpleRequest.getDefaultInstance()); fail("Expected cancellation"); } catch (StatusRuntimeException sre) { // Wait for the workers to finish - Thread.sleep(150); Status status = Status.fromThrowable(sre); // Outermost caller observes deadline exceeded, the descendant RPCs are cancelled so they // receive cancellation. assertEquals(Status.Code.DEADLINE_EXCEEDED, status.getCode()); + // Wait for the channel to shutdown so we know all the calls have completed + channel.shutdown(); + channel.awaitTermination(5, TimeUnit.SECONDS); + // Should have 3 calls before deadline propagates - assertEquals(3, depth.get()); + assertEquals(3, nodeCount.get()); // Server should have observed 2 cancellations from downstream calls assertEquals(2, observedCancellations.get()); // and received 2 cancellations @@ -161,18 +171,22 @@ public class CascadingTest { */ @Test public void testCascadingCancellationViaLeafFailure() throws Exception { - startCallTreeServer(); + startCallTreeServer(3); try { - // Use response size limit to control tree depth. + // Use response size limit to control tree nodeCount. blockingStub.unaryCall(Messages.SimpleRequest.newBuilder().setResponseSize(3).build()); fail("Expected abort"); } catch (StatusRuntimeException sre) { // Wait for the workers to finish - Thread.sleep(100); Status status = Status.fromThrowable(sre); // Outermost caller observes ABORTED propagating up from the failing leaf, // The descendant RPCs are cancelled so they receive CANCELLED. assertEquals(Status.Code.ABORTED, status.getCode()); + + // Wait for the channel to shutdown so we know all the calls have completed + channel.shutdown(); + channel.awaitTermination(5, TimeUnit.SECONDS); + // All nodes (15) except one edge of the tree (4) will be cancelled. assertEquals(11, observedCancellations.get()); assertEquals(11, receivedCancellations.get()); @@ -184,7 +198,6 @@ public class CascadingTest { */ private void startChainingServer(final int depthThreshold) throws IOException { - final Executor otherWork = Context.propagate(Executors.newCachedThreadPool()); server = InProcessServerBuilder.forName("channel").addService( ServerInterceptors.intercept(TestServiceGrpc.bindService(service), new ServerInterceptor() { @@ -200,13 +213,12 @@ public class CascadingTest { return new ServerCall.Listener() { @Override public void onMessage(final ReqT message) { - // Wait and then recurse. - if (depth.incrementAndGet() == depthThreshold) { + if (nodeCount.incrementAndGet() == depthThreshold) { // No need to abort so just wait for top-down cancellation return; } - otherWork.execute(new Runnable() { + Context.propagate(otherWork).execute(new Runnable() { @Override public void run() { try { @@ -237,11 +249,11 @@ public class CascadingTest { /** * Create a tree of client to server calls where each received call on the server - * fans out to two downstream calls. Uses SimpleRequest.response_size to limit the depth + * fans out to two downstream calls. Uses SimpleRequest.response_size to limit the nodeCount * of the tree. One of the leaves will ABORT to trigger cancellation back up to tree. */ - private void startCallTreeServer() throws IOException { - final Semaphore semaphore = new Semaphore(1); + private void startCallTreeServer(int depthThreshold) throws IOException { + final AtomicInteger nodeCount = new AtomicInteger((2 << depthThreshold) - 1); server = InProcessServerBuilder.forName("channel").addService( ServerInterceptors.intercept(TestServiceGrpc.bindService(service), new ServerInterceptor() { @@ -258,20 +270,16 @@ public class CascadingTest { @Override public void onMessage(final ReqT message) { Messages.SimpleRequest req = (Messages.SimpleRequest) message; - // we are at a leaf node, acquire the semaphore and cause this edge of the - // tree to ABORT. - if (req.getResponseSize() == 0) { - if (semaphore.tryAcquire(1)) { - Executors.newScheduledThreadPool(1).schedule( - new Runnable() { - @Override - public void run() { - call.close(Status.ABORTED, new Metadata()); - } - }, 50, TimeUnit.MILLISECONDS); - } - } else { - // Decrement tree depth limit + if (nodeCount.decrementAndGet() == 0) { + // we are in the final leaf node so trigger an ABORT upwards + Context.propagate(otherWork).execute(new Runnable() { + @Override + public void run() { + call.close(Status.ABORTED, new Metadata()); + } + }); + } else if (req.getResponseSize() != 0) { + // We are in a non leaf node so fire off two requests req = req.toBuilder().setResponseSize(req.getResponseSize() - 1).build(); for (int i = 0; i < 2; i++) { asyncStub.unaryCall(req,