core: cancel RPC when exception in server onReady

Fixes #2305
This commit is contained in:
Eric Anderson 2016-09-28 10:59:09 -07:00
parent f51316b84a
commit 2bd74c5a92
2 changed files with 144 additions and 3 deletions

View File

@ -38,6 +38,7 @@ import static io.grpc.Status.DEADLINE_EXCEEDED;
import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
@ -524,7 +525,8 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
* Dispatches callbacks onto an application-provided executor and correctly propagates
* exceptions.
*/
private static class JumpToApplicationThreadServerStreamListener implements ServerStreamListener {
@VisibleForTesting
static class JumpToApplicationThreadServerStreamListener implements ServerStreamListener {
private final Executor callExecutor;
private final Context.CancellableContext context;
private final ServerStream stream;
@ -545,7 +547,8 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
return listener;
}
private void setListener(ServerStreamListener listener) {
@VisibleForTesting
void setListener(ServerStreamListener listener) {
Preconditions.checkNotNull(listener, "listener must not be null");
Preconditions.checkState(this.listener == null, "Listener already set");
this.listener = listener;
@ -616,7 +619,15 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
callExecutor.execute(new ContextRunnable(context) {
@Override
public void runInContext() {
getListener().onReady();
try {
getListener().onReady();
} catch (RuntimeException e) {
internalClose(Status.fromThrowable(e), new Metadata());
throw e;
} catch (Error e) {
internalClose(Status.fromThrowable(e), new Metadata());
throw e;
}
}
});
}

View File

@ -45,6 +45,7 @@ import static org.mockito.Matchers.isNotNull;
import static org.mockito.Matchers.notNull;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
@ -77,6 +78,7 @@ import io.grpc.ServerTransportFilter;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.grpc.StringMarshaller;
import io.grpc.internal.ServerImpl.JumpToApplicationThreadServerStreamListener;
import io.grpc.internal.testing.StatsTestUtils;
import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory;
import io.grpc.util.MutableHandlerRegistry;
@ -898,6 +900,134 @@ public class ServerImplTest {
verifyNoMoreInteractions(fallbackRegistry);
}
@Test
public void messageRead_errorCancelsCall() throws Exception {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);
Throwable expectedT = new AssertionError();
doThrow(expectedT).when(mockListener).messageRead(any(InputStream.class));
// Closing the InputStream is done by the delegated listener (generally ServerCallImpl)
listener.messageRead(mock(InputStream.class));
try {
executor.runDueTasks();
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
assertSame(expectedT, statusCaptor.getValue().getCause());
}
}
@Test
public void messageRead_runtimeExceptionCancelsCall() throws Exception {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);
Throwable expectedT = new RuntimeException();
doThrow(expectedT).when(mockListener).messageRead(any(InputStream.class));
// Closing the InputStream is done by the delegated listener (generally ServerCallImpl)
listener.messageRead(mock(InputStream.class));
try {
executor.runDueTasks();
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
assertSame(expectedT, statusCaptor.getValue().getCause());
}
}
@Test
public void halfClosed_errorCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);
Throwable expectedT = new AssertionError();
doThrow(expectedT).when(mockListener).halfClosed();
listener.halfClosed();
try {
executor.runDueTasks();
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
assertSame(expectedT, statusCaptor.getValue().getCause());
}
}
@Test
public void halfClosed_runtimeExceptionCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);
Throwable expectedT = new RuntimeException();
doThrow(expectedT).when(mockListener).halfClosed();
listener.halfClosed();
try {
executor.runDueTasks();
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
assertSame(expectedT, statusCaptor.getValue().getCause());
}
}
@Test
public void onReady_errorCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);
Throwable expectedT = new AssertionError();
doThrow(expectedT).when(mockListener).onReady();
listener.onReady();
try {
executor.runDueTasks();
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
assertSame(expectedT, statusCaptor.getValue().getCause());
}
}
@Test
public void onReady_runtimeExceptionCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);
Throwable expectedT = new RuntimeException();
doThrow(expectedT).when(mockListener).onReady();
listener.onReady();
try {
executor.runDueTasks();
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
assertSame(expectedT, statusCaptor.getValue().getCause());
}
}
private void createAndStartServer(List<ServerTransportFilter> filters) throws IOException {
createServer(filters);
server.start();