stub: Wait for onClose when blocking stub is interrupted

Interceptors need to see the onClose to clean up properly.

This also changes an isInterrupted() to interrupted(), since previously
the interrupted flag was still set when InterruptedException was thrown.
This caused an infinite loop with the new code. Previously, all callers
immediately re-set the interrupted flag, so there was no issue.

Fixes #5576
This commit is contained in:
Eric Anderson 2019-04-22 16:32:06 -07:00 committed by GitHub
parent f4d48fec62
commit 6d44f46f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 263 additions and 39 deletions

View File

@ -123,6 +123,7 @@ public final class ClientCalls {
public static <ReqT, RespT> RespT blockingUnaryCall( public static <ReqT, RespT> RespT blockingUnaryCall(
Channel channel, MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, ReqT req) { Channel channel, MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, ReqT req) {
ThreadlessExecutor executor = new ThreadlessExecutor(); ThreadlessExecutor executor = new ThreadlessExecutor();
boolean interrupt = false;
ClientCall<ReqT, RespT> call = channel.newCall(method, callOptions.withExecutor(executor)); ClientCall<ReqT, RespT> call = channel.newCall(method, callOptions.withExecutor(executor));
try { try {
ListenableFuture<RespT> responseFuture = futureUnaryCall(call, req); ListenableFuture<RespT> responseFuture = futureUnaryCall(call, req);
@ -130,18 +131,22 @@ public final class ClientCalls {
try { try {
executor.waitAndDrain(); executor.waitAndDrain();
} catch (InterruptedException e) { } catch (InterruptedException e) {
Thread.currentThread().interrupt(); interrupt = true;
throw Status.CANCELLED call.cancel("Thread interrupted", e);
.withDescription("Call was interrupted") // Now wait for onClose() to be called, so interceptors can clean up
.withCause(e)
.asRuntimeException();
} }
} }
return getUnchecked(responseFuture); return getUnchecked(responseFuture);
} catch (RuntimeException e) { } catch (RuntimeException e) {
// Something very bad happened. All bets are off; it may be dangerous to wait for onClose().
throw cancelThrow(call, e); throw cancelThrow(call, e);
} catch (Error e) { } catch (Error e) {
// Something very bad happened. All bets are off; it may be dangerous to wait for onClose().
throw cancelThrow(call, e); throw cancelThrow(call, e);
} finally {
if (interrupt) {
Thread.currentThread().interrupt();
}
} }
} }
@ -208,7 +213,7 @@ public final class ClientCalls {
} catch (InterruptedException e) { } catch (InterruptedException e) {
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
throw Status.CANCELLED throw Status.CANCELLED
.withDescription("Call was interrupted") .withDescription("Thread interrupted")
.withCause(e) .withCause(e)
.asRuntimeException(); .asRuntimeException();
} catch (ExecutionException e) { } catch (ExecutionException e) {
@ -546,30 +551,45 @@ public final class ClientCalls {
return listener; return listener;
} }
private Object waitForNext() throws InterruptedException { private Object waitForNext() {
boolean interrupt = false;
try {
if (threadless == null) { if (threadless == null) {
while (true) {
try {
return buffer.take(); return buffer.take();
} catch (InterruptedException ie) {
interrupt = true;
call.cancel("Thread interrupted", ie);
// Now wait for onClose() to be called, to guarantee BlockingQueue doesn't fill
}
}
} else { } else {
Object next = buffer.poll(); Object next;
while (next == null) { while ((next = buffer.poll()) == null) {
try {
threadless.waitAndDrain(); threadless.waitAndDrain();
next = buffer.poll(); } catch (InterruptedException ie) {
interrupt = true;
call.cancel("Thread interrupted", ie);
// Now wait for onClose() to be called, so interceptors can clean up
}
} }
return next; return next;
} }
} finally {
if (interrupt) {
Thread.currentThread().interrupt();
}
}
} }
@Override @Override
public boolean hasNext() { public boolean hasNext() {
if (last == null) { while (last == null) {
try {
// Will block here indefinitely waiting for content. RPC timeouts defend against permanent // Will block here indefinitely waiting for content. RPC timeouts defend against permanent
// hangs here as the call will become closed. // hangs here as the call will become closed.
last = waitForNext(); last = waitForNext();
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw Status.CANCELLED.withDescription("interrupted").withCause(ie).asRuntimeException();
}
} }
if (last instanceof StatusRuntimeException) { if (last instanceof StatusRuntimeException) {
// Rethrow the exception with a new stacktrace. // Rethrow the exception with a new stacktrace.
@ -643,15 +663,14 @@ public final class ClientCalls {
* Must only be called by one thread at a time. * Must only be called by one thread at a time.
*/ */
public void waitAndDrain() throws InterruptedException { public void waitAndDrain() throws InterruptedException {
final Thread currentThread = Thread.currentThread(); throwIfInterrupted();
throwIfInterrupted(currentThread);
Runnable runnable = poll(); Runnable runnable = poll();
if (runnable == null) { if (runnable == null) {
waiter = currentThread; waiter = Thread.currentThread();
try { try {
while ((runnable = poll()) == null) { while ((runnable = poll()) == null) {
LockSupport.park(this); LockSupport.park(this);
throwIfInterrupted(currentThread); throwIfInterrupted();
} }
} finally { } finally {
waiter = null; waiter = null;
@ -666,8 +685,8 @@ public final class ClientCalls {
} while ((runnable = poll()) != null); } while ((runnable = poll()) != null);
} }
private static void throwIfInterrupted(Thread currentThread) throws InterruptedException { private static void throwIfInterrupted() throws InterruptedException {
if (currentThread.isInterrupted()) { if (Thread.interrupted()) {
throw new InterruptedException(); throw new InterruptedException();
} }
} }

View File

@ -18,6 +18,7 @@ package io.grpc.stub;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@ -26,7 +27,11 @@ import static org.junit.Assert.fail;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
@ -39,6 +44,8 @@ import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.NoopClientCall; import io.grpc.internal.NoopClientCall;
import io.grpc.stub.ServerCalls.NoopStreamObserver; import io.grpc.stub.ServerCalls.NoopStreamObserver;
import io.grpc.stub.ServerCalls.ServerStreamingMethod;
import io.grpc.stub.ServerCalls.UnaryMethod;
import io.grpc.stub.ServerCallsTest.IntegerMarshaller; import io.grpc.stub.ServerCallsTest.IntegerMarshaller;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -62,14 +69,17 @@ import org.mockito.MockitoAnnotations;
*/ */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class ClientCallsTest { public class ClientCallsTest {
private static final MethodDescriptor<Integer, Integer> UNARY_METHOD =
private static final MethodDescriptor<Integer, Integer> STREAMING_METHOD =
MethodDescriptor.<Integer, Integer>newBuilder() MethodDescriptor.<Integer, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.BIDI_STREAMING) .setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("some/method") .setFullMethodName("some/method")
.setRequestMarshaller(new IntegerMarshaller()) .setRequestMarshaller(new IntegerMarshaller())
.setResponseMarshaller(new IntegerMarshaller()) .setResponseMarshaller(new IntegerMarshaller())
.build(); .build();
private static final MethodDescriptor<Integer, Integer> SERVER_STREAMING_METHOD =
UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.SERVER_STREAMING).build();
private static final MethodDescriptor<Integer, Integer> BIDI_STREAMING_METHOD =
UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.BIDI_STREAMING).build();
private Server server; private Server server;
private ManagedChannel channel; private ManagedChannel channel;
@ -130,6 +140,69 @@ public class ClientCallsTest {
} }
} }
@Test
public void blockingUnaryCall2_success() throws Exception {
Integer req = 2;
final Integer resp = 3;
class BasicUnaryResponse implements UnaryMethod<Integer, Integer> {
Integer request;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
this.request = request;
responseObserver.onNext(resp);
responseObserver.onCompleted();
}
}
BasicUnaryResponse service = new BasicUnaryResponse();
server = InProcessServerBuilder.forName("simple-reply").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(service))
.build())
.build().start();
channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build();
Integer actualResponse =
ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req);
assertEquals(resp, actualResponse);
assertEquals(req, service.request);
}
@Test
public void blockingUnaryCall2_interruptedWaitsForOnClose() throws Exception {
Integer req = 2;
class NoopUnaryMethod implements UnaryMethod<Integer, Integer> {
ServerCallStreamObserver<Integer> observer;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
observer = (ServerCallStreamObserver<Integer>) responseObserver;
}
}
NoopUnaryMethod methodImpl = new NoopUnaryMethod();
server = InProcessServerBuilder.forName("noop").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(methodImpl))
.build())
.build().start();
InterruptInterceptor interceptor = new InterruptInterceptor();
channel = InProcessChannelBuilder.forName("noop")
.directExecutor()
.intercept(interceptor)
.build();
try {
ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req);
fail();
} catch (StatusRuntimeException ex) {
assertTrue(Thread.interrupted());
assertTrue("interrupted", ex.getCause() instanceof InterruptedException);
}
assertTrue("onCloseCalled", interceptor.onCloseCalled);
assertTrue("context not cancelled", methodImpl.observer.isCancelled());
}
@Test @Test
public void unaryFutureCallSuccess() throws Exception { public void unaryFutureCallSuccess() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener = final AtomicReference<ClientCall.Listener<String>> listener =
@ -372,8 +445,8 @@ public class ClientCallsTest {
public void inprocessTransportInboundFlowControl() throws Exception { public void inprocessTransportInboundFlowControl() throws Exception {
final Semaphore semaphore = new Semaphore(0); final Semaphore semaphore = new Semaphore(0);
ServerServiceDefinition service = ServerServiceDefinition.builder( ServerServiceDefinition service = ServerServiceDefinition.builder(
new ServiceDescriptor("some", STREAMING_METHOD)) new ServiceDescriptor("some", BIDI_STREAMING_METHOD))
.addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
new ServerCalls.BidiStreamingMethod<Integer, Integer>() { new ServerCalls.BidiStreamingMethod<Integer, Integer>() {
int iteration; int iteration;
@ -404,7 +477,7 @@ public class ClientCallsTest {
server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor()
.addService(service).build().start(); .addService(service).build().start();
channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build();
final ClientCall<Integer, Integer> clientCall = channel.newCall(STREAMING_METHOD, final ClientCall<Integer, Integer> clientCall = channel.newCall(BIDI_STREAMING_METHOD,
CallOptions.DEFAULT); CallOptions.DEFAULT);
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
final List<Object> receivedMessages = new ArrayList<>(6); final List<Object> receivedMessages = new ArrayList<>(6);
@ -453,8 +526,8 @@ public class ClientCallsTest {
final SettableFuture<ServerCallStreamObserver<Integer>> observerFuture final SettableFuture<ServerCallStreamObserver<Integer>> observerFuture
= SettableFuture.create(); = SettableFuture.create();
ServerServiceDefinition service = ServerServiceDefinition.builder( ServerServiceDefinition service = ServerServiceDefinition.builder(
new ServiceDescriptor("some", STREAMING_METHOD)) new ServiceDescriptor("some", BIDI_STREAMING_METHOD))
.addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
new ServerCalls.BidiStreamingMethod<Integer, Integer>() { new ServerCalls.BidiStreamingMethod<Integer, Integer>() {
@Override @Override
public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver) { public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver) {
@ -485,7 +558,7 @@ public class ClientCallsTest {
server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor()
.addService(service).build().start(); .addService(service).build().start();
channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build();
final ClientCall<Integer, Integer> clientCall = channel.newCall(STREAMING_METHOD, final ClientCall<Integer, Integer> clientCall = channel.newCall(BIDI_STREAMING_METHOD,
CallOptions.DEFAULT); CallOptions.DEFAULT);
final SettableFuture<Void> future = SettableFuture.create(); final SettableFuture<Void> future = SettableFuture.create();
@ -564,4 +637,136 @@ public class ClientCallsTest {
assertSame(trailers, metadata); assertSame(trailers, metadata);
} }
} }
@Test
public void blockingServerStreamingCall_interruptedWaitsForOnClose() throws Exception {
Integer req = 2;
class NoopServerStreamingMethod implements ServerStreamingMethod<Integer, Integer> {
ServerCallStreamObserver<Integer> observer;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
observer = (ServerCallStreamObserver<Integer>) responseObserver;
}
}
NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod();
server = InProcessServerBuilder.forName("noop").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl))
.build())
.build().start();
InterruptInterceptor interceptor = new InterruptInterceptor();
channel = InProcessChannelBuilder.forName("noop")
.directExecutor()
.intercept(interceptor)
.build();
Iterator<Integer> iter = ClientCalls.blockingServerStreamingCall(
channel.newCall(SERVER_STREAMING_METHOD, CallOptions.DEFAULT), req);
try {
iter.next();
fail();
} catch (StatusRuntimeException ex) {
assertTrue(Thread.interrupted());
assertTrue("interrupted", ex.getCause() instanceof InterruptedException);
}
assertTrue("onCloseCalled", interceptor.onCloseCalled);
assertTrue("context not cancelled", methodImpl.observer.isCancelled());
}
@Test
public void blockingServerStreamingCall2_success() throws Exception {
Integer req = 2;
final Integer resp1 = 3;
final Integer resp2 = 4;
class BasicServerStreamingResponse implements ServerStreamingMethod<Integer, Integer> {
Integer request;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
this.request = request;
responseObserver.onNext(resp1);
responseObserver.onNext(resp2);
responseObserver.onCompleted();
}
}
BasicServerStreamingResponse service = new BasicServerStreamingResponse();
server = InProcessServerBuilder.forName("simple-reply").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(service))
.build())
.build().start();
channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build();
Iterator<Integer> iter = ClientCalls.blockingServerStreamingCall(
channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req);
assertEquals(resp1, iter.next());
assertTrue(iter.hasNext());
assertEquals(resp2, iter.next());
assertFalse(iter.hasNext());
assertEquals(req, service.request);
}
@Test
public void blockingServerStreamingCall2_interruptedWaitsForOnClose() throws Exception {
Integer req = 2;
class NoopServerStreamingMethod implements ServerStreamingMethod<Integer, Integer> {
ServerCallStreamObserver<Integer> observer;
@Override public void invoke(Integer request, StreamObserver<Integer> responseObserver) {
observer = (ServerCallStreamObserver<Integer>) responseObserver;
}
}
NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod();
server = InProcessServerBuilder.forName("noop").directExecutor()
.addService(ServerServiceDefinition.builder("some")
.addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl))
.build())
.build().start();
InterruptInterceptor interceptor = new InterruptInterceptor();
channel = InProcessChannelBuilder.forName("noop")
.directExecutor()
.intercept(interceptor)
.build();
Iterator<Integer> iter = ClientCalls.blockingServerStreamingCall(
channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req);
try {
iter.next();
fail();
} catch (StatusRuntimeException ex) {
assertTrue(Thread.interrupted());
assertTrue("interrupted", ex.getCause() instanceof InterruptedException);
}
assertTrue("onCloseCalled", interceptor.onCloseCalled);
assertTrue("context not cancelled", methodImpl.observer.isCancelled());
}
// Used for blocking tests to check interrupt behavior and make sure onClose is still called.
class InterruptInterceptor implements ClientInterceptor {
boolean onCloseCalled;
@Override
public <ReqT,RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
@Override public void start(ClientCall.Listener<RespT> listener, Metadata headers) {
super.start(new SimpleForwardingClientCallListener<RespT>(listener) {
@Override public void onClose(Status status, Metadata trailers) {
onCloseCalled = true;
super.onClose(status, trailers);
}
}, headers);
}
@Override public void halfClose() {
Thread.currentThread().interrupt();
super.halfClose();
}
};
}
}
} }