From 2bd74c5a921b45652cfae10355153d1ae6be2c67 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 28 Sep 2016 10:59:09 -0700 Subject: [PATCH] core: cancel RPC when exception in server onReady Fixes #2305 --- .../java/io/grpc/internal/ServerImpl.java | 17 ++- .../java/io/grpc/internal/ServerImplTest.java | 130 ++++++++++++++++++ 2 files changed, 144 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 33fa3b0bf8..bc17761f6b 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -38,6 +38,7 @@ import static io.grpc.Status.DEADLINE_EXCEEDED; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; import static java.util.concurrent.TimeUnit.NANOSECONDS; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; @@ -524,7 +525,8 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { * Dispatches callbacks onto an application-provided executor and correctly propagates * exceptions. */ - private static class JumpToApplicationThreadServerStreamListener implements ServerStreamListener { + @VisibleForTesting + static class JumpToApplicationThreadServerStreamListener implements ServerStreamListener { private final Executor callExecutor; private final Context.CancellableContext context; private final ServerStream stream; @@ -545,7 +547,8 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { return listener; } - private void setListener(ServerStreamListener listener) { + @VisibleForTesting + void setListener(ServerStreamListener listener) { Preconditions.checkNotNull(listener, "listener must not be null"); Preconditions.checkState(this.listener == null, "Listener already set"); this.listener = listener; @@ -616,7 +619,15 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { callExecutor.execute(new ContextRunnable(context) { @Override public void runInContext() { - getListener().onReady(); + try { + getListener().onReady(); + } catch (RuntimeException e) { + internalClose(Status.fromThrowable(e), new Metadata()); + throw e; + } catch (Error e) { + internalClose(Status.fromThrowable(e), new Metadata()); + throw e; + } } }); } diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index dda7d17a6f..d51e41c309 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -45,6 +45,7 @@ import static org.mockito.Matchers.isNotNull; import static org.mockito.Matchers.notNull; import static org.mockito.Matchers.same; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -77,6 +78,7 @@ import io.grpc.ServerTransportFilter; import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.StringMarshaller; +import io.grpc.internal.ServerImpl.JumpToApplicationThreadServerStreamListener; import io.grpc.internal.testing.StatsTestUtils; import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory; import io.grpc.util.MutableHandlerRegistry; @@ -898,6 +900,134 @@ public class ServerImplTest { verifyNoMoreInteractions(fallbackRegistry); } + @Test + public void messageRead_errorCancelsCall() throws Exception { + JumpToApplicationThreadServerStreamListener listener + = new JumpToApplicationThreadServerStreamListener( + executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation()); + ServerStreamListener mockListener = mock(ServerStreamListener.class); + listener.setListener(mockListener); + + Throwable expectedT = new AssertionError(); + doThrow(expectedT).when(mockListener).messageRead(any(InputStream.class)); + // Closing the InputStream is done by the delegated listener (generally ServerCallImpl) + listener.messageRead(mock(InputStream.class)); + try { + executor.runDueTasks(); + fail("Expected exception"); + } catch (Throwable t) { + assertSame(expectedT, t); + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + assertSame(expectedT, statusCaptor.getValue().getCause()); + } + } + + @Test + public void messageRead_runtimeExceptionCancelsCall() throws Exception { + JumpToApplicationThreadServerStreamListener listener + = new JumpToApplicationThreadServerStreamListener( + executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation()); + ServerStreamListener mockListener = mock(ServerStreamListener.class); + listener.setListener(mockListener); + + Throwable expectedT = new RuntimeException(); + doThrow(expectedT).when(mockListener).messageRead(any(InputStream.class)); + // Closing the InputStream is done by the delegated listener (generally ServerCallImpl) + listener.messageRead(mock(InputStream.class)); + try { + executor.runDueTasks(); + fail("Expected exception"); + } catch (Throwable t) { + assertSame(expectedT, t); + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + assertSame(expectedT, statusCaptor.getValue().getCause()); + } + } + + @Test + public void halfClosed_errorCancelsCall() { + JumpToApplicationThreadServerStreamListener listener + = new JumpToApplicationThreadServerStreamListener( + executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation()); + ServerStreamListener mockListener = mock(ServerStreamListener.class); + listener.setListener(mockListener); + + Throwable expectedT = new AssertionError(); + doThrow(expectedT).when(mockListener).halfClosed(); + listener.halfClosed(); + try { + executor.runDueTasks(); + fail("Expected exception"); + } catch (Throwable t) { + assertSame(expectedT, t); + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + assertSame(expectedT, statusCaptor.getValue().getCause()); + } + } + + @Test + public void halfClosed_runtimeExceptionCancelsCall() { + JumpToApplicationThreadServerStreamListener listener + = new JumpToApplicationThreadServerStreamListener( + executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation()); + ServerStreamListener mockListener = mock(ServerStreamListener.class); + listener.setListener(mockListener); + + Throwable expectedT = new RuntimeException(); + doThrow(expectedT).when(mockListener).halfClosed(); + listener.halfClosed(); + try { + executor.runDueTasks(); + fail("Expected exception"); + } catch (Throwable t) { + assertSame(expectedT, t); + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + assertSame(expectedT, statusCaptor.getValue().getCause()); + } + } + + @Test + public void onReady_errorCancelsCall() { + JumpToApplicationThreadServerStreamListener listener + = new JumpToApplicationThreadServerStreamListener( + executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation()); + ServerStreamListener mockListener = mock(ServerStreamListener.class); + listener.setListener(mockListener); + + Throwable expectedT = new AssertionError(); + doThrow(expectedT).when(mockListener).onReady(); + listener.onReady(); + try { + executor.runDueTasks(); + fail("Expected exception"); + } catch (Throwable t) { + assertSame(expectedT, t); + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + assertSame(expectedT, statusCaptor.getValue().getCause()); + } + } + + @Test + public void onReady_runtimeExceptionCancelsCall() { + JumpToApplicationThreadServerStreamListener listener + = new JumpToApplicationThreadServerStreamListener( + executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation()); + ServerStreamListener mockListener = mock(ServerStreamListener.class); + listener.setListener(mockListener); + + Throwable expectedT = new RuntimeException(); + doThrow(expectedT).when(mockListener).onReady(); + listener.onReady(); + try { + executor.runDueTasks(); + fail("Expected exception"); + } catch (Throwable t) { + assertSame(expectedT, t); + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + assertSame(expectedT, statusCaptor.getValue().getCause()); + } + } + private void createAndStartServer(List filters) throws IOException { createServer(filters); server.start();