Fix flakiness in Cascading cancellation tests

Add explicit shutdown for other executors
This commit is contained in:
Louis Ryan 2016-02-17 15:55:32 -08:00
parent 6715f1202c
commit 7fc986e6d0
1 changed files with 44 additions and 36 deletions

View File

@ -58,10 +58,9 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -75,23 +74,26 @@ public class CascadingTest {
TestServiceGrpc.TestService service; TestServiceGrpc.TestService service;
private ManagedChannelImpl channel; private ManagedChannelImpl channel;
private ServerImpl server; private ServerImpl server;
private AtomicInteger depth; private AtomicInteger nodeCount;
private AtomicInteger observedCancellations; private AtomicInteger observedCancellations;
private AtomicInteger receivedCancellations; private AtomicInteger receivedCancellations;
private TestServiceGrpc.TestServiceBlockingStub blockingStub; private TestServiceGrpc.TestServiceBlockingStub blockingStub;
private TestServiceGrpc.TestServiceStub asyncStub; private TestServiceGrpc.TestServiceStub asyncStub;
private ScheduledExecutorService scheduler; private ScheduledExecutorService scheduler;
private ExecutorService otherWork;
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
channel = InProcessChannelBuilder.forName("channel").build(); nodeCount = new AtomicInteger();
depth = new AtomicInteger();
observedCancellations = new AtomicInteger(); observedCancellations = new AtomicInteger();
receivedCancellations = new AtomicInteger(); receivedCancellations = new AtomicInteger();
scheduler = Executors.newScheduledThreadPool(1);
// Use a cached thread pool as we need a thread for each blocked call
otherWork = Executors.newCachedThreadPool();
channel = InProcessChannelBuilder.forName("channel").executor(otherWork).build();
blockingStub = TestServiceGrpc.newBlockingStub(channel); blockingStub = TestServiceGrpc.newBlockingStub(channel);
asyncStub = TestServiceGrpc.newStub(channel); asyncStub = TestServiceGrpc.newStub(channel);
scheduler = Executors.newScheduledThreadPool(1);
} }
@After @After
@ -99,6 +101,8 @@ public class CascadingTest {
Context.ROOT.attach(); Context.ROOT.attach();
channel.shutdownNow(); channel.shutdownNow();
server.shutdownNow(); server.shutdownNow();
otherWork.shutdownNow();
scheduler.shutdownNow();
} }
/** /**
@ -108,18 +112,21 @@ public class CascadingTest {
@Test @Test
public void testCascadingCancellationViaOuterContextExpiration() throws Exception { public void testCascadingCancellationViaOuterContextExpiration() throws Exception {
startChainingServer(3); startChainingServer(3);
Context.current().withDeadlineAfter(150, TimeUnit.MILLISECONDS, scheduler).attach(); Context.current().withDeadlineAfter(500, TimeUnit.MILLISECONDS, scheduler).attach();
try { try {
blockingStub.unaryCall(Messages.SimpleRequest.getDefaultInstance()); blockingStub.unaryCall(Messages.SimpleRequest.getDefaultInstance());
fail("Expected cancellation"); fail("Expected cancellation");
} catch (StatusRuntimeException sre) { } catch (StatusRuntimeException sre) {
// Wait for the workers to finish // Wait for the workers to finish
Thread.sleep(500);
Status status = Status.fromThrowable(sre); Status status = Status.fromThrowable(sre);
assertEquals(Status.Code.CANCELLED, status.getCode()); assertEquals(Status.Code.CANCELLED, status.getCode());
// Wait for the channel to shutdown so we know all the calls have completed
channel.shutdown();
channel.awaitTermination(5, TimeUnit.SECONDS);
// Should have 3 calls before timeout propagates // Should have 3 calls before timeout propagates
assertEquals(3, depth.get()); assertEquals(3, nodeCount.get());
// Should have observed 2 cancellations responses from downstream servers // Should have observed 2 cancellations responses from downstream servers
assertEquals(2, observedCancellations.get()); assertEquals(2, observedCancellations.get());
@ -135,19 +142,22 @@ public class CascadingTest {
public void testCascadingCancellationViaMethodTimeout() throws Exception { public void testCascadingCancellationViaMethodTimeout() throws Exception {
startChainingServer(3); startChainingServer(3);
try { try {
blockingStub.withDeadlineAfter(150, TimeUnit.MILLISECONDS) blockingStub.withDeadlineAfter(500, TimeUnit.MILLISECONDS)
.unaryCall(Messages.SimpleRequest.getDefaultInstance()); .unaryCall(Messages.SimpleRequest.getDefaultInstance());
fail("Expected cancellation"); fail("Expected cancellation");
} catch (StatusRuntimeException sre) { } catch (StatusRuntimeException sre) {
// Wait for the workers to finish // Wait for the workers to finish
Thread.sleep(150);
Status status = Status.fromThrowable(sre); Status status = Status.fromThrowable(sre);
// Outermost caller observes deadline exceeded, the descendant RPCs are cancelled so they // Outermost caller observes deadline exceeded, the descendant RPCs are cancelled so they
// receive cancellation. // receive cancellation.
assertEquals(Status.Code.DEADLINE_EXCEEDED, status.getCode()); assertEquals(Status.Code.DEADLINE_EXCEEDED, status.getCode());
// Wait for the channel to shutdown so we know all the calls have completed
channel.shutdown();
channel.awaitTermination(5, TimeUnit.SECONDS);
// Should have 3 calls before deadline propagates // Should have 3 calls before deadline propagates
assertEquals(3, depth.get()); assertEquals(3, nodeCount.get());
// Server should have observed 2 cancellations from downstream calls // Server should have observed 2 cancellations from downstream calls
assertEquals(2, observedCancellations.get()); assertEquals(2, observedCancellations.get());
// and received 2 cancellations // and received 2 cancellations
@ -161,18 +171,22 @@ public class CascadingTest {
*/ */
@Test @Test
public void testCascadingCancellationViaLeafFailure() throws Exception { public void testCascadingCancellationViaLeafFailure() throws Exception {
startCallTreeServer(); startCallTreeServer(3);
try { try {
// Use response size limit to control tree depth. // Use response size limit to control tree nodeCount.
blockingStub.unaryCall(Messages.SimpleRequest.newBuilder().setResponseSize(3).build()); blockingStub.unaryCall(Messages.SimpleRequest.newBuilder().setResponseSize(3).build());
fail("Expected abort"); fail("Expected abort");
} catch (StatusRuntimeException sre) { } catch (StatusRuntimeException sre) {
// Wait for the workers to finish // Wait for the workers to finish
Thread.sleep(100);
Status status = Status.fromThrowable(sre); Status status = Status.fromThrowable(sre);
// Outermost caller observes ABORTED propagating up from the failing leaf, // Outermost caller observes ABORTED propagating up from the failing leaf,
// The descendant RPCs are cancelled so they receive CANCELLED. // The descendant RPCs are cancelled so they receive CANCELLED.
assertEquals(Status.Code.ABORTED, status.getCode()); assertEquals(Status.Code.ABORTED, status.getCode());
// Wait for the channel to shutdown so we know all the calls have completed
channel.shutdown();
channel.awaitTermination(5, TimeUnit.SECONDS);
// All nodes (15) except one edge of the tree (4) will be cancelled. // All nodes (15) except one edge of the tree (4) will be cancelled.
assertEquals(11, observedCancellations.get()); assertEquals(11, observedCancellations.get());
assertEquals(11, receivedCancellations.get()); assertEquals(11, receivedCancellations.get());
@ -184,7 +198,6 @@ public class CascadingTest {
*/ */
private void startChainingServer(final int depthThreshold) private void startChainingServer(final int depthThreshold)
throws IOException { throws IOException {
final Executor otherWork = Context.propagate(Executors.newCachedThreadPool());
server = InProcessServerBuilder.forName("channel").addService( server = InProcessServerBuilder.forName("channel").addService(
ServerInterceptors.intercept(TestServiceGrpc.bindService(service), ServerInterceptors.intercept(TestServiceGrpc.bindService(service),
new ServerInterceptor() { new ServerInterceptor() {
@ -200,13 +213,12 @@ public class CascadingTest {
return new ServerCall.Listener<ReqT>() { return new ServerCall.Listener<ReqT>() {
@Override @Override
public void onMessage(final ReqT message) { public void onMessage(final ReqT message) {
// Wait and then recurse. if (nodeCount.incrementAndGet() == depthThreshold) {
if (depth.incrementAndGet() == depthThreshold) {
// No need to abort so just wait for top-down cancellation // No need to abort so just wait for top-down cancellation
return; return;
} }
otherWork.execute(new Runnable() { Context.propagate(otherWork).execute(new Runnable() {
@Override @Override
public void run() { public void run() {
try { try {
@ -237,11 +249,11 @@ public class CascadingTest {
/** /**
* Create a tree of client to server calls where each received call on the server * Create a tree of client to server calls where each received call on the server
* fans out to two downstream calls. Uses SimpleRequest.response_size to limit the depth * fans out to two downstream calls. Uses SimpleRequest.response_size to limit the nodeCount
* of the tree. One of the leaves will ABORT to trigger cancellation back up to tree. * of the tree. One of the leaves will ABORT to trigger cancellation back up to tree.
*/ */
private void startCallTreeServer() throws IOException { private void startCallTreeServer(int depthThreshold) throws IOException {
final Semaphore semaphore = new Semaphore(1); final AtomicInteger nodeCount = new AtomicInteger((2 << depthThreshold) - 1);
server = InProcessServerBuilder.forName("channel").addService( server = InProcessServerBuilder.forName("channel").addService(
ServerInterceptors.intercept(TestServiceGrpc.bindService(service), ServerInterceptors.intercept(TestServiceGrpc.bindService(service),
new ServerInterceptor() { new ServerInterceptor() {
@ -258,20 +270,16 @@ public class CascadingTest {
@Override @Override
public void onMessage(final ReqT message) { public void onMessage(final ReqT message) {
Messages.SimpleRequest req = (Messages.SimpleRequest) message; Messages.SimpleRequest req = (Messages.SimpleRequest) message;
// we are at a leaf node, acquire the semaphore and cause this edge of the if (nodeCount.decrementAndGet() == 0) {
// tree to ABORT. // we are in the final leaf node so trigger an ABORT upwards
if (req.getResponseSize() == 0) { Context.propagate(otherWork).execute(new Runnable() {
if (semaphore.tryAcquire(1)) {
Executors.newScheduledThreadPool(1).schedule(
new Runnable() {
@Override @Override
public void run() { public void run() {
call.close(Status.ABORTED, new Metadata()); call.close(Status.ABORTED, new Metadata());
} }
}, 50, TimeUnit.MILLISECONDS); });
} } else if (req.getResponseSize() != 0) {
} else { // We are in a non leaf node so fire off two requests
// Decrement tree depth limit
req = req.toBuilder().setResponseSize(req.getResponseSize() - 1).build(); req = req.toBuilder().setResponseSize(req.getResponseSize() - 1).build();
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
asyncStub.unaryCall(req, asyncStub.unaryCall(req,