diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java index b566a8c888..fab130a454 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java @@ -16,8 +16,10 @@ package io.grpc.testing.integration; -import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.ChannelCredentials; import io.grpc.Grpc; import io.grpc.ManagedChannel; @@ -32,7 +34,8 @@ import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; import java.io.File; import java.io.IOException; -import java.util.concurrent.CountDownLatch; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -40,9 +43,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -58,28 +59,29 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ConcurrencyTest { - @Rule public final Timeout globalTimeout = Timeout.seconds(60); - /** - * A response observer that signals a {@code CountDownLatch} when the proper number of responses - * arrives and the server signals that the RPC is complete. + * A response observer that completes a {@code ListenableFuture} when the proper number of + * responses arrives and the server signals that the RPC is complete. */ private static class SignalingResponseObserver implements StreamObserver { - public SignalingResponseObserver(CountDownLatch responsesDoneSignal) { - this.responsesDoneSignal = responsesDoneSignal; + public SignalingResponseObserver(SettableFuture completionFuture) { + this.completionFuture = completionFuture; } @Override public void onCompleted() { - Preconditions.checkState(numResponsesReceived == NUM_RESPONSES_PER_REQUEST); - responsesDoneSignal.countDown(); + if (numResponsesReceived != NUM_RESPONSES_PER_REQUEST) { + completionFuture.setException( + new IllegalStateException("Wrong number of responses: " + numResponsesReceived)); + } else { + completionFuture.set(null); + } } @Override public void onError(Throwable error) { - // This should never happen. If it does happen, ensure that the error is visible. - error.printStackTrace(); + completionFuture.setException(error); } @Override @@ -87,19 +89,19 @@ public class ConcurrencyTest { numResponsesReceived++; } - private final CountDownLatch responsesDoneSignal; + private final SettableFuture completionFuture; private int numResponsesReceived = 0; } /** * A client worker task that waits until all client workers are ready, then sends a request for a - * server-streaming RPC and arranges for a {@code CountDownLatch} to be signaled when the RPC is + * server-streaming RPC and arranges for a {@code ListenableFuture} to be signaled when the RPC is * complete. */ private class ClientWorker implements Runnable { - public ClientWorker(CyclicBarrier startBarrier, CountDownLatch responsesDoneSignal) { + public ClientWorker(CyclicBarrier startBarrier, SettableFuture completionFuture) { this.startBarrier = startBarrier; - this.responsesDoneSignal = responsesDoneSignal; + this.completionFuture = completionFuture; } @Override @@ -117,14 +119,17 @@ public class ConcurrencyTest { // Wait until all client worker threads are poised & ready, then send the request. This way // all clients send their requests at approximately the same time. startBarrier.await(); - clientStub.streamingOutputCall(request, new SignalingResponseObserver(responsesDoneSignal)); - } catch (Exception e) { - throw e instanceof RuntimeException ? (RuntimeException) e : new RuntimeException(e); + clientStub.streamingOutputCall(request, new SignalingResponseObserver(completionFuture)); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + completionFuture.setException(ex); + } catch (Throwable t) { + completionFuture.setException(t); } } private final CyclicBarrier startBarrier; - private final CountDownLatch responsesDoneSignal; + private final SettableFuture completionFuture; } private static final int NUM_SERVER_THREADS = 10; @@ -168,14 +173,15 @@ public class ConcurrencyTest { @Test public void serverStreamingTest() throws Exception { CyclicBarrier startBarrier = new CyclicBarrier(NUM_CONCURRENT_REQUESTS); - CountDownLatch responsesDoneSignal = new CountDownLatch(NUM_CONCURRENT_REQUESTS); + List> workerFutures = new ArrayList<>(NUM_CONCURRENT_REQUESTS); for (int i = 0; i < NUM_CONCURRENT_REQUESTS; i++) { - clientExecutor.execute(new ClientWorker(startBarrier, responsesDoneSignal)); + SettableFuture future = SettableFuture.create(); + clientExecutor.execute(new ClientWorker(startBarrier, future)); + workerFutures.add(future); } - // Wait until the clients all receive their complete RPC response streams. - responsesDoneSignal.await(); + Futures.allAsList(workerFutures).get(60, TimeUnit.SECONDS); } /**