From 6d44f46f18fd2dcdbb5be8866cdbf79ab21b110f Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 22 Apr 2019 16:32:06 -0700 Subject: [PATCH] stub: Wait for onClose when blocking stub is interrupted Interceptors need to see the onClose to clean up properly. This also changes an isInterrupted() to interrupted(), since previously the interrupted flag was still set when InterruptedException was thrown. This caused an infinite loop with the new code. Previously, all callers immediately re-set the interrupted flag, so there was no issue. Fixes #5576 --- .../main/java/io/grpc/stub/ClientCalls.java | 79 ++++--- .../java/io/grpc/stub/ClientCallsTest.java | 223 +++++++++++++++++- 2 files changed, 263 insertions(+), 39 deletions(-) diff --git a/stub/src/main/java/io/grpc/stub/ClientCalls.java b/stub/src/main/java/io/grpc/stub/ClientCalls.java index d8ee7a9eb1..02f90ec629 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCalls.java +++ b/stub/src/main/java/io/grpc/stub/ClientCalls.java @@ -123,6 +123,7 @@ public final class ClientCalls { public static RespT blockingUnaryCall( Channel channel, MethodDescriptor method, CallOptions callOptions, ReqT req) { ThreadlessExecutor executor = new ThreadlessExecutor(); + boolean interrupt = false; ClientCall call = channel.newCall(method, callOptions.withExecutor(executor)); try { ListenableFuture responseFuture = futureUnaryCall(call, req); @@ -130,18 +131,22 @@ public final class ClientCalls { try { executor.waitAndDrain(); } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw Status.CANCELLED - .withDescription("Call was interrupted") - .withCause(e) - .asRuntimeException(); + interrupt = true; + call.cancel("Thread interrupted", e); + // Now wait for onClose() to be called, so interceptors can clean up } } return getUnchecked(responseFuture); } catch (RuntimeException e) { + // Something very bad happened. All bets are off; it may be dangerous to wait for onClose(). throw cancelThrow(call, e); } catch (Error e) { + // Something very bad happened. All bets are off; it may be dangerous to wait for onClose(). throw cancelThrow(call, e); + } finally { + if (interrupt) { + Thread.currentThread().interrupt(); + } } } @@ -208,7 +213,7 @@ public final class ClientCalls { } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw Status.CANCELLED - .withDescription("Call was interrupted") + .withDescription("Thread interrupted") .withCause(e) .asRuntimeException(); } catch (ExecutionException e) { @@ -546,30 +551,45 @@ public final class ClientCalls { return listener; } - private Object waitForNext() throws InterruptedException { - if (threadless == null) { - return buffer.take(); - } else { - Object next = buffer.poll(); - while (next == null) { - threadless.waitAndDrain(); - next = buffer.poll(); + private Object waitForNext() { + boolean interrupt = false; + try { + if (threadless == null) { + while (true) { + try { + return buffer.take(); + } catch (InterruptedException ie) { + interrupt = true; + call.cancel("Thread interrupted", ie); + // Now wait for onClose() to be called, to guarantee BlockingQueue doesn't fill + } + } + } else { + Object next; + while ((next = buffer.poll()) == null) { + try { + threadless.waitAndDrain(); + } catch (InterruptedException ie) { + interrupt = true; + call.cancel("Thread interrupted", ie); + // Now wait for onClose() to be called, so interceptors can clean up + } + } + return next; + } + } finally { + if (interrupt) { + Thread.currentThread().interrupt(); } - return next; } } @Override public boolean hasNext() { - if (last == null) { - try { - // Will block here indefinitely waiting for content. RPC timeouts defend against permanent - // hangs here as the call will become closed. - last = waitForNext(); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - throw Status.CANCELLED.withDescription("interrupted").withCause(ie).asRuntimeException(); - } + while (last == null) { + // Will block here indefinitely waiting for content. RPC timeouts defend against permanent + // hangs here as the call will become closed. + last = waitForNext(); } if (last instanceof StatusRuntimeException) { // Rethrow the exception with a new stacktrace. @@ -643,15 +663,14 @@ public final class ClientCalls { * Must only be called by one thread at a time. */ public void waitAndDrain() throws InterruptedException { - final Thread currentThread = Thread.currentThread(); - throwIfInterrupted(currentThread); + throwIfInterrupted(); Runnable runnable = poll(); if (runnable == null) { - waiter = currentThread; + waiter = Thread.currentThread(); try { while ((runnable = poll()) == null) { LockSupport.park(this); - throwIfInterrupted(currentThread); + throwIfInterrupted(); } } finally { waiter = null; @@ -666,8 +685,8 @@ public final class ClientCalls { } while ((runnable = poll()) != null); } - private static void throwIfInterrupted(Thread currentThread) throws InterruptedException { - if (currentThread.isInterrupted()) { + private static void throwIfInterrupted() throws InterruptedException { + if (Thread.interrupted()) { throw new InterruptedException(); } } diff --git a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java index 298f58d2f8..a636487706 100644 --- a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java @@ -18,6 +18,7 @@ package io.grpc.stub; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; @@ -26,7 +27,11 @@ import static org.junit.Assert.fail; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -39,6 +44,8 @@ import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.NoopClientCall; import io.grpc.stub.ServerCalls.NoopStreamObserver; +import io.grpc.stub.ServerCalls.ServerStreamingMethod; +import io.grpc.stub.ServerCalls.UnaryMethod; import io.grpc.stub.ServerCallsTest.IntegerMarshaller; import java.util.ArrayList; import java.util.Arrays; @@ -62,14 +69,17 @@ import org.mockito.MockitoAnnotations; */ @RunWith(JUnit4.class) public class ClientCallsTest { - - private static final MethodDescriptor STREAMING_METHOD = + private static final MethodDescriptor UNARY_METHOD = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName("some/method") .setRequestMarshaller(new IntegerMarshaller()) .setResponseMarshaller(new IntegerMarshaller()) .build(); + private static final MethodDescriptor SERVER_STREAMING_METHOD = + UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.SERVER_STREAMING).build(); + private static final MethodDescriptor BIDI_STREAMING_METHOD = + UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.BIDI_STREAMING).build(); private Server server; private ManagedChannel channel; @@ -130,6 +140,69 @@ public class ClientCallsTest { } } + @Test + public void blockingUnaryCall2_success() throws Exception { + Integer req = 2; + final Integer resp = 3; + + class BasicUnaryResponse implements UnaryMethod { + Integer request; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + this.request = request; + responseObserver.onNext(resp); + responseObserver.onCompleted(); + } + } + + BasicUnaryResponse service = new BasicUnaryResponse(); + server = InProcessServerBuilder.forName("simple-reply").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(service)) + .build()) + .build().start(); + channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build(); + Integer actualResponse = + ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req); + assertEquals(resp, actualResponse); + assertEquals(req, service.request); + } + + @Test + public void blockingUnaryCall2_interruptedWaitsForOnClose() throws Exception { + Integer req = 2; + + class NoopUnaryMethod implements UnaryMethod { + ServerCallStreamObserver observer; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + observer = (ServerCallStreamObserver) responseObserver; + } + } + + NoopUnaryMethod methodImpl = new NoopUnaryMethod(); + server = InProcessServerBuilder.forName("noop").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(methodImpl)) + .build()) + .build().start(); + + InterruptInterceptor interceptor = new InterruptInterceptor(); + channel = InProcessChannelBuilder.forName("noop") + .directExecutor() + .intercept(interceptor) + .build(); + try { + ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req); + fail(); + } catch (StatusRuntimeException ex) { + assertTrue(Thread.interrupted()); + assertTrue("interrupted", ex.getCause() instanceof InterruptedException); + } + assertTrue("onCloseCalled", interceptor.onCloseCalled); + assertTrue("context not cancelled", methodImpl.observer.isCancelled()); + } + @Test public void unaryFutureCallSuccess() throws Exception { final AtomicReference> listener = @@ -372,8 +445,8 @@ public class ClientCallsTest { public void inprocessTransportInboundFlowControl() throws Exception { final Semaphore semaphore = new Semaphore(0); ServerServiceDefinition service = ServerServiceDefinition.builder( - new ServiceDescriptor("some", STREAMING_METHOD)) - .addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( + new ServiceDescriptor("some", BIDI_STREAMING_METHOD)) + .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod() { int iteration; @@ -404,7 +477,7 @@ public class ClientCallsTest { server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() .addService(service).build().start(); channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); - final ClientCall clientCall = channel.newCall(STREAMING_METHOD, + final ClientCall clientCall = channel.newCall(BIDI_STREAMING_METHOD, CallOptions.DEFAULT); final CountDownLatch latch = new CountDownLatch(1); final List receivedMessages = new ArrayList<>(6); @@ -453,8 +526,8 @@ public class ClientCallsTest { final SettableFuture> observerFuture = SettableFuture.create(); ServerServiceDefinition service = ServerServiceDefinition.builder( - new ServiceDescriptor("some", STREAMING_METHOD)) - .addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( + new ServiceDescriptor("some", BIDI_STREAMING_METHOD)) + .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod() { @Override public StreamObserver invoke(StreamObserver responseObserver) { @@ -485,7 +558,7 @@ public class ClientCallsTest { server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() .addService(service).build().start(); channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); - final ClientCall clientCall = channel.newCall(STREAMING_METHOD, + final ClientCall clientCall = channel.newCall(BIDI_STREAMING_METHOD, CallOptions.DEFAULT); final SettableFuture future = SettableFuture.create(); @@ -564,4 +637,136 @@ public class ClientCallsTest { assertSame(trailers, metadata); } } + + @Test + public void blockingServerStreamingCall_interruptedWaitsForOnClose() throws Exception { + Integer req = 2; + + class NoopServerStreamingMethod implements ServerStreamingMethod { + ServerCallStreamObserver observer; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + observer = (ServerCallStreamObserver) responseObserver; + } + } + + NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod(); + server = InProcessServerBuilder.forName("noop").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl)) + .build()) + .build().start(); + + InterruptInterceptor interceptor = new InterruptInterceptor(); + channel = InProcessChannelBuilder.forName("noop") + .directExecutor() + .intercept(interceptor) + .build(); + Iterator iter = ClientCalls.blockingServerStreamingCall( + channel.newCall(SERVER_STREAMING_METHOD, CallOptions.DEFAULT), req); + try { + iter.next(); + fail(); + } catch (StatusRuntimeException ex) { + assertTrue(Thread.interrupted()); + assertTrue("interrupted", ex.getCause() instanceof InterruptedException); + } + assertTrue("onCloseCalled", interceptor.onCloseCalled); + assertTrue("context not cancelled", methodImpl.observer.isCancelled()); + } + + @Test + public void blockingServerStreamingCall2_success() throws Exception { + Integer req = 2; + final Integer resp1 = 3; + final Integer resp2 = 4; + + class BasicServerStreamingResponse implements ServerStreamingMethod { + Integer request; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + this.request = request; + responseObserver.onNext(resp1); + responseObserver.onNext(resp2); + responseObserver.onCompleted(); + } + } + + BasicServerStreamingResponse service = new BasicServerStreamingResponse(); + server = InProcessServerBuilder.forName("simple-reply").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(service)) + .build()) + .build().start(); + channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build(); + Iterator iter = ClientCalls.blockingServerStreamingCall( + channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req); + assertEquals(resp1, iter.next()); + assertTrue(iter.hasNext()); + assertEquals(resp2, iter.next()); + assertFalse(iter.hasNext()); + assertEquals(req, service.request); + } + + @Test + public void blockingServerStreamingCall2_interruptedWaitsForOnClose() throws Exception { + Integer req = 2; + + class NoopServerStreamingMethod implements ServerStreamingMethod { + ServerCallStreamObserver observer; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + observer = (ServerCallStreamObserver) responseObserver; + } + } + + NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod(); + server = InProcessServerBuilder.forName("noop").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl)) + .build()) + .build().start(); + + InterruptInterceptor interceptor = new InterruptInterceptor(); + channel = InProcessChannelBuilder.forName("noop") + .directExecutor() + .intercept(interceptor) + .build(); + Iterator iter = ClientCalls.blockingServerStreamingCall( + channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req); + try { + iter.next(); + fail(); + } catch (StatusRuntimeException ex) { + assertTrue(Thread.interrupted()); + assertTrue("interrupted", ex.getCause() instanceof InterruptedException); + } + assertTrue("onCloseCalled", interceptor.onCloseCalled); + assertTrue("context not cancelled", methodImpl.observer.isCancelled()); + } + + // Used for blocking tests to check interrupt behavior and make sure onClose is still called. + class InterruptInterceptor implements ClientInterceptor { + boolean onCloseCalled; + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return new SimpleForwardingClientCall(next.newCall(method, callOptions)) { + @Override public void start(ClientCall.Listener listener, Metadata headers) { + super.start(new SimpleForwardingClientCallListener(listener) { + @Override public void onClose(Status status, Metadata trailers) { + onCloseCalled = true; + super.onClose(status, trailers); + } + }, headers); + } + + @Override public void halfClose() { + Thread.currentThread().interrupt(); + super.halfClose(); + } + }; + } + } }