Revert "Revert "stub: Wait for onClose when blocking stub is interrupted"" (#6255)

This reverts commit 0604e14154.
This commit is contained in:
Chengyuan Zhang 2019-10-07 11:40:20 -07:00 committed by GitHub
parent 2caa77d48f
commit 0ec31c683e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 263 additions and 39 deletions

View File

@ -124,6 +124,7 @@ public final class ClientCalls {
public static <ReqT, RespT> RespT blockingUnaryCall(
Channel channel, MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, ReqT req) {
ThreadlessExecutor executor = new ThreadlessExecutor();
boolean interrupt = false;
ClientCall<ReqT, RespT> call = channel.newCall(method, callOptions.withExecutor(executor));
try {
ListenableFuture<RespT> responseFuture = futureUnaryCall(call, req);
@ -131,18 +132,22 @@ public final class ClientCalls {
try {
executor.waitAndDrain();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw Status.CANCELLED
.withDescription("Call was interrupted")
.withCause(e)
.asRuntimeException();
interrupt = true;
call.cancel("Thread interrupted", e);
// Now wait for onClose() to be called, so interceptors can clean up
}
}
return getUnchecked(responseFuture);
} catch (RuntimeException e) {
// Something very bad happened. All bets are off; it may be dangerous to wait for onClose().
throw cancelThrow(call, e);
} catch (Error e) {
// Something very bad happened. All bets are off; it may be dangerous to wait for onClose().
throw cancelThrow(call, e);
} finally {
if (interrupt) {
Thread.currentThread().interrupt();
}
}
}
@ -209,7 +214,7 @@ public final class ClientCalls {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw Status.CANCELLED
.withDescription("Call was interrupted")
.withDescription("Thread interrupted")
.withCause(e)
.asRuntimeException();
} catch (ExecutionException e) {
@ -553,30 +558,45 @@ public final class ClientCalls {
return listener;
}
private Object waitForNext() throws InterruptedException {
if (threadless == null) {
return buffer.take();
} else {
Object next = buffer.poll();
while (next == null) {
threadless.waitAndDrain();
next = buffer.poll();
private Object waitForNext() {
boolean interrupt = false;
try {
if (threadless == null) {
while (true) {
try {
return buffer.take();
} catch (InterruptedException ie) {
interrupt = true;
call.cancel("Thread interrupted", ie);
// Now wait for onClose() to be called, to guarantee BlockingQueue doesn't fill
}
}
} else {
Object next;
while ((next = buffer.poll()) == null) {
try {
threadless.waitAndDrain();
} catch (InterruptedException ie) {
interrupt = true;
call.cancel("Thread interrupted", ie);
// Now wait for onClose() to be called, so interceptors can clean up
}
}
return next;
}
} finally {
if (interrupt) {
Thread.currentThread().interrupt();
}
return next;
}
}
@Override
public boolean hasNext() {
if (last == null) {
try {
// Will block here indefinitely waiting for content. RPC timeouts defend against permanent
// hangs here as the call will become closed.
last = waitForNext();
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw Status.CANCELLED.withDescription("interrupted").withCause(ie).asRuntimeException();
}
while (last == null) {
// Will block here indefinitely waiting for content. RPC timeouts defend against permanent
// hangs here as the call will become closed.
last = waitForNext();
}
if (last instanceof StatusRuntimeException) {
// Rethrow the exception with a new stacktrace.
@ -650,15 +670,14 @@ public final class ClientCalls {
* Must only be called by one thread at a time.
*/
public void waitAndDrain() throws InterruptedException {
final Thread currentThread = Thread.currentThread();
throwIfInterrupted(currentThread);
throwIfInterrupted();
Runnable runnable = poll();
if (runnable == null) {
waiter = currentThread;
waiter = Thread.currentThread();
try {
while ((runnable = poll()) == null) {
LockSupport.park(this);
throwIfInterrupted(currentThread);
throwIfInterrupted();
}
} finally {
waiter = null;
@ -673,8 +692,8 @@ public final class ClientCalls {
} while ((runnable = poll()) != null);
}
private static void throwIfInterrupted(Thread currentThread) throws InterruptedException {
if (currentThread.isInterrupted()) {
private static void throwIfInterrupted() throws InterruptedException {
if (Thread.interrupted()) {
throw new InterruptedException();
}
}

View File

@ -18,6 +18,7 @@ package io.grpc.stub;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
@ -26,7 +27,11 @@ import static org.junit.Assert.fail;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
@ -39,6 +44,8 @@ import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.NoopClientCall;
import io.grpc.stub.ServerCalls.NoopStreamObserver;
import io.grpc.stub.ServerCalls.ServerStreamingMethod;
import io.grpc.stub.ServerCalls.UnaryMethod;
import io.grpc.stub.ServerCallsTest.IntegerMarshaller;
import java.util.ArrayList;
import java.util.Arrays;
@ -62,14 +69,17 @@ import org.mockito.MockitoAnnotations;
*/
@RunWith(JUnit4.class)
public class ClientCallsTest {
private static final MethodDescriptor<Integer, Integer> STREAMING_METHOD =
private static final MethodDescriptor<Integer, Integer> UNARY_METHOD =
MethodDescriptor.<Integer, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("some/method")
.setRequestMarshaller(new IntegerMarshaller())
.setResponseMarshaller(new IntegerMarshaller())
.build();
private static final MethodDescriptor<Integer, Integer> SERVER_STREAMING_METHOD =
UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.SERVER_STREAMING).build();
private static final MethodDescriptor<Integer, Integer> BIDI_STREAMING_METHOD =
UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.BIDI_STREAMING).build();
private Server server;
private ManagedChannel channel;
@ -130,6 +140,69 @@ public class ClientCallsTest {
}
}
@Test
public void blockingUnaryCall2_success() throws Exception {
Integer req = 2;
final Integer resp = 3;
class BasicUnaryResponse implements UnaryMethod<Integer, Integer> {
Integer request;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
this.request = request;
responseObserver.onNext(resp);
responseObserver.onCompleted();
}
}
BasicUnaryResponse service = new BasicUnaryResponse();
server = InProcessServerBuilder.forName("simple-reply").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(service))
.build())
.build().start();
channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build();
Integer actualResponse =
ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req);
assertEquals(resp, actualResponse);
assertEquals(req, service.request);
}
@Test
public void blockingUnaryCall2_interruptedWaitsForOnClose() throws Exception {
Integer req = 2;
class NoopUnaryMethod implements UnaryMethod<Integer, Integer> {
ServerCallStreamObserver<Integer> observer;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
observer = (ServerCallStreamObserver<Integer>) responseObserver;
}
}
NoopUnaryMethod methodImpl = new NoopUnaryMethod();
server = InProcessServerBuilder.forName("noop").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(methodImpl))
.build())
.build().start();
InterruptInterceptor interceptor = new InterruptInterceptor();
channel = InProcessChannelBuilder.forName("noop")
.directExecutor()
.intercept(interceptor)
.build();
try {
ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req);
fail();
} catch (StatusRuntimeException ex) {
assertTrue(Thread.interrupted());
assertTrue("interrupted", ex.getCause() instanceof InterruptedException);
}
assertTrue("onCloseCalled", interceptor.onCloseCalled);
assertTrue("context not cancelled", methodImpl.observer.isCancelled());
}
@Test
public void unaryFutureCallSuccess() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
@ -372,8 +445,8 @@ public class ClientCallsTest {
public void inprocessTransportInboundFlowControl() throws Exception {
final Semaphore semaphore = new Semaphore(0);
ServerServiceDefinition service = ServerServiceDefinition.builder(
new ServiceDescriptor("some", STREAMING_METHOD))
.addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
new ServiceDescriptor("some", BIDI_STREAMING_METHOD))
.addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
new ServerCalls.BidiStreamingMethod<Integer, Integer>() {
int iteration;
@ -404,7 +477,7 @@ public class ClientCallsTest {
server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor()
.addService(service).build().start();
channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build();
final ClientCall<Integer, Integer> clientCall = channel.newCall(STREAMING_METHOD,
final ClientCall<Integer, Integer> clientCall = channel.newCall(BIDI_STREAMING_METHOD,
CallOptions.DEFAULT);
final CountDownLatch latch = new CountDownLatch(1);
final List<Object> receivedMessages = new ArrayList<>(6);
@ -453,8 +526,8 @@ public class ClientCallsTest {
final SettableFuture<ServerCallStreamObserver<Integer>> observerFuture
= SettableFuture.create();
ServerServiceDefinition service = ServerServiceDefinition.builder(
new ServiceDescriptor("some", STREAMING_METHOD))
.addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
new ServiceDescriptor("some", BIDI_STREAMING_METHOD))
.addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
new ServerCalls.BidiStreamingMethod<Integer, Integer>() {
@Override
public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver) {
@ -485,7 +558,7 @@ public class ClientCallsTest {
server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor()
.addService(service).build().start();
channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build();
final ClientCall<Integer, Integer> clientCall = channel.newCall(STREAMING_METHOD,
final ClientCall<Integer, Integer> clientCall = channel.newCall(BIDI_STREAMING_METHOD,
CallOptions.DEFAULT);
final SettableFuture<Void> future = SettableFuture.create();
@ -564,4 +637,136 @@ public class ClientCallsTest {
assertSame(trailers, metadata);
}
}
@Test
public void blockingServerStreamingCall_interruptedWaitsForOnClose() throws Exception {
Integer req = 2;
class NoopServerStreamingMethod implements ServerStreamingMethod<Integer, Integer> {
ServerCallStreamObserver<Integer> observer;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
observer = (ServerCallStreamObserver<Integer>) responseObserver;
}
}
NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod();
server = InProcessServerBuilder.forName("noop").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl))
.build())
.build().start();
InterruptInterceptor interceptor = new InterruptInterceptor();
channel = InProcessChannelBuilder.forName("noop")
.directExecutor()
.intercept(interceptor)
.build();
Iterator<Integer> iter = ClientCalls.blockingServerStreamingCall(
channel.newCall(SERVER_STREAMING_METHOD, CallOptions.DEFAULT), req);
try {
iter.next();
fail();
} catch (StatusRuntimeException ex) {
assertTrue(Thread.interrupted());
assertTrue("interrupted", ex.getCause() instanceof InterruptedException);
}
assertTrue("onCloseCalled", interceptor.onCloseCalled);
assertTrue("context not cancelled", methodImpl.observer.isCancelled());
}
@Test
public void blockingServerStreamingCall2_success() throws Exception {
Integer req = 2;
final Integer resp1 = 3;
final Integer resp2 = 4;
class BasicServerStreamingResponse implements ServerStreamingMethod<Integer, Integer> {
Integer request;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
this.request = request;
responseObserver.onNext(resp1);
responseObserver.onNext(resp2);
responseObserver.onCompleted();
}
}
BasicServerStreamingResponse service = new BasicServerStreamingResponse();
server = InProcessServerBuilder.forName("simple-reply").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(service))
.build())
.build().start();
channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build();
Iterator<Integer> iter = ClientCalls.blockingServerStreamingCall(
channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req);
assertEquals(resp1, iter.next());
assertTrue(iter.hasNext());
assertEquals(resp2, iter.next());
assertFalse(iter.hasNext());
assertEquals(req, service.request);
}
@Test
public void blockingServerStreamingCall2_interruptedWaitsForOnClose() throws Exception {
Integer req = 2;
class NoopServerStreamingMethod implements ServerStreamingMethod<Integer, Integer> {
ServerCallStreamObserver<Integer> observer;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
observer = (ServerCallStreamObserver<Integer>) responseObserver;
}
}
NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod();
server = InProcessServerBuilder.forName("noop").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl))
.build())
.build().start();
InterruptInterceptor interceptor = new InterruptInterceptor();
channel = InProcessChannelBuilder.forName("noop")
.directExecutor()
.intercept(interceptor)
.build();
Iterator<Integer> iter = ClientCalls.blockingServerStreamingCall(
channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req);
try {
iter.next();
fail();
} catch (StatusRuntimeException ex) {
assertTrue(Thread.interrupted());
assertTrue("interrupted", ex.getCause() instanceof InterruptedException);
}
assertTrue("onCloseCalled", interceptor.onCloseCalled);
assertTrue("context not cancelled", methodImpl.observer.isCancelled());
}
// Used for blocking tests to check interrupt behavior and make sure onClose is still called.
class InterruptInterceptor implements ClientInterceptor {
boolean onCloseCalled;
@Override
public <ReqT,RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
@Override public void start(ClientCall.Listener<RespT> listener, Metadata headers) {
super.start(new SimpleForwardingClientCallListener<RespT>(listener) {
@Override public void onClose(Status status, Metadata trailers) {
onCloseCalled = true;
super.onClose(status, trailers);
}
}, headers);
}
@Override public void halfClose() {
Thread.currentThread().interrupt();
super.halfClose();
}
};
}
}
}