diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java
index b5411fa730..27d14321f5 100644
--- a/core/src/main/java/io/grpc/internal/ServerImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerImpl.java
@@ -546,9 +546,9 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
/**
* Like {@link ServerCall#close(Status, Metadata)}, but thread-safe for internal use.
*/
- private void internalClose(Status status, Metadata trailers) {
+ private void internalClose() {
// TODO(ejona86): this is not thread-safe :)
- stream.close(status, trailers);
+ stream.close(Status.UNKNOWN, new Metadata());
}
@Override
@@ -559,10 +559,10 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
try {
getListener().messageRead(message);
} catch (RuntimeException e) {
- internalClose(Status.fromThrowable(e), new Metadata());
+ internalClose();
throw e;
} catch (Error e) {
- internalClose(Status.fromThrowable(e), new Metadata());
+ internalClose();
throw e;
}
}
@@ -577,10 +577,10 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
try {
getListener().halfClosed();
} catch (RuntimeException e) {
- internalClose(Status.fromThrowable(e), new Metadata());
+ internalClose();
throw e;
} catch (Error e) {
- internalClose(Status.fromThrowable(e), new Metadata());
+ internalClose();
throw e;
}
}
@@ -612,10 +612,10 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
try {
getListener().onReady();
} catch (RuntimeException e) {
- internalClose(Status.fromThrowable(e), new Metadata());
+ internalClose();
throw e;
} catch (Error e) {
- internalClose(Status.fromThrowable(e), new Metadata());
+ internalClose();
throw e;
}
}
diff --git a/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java b/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java
new file mode 100644
index 0000000000..420fb3caa7
--- /dev/null
+++ b/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java
@@ -0,0 +1,264 @@
+/*
+ * Copyright 2017, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.util;
+
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.SettableFuture;
+import io.grpc.Attributes;
+import io.grpc.ExperimentalApi;
+import io.grpc.ForwardingServerCall;
+import io.grpc.ForwardingServerCallListener;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.internal.SerializingExecutor;
+import java.util.concurrent.ExecutionException;
+import javax.annotation.Nullable;
+
+/**
+ * A class that intercepts uncaught exceptions of type {@link StatusRuntimeException} and handles
+ * them by closing the {@link ServerCall}, and transmitting the exception's status and metadata
+ * to the client.
+ *
+ *
Without this interceptor, gRPC will strip all details and close the {@link ServerCall} with
+ * a generic {@link Status#UNKNOWN} code.
+ *
+ *
Security warning: the {@link Status} and {@link Metadata} may contain sensitive server-side
+ * state information, and generally should not be sent to clients. Only install this interceptor
+ * if all clients are trusted.
+ */
+@ExperimentalApi("https://github.com/grpc/grpc-java/issues/2189")
+public final class TransmitStatusRuntimeExceptionInterceptor implements ServerInterceptor {
+ private TransmitStatusRuntimeExceptionInterceptor() {
+ }
+
+ public static ServerInterceptor instance() {
+ return new TransmitStatusRuntimeExceptionInterceptor();
+ }
+
+ @Override
+ public ServerCall.Listener interceptCall(
+ ServerCall call, Metadata headers, ServerCallHandler next) {
+ final ServerCall serverCall = new SerializingServerCall(call);
+ ServerCall.Listener listener = next.startCall(serverCall, headers);
+ return new ForwardingServerCallListener.SimpleForwardingServerCallListener(listener) {
+ @Override
+ public void onMessage(ReqT message) {
+ try {
+ super.onMessage(message);
+ } catch (StatusRuntimeException e) {
+ closeWithException(e);
+ }
+ }
+
+ @Override
+ public void onHalfClose() {
+ try {
+ super.onHalfClose();
+ } catch (StatusRuntimeException e) {
+ closeWithException(e);
+ }
+ }
+
+ @Override
+ public void onCancel() {
+ try {
+ super.onCancel();
+ } catch (StatusRuntimeException e) {
+ closeWithException(e);
+ }
+ }
+
+ @Override
+ public void onComplete() {
+ try {
+ super.onComplete();
+ } catch (StatusRuntimeException e) {
+ closeWithException(e);
+ }
+ }
+
+ @Override
+ public void onReady() {
+ try {
+ super.onReady();
+ } catch (StatusRuntimeException e) {
+ closeWithException(e);
+ }
+ }
+
+ private void closeWithException(StatusRuntimeException t) {
+ Metadata metadata = t.getTrailers();
+ if (metadata == null) {
+ metadata = new Metadata();
+ }
+ serverCall.close(t.getStatus(), metadata);
+ }
+ };
+ }
+
+ /**
+ * A {@link ServerCall} that wraps around a non thread safe delegate and provides thread safe
+ * access by serializing everything on an executor.
+ */
+ private static class SerializingServerCall extends
+ ForwardingServerCall.SimpleForwardingServerCall {
+ private static final String ERROR_MSG = "Encountered error during serialized access";
+ private final SerializingExecutor serializingExecutor =
+ new SerializingExecutor(MoreExecutors.directExecutor());
+
+ SerializingServerCall(ServerCall delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public void sendMessage(final RespT message) {
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ SerializingServerCall.super.sendMessage(message);
+ }
+ });
+ }
+
+ @Override
+ public void request(final int numMessages) {
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ SerializingServerCall.super.request(numMessages);
+ }
+ });
+ }
+
+ @Override
+ public void sendHeaders(final Metadata headers) {
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ SerializingServerCall.super.sendHeaders(headers);
+ }
+ });
+ }
+
+ @Override
+ public void close(final Status status, final Metadata trailers) {
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ SerializingServerCall.super.close(status, trailers);
+ }
+ });
+ }
+
+ @Override
+ public boolean isReady() {
+ final SettableFuture retVal = SettableFuture.create();
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ retVal.set(SerializingServerCall.super.isReady());
+ }
+ });
+ try {
+ return retVal.get();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ } catch (ExecutionException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ }
+ }
+
+ @Override
+ public boolean isCancelled() {
+ final SettableFuture retVal = SettableFuture.create();
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ retVal.set(SerializingServerCall.super.isCancelled());
+ }
+ });
+ try {
+ return retVal.get();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ } catch (ExecutionException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ }
+ }
+
+ @Override
+ public void setMessageCompression(final boolean enabled) {
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ SerializingServerCall.super.setMessageCompression(enabled);
+ }
+ });
+ }
+
+ @Override
+ public void setCompression(final String compressor) {
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ SerializingServerCall.super.setCompression(compressor);
+ }
+ });
+ }
+
+ @Override
+ public Attributes getAttributes() {
+ final SettableFuture retVal = SettableFuture.create();
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ retVal.set(SerializingServerCall.super.getAttributes());
+ }
+ });
+ try {
+ return retVal.get();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ } catch (ExecutionException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ }
+ }
+
+ @Nullable
+ @Override
+ public String getAuthority() {
+ final SettableFuture retVal = SettableFuture.create();
+ serializingExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ retVal.set(SerializingServerCall.super.getAuthority());
+ }
+ });
+ try {
+ return retVal.get();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ } catch (ExecutionException e) {
+ throw new RuntimeException(ERROR_MSG, e);
+ }
+ }
+ }
+}
diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java
index 09f90a4f71..b95c7edc54 100644
--- a/core/src/test/java/io/grpc/internal/ServerImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java
@@ -146,6 +146,8 @@ public class ServerImplTest {
@Captor
private ArgumentCaptor statusCaptor;
@Captor
+ private ArgumentCaptor metadataCaptor;
+ @Captor
private ArgumentCaptor streamListenerCaptor;
@Mock
@@ -981,8 +983,7 @@ public class ServerImplTest {
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
- verify(stream).close(statusCaptor.capture(), any(Metadata.class));
- assertSame(expectedT, statusCaptor.getValue().getCause());
+ ensureServerStateNotLeaked();
}
}
@@ -1006,8 +1007,7 @@ public class ServerImplTest {
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
- verify(stream).close(statusCaptor.capture(), any(Metadata.class));
- assertSame(expectedT, statusCaptor.getValue().getCause());
+ ensureServerStateNotLeaked();
}
}
@@ -1030,8 +1030,7 @@ public class ServerImplTest {
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
- verify(stream).close(statusCaptor.capture(), any(Metadata.class));
- assertSame(expectedT, statusCaptor.getValue().getCause());
+ ensureServerStateNotLeaked();
}
}
@@ -1054,8 +1053,7 @@ public class ServerImplTest {
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
- verify(stream).close(statusCaptor.capture(), any(Metadata.class));
- assertSame(expectedT, statusCaptor.getValue().getCause());
+ ensureServerStateNotLeaked();
}
}
@@ -1078,8 +1076,7 @@ public class ServerImplTest {
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
- verify(stream).close(statusCaptor.capture(), any(Metadata.class));
- assertSame(expectedT, statusCaptor.getValue().getCause());
+ ensureServerStateNotLeaked();
}
}
@@ -1102,8 +1099,7 @@ public class ServerImplTest {
fail("Expected exception");
} catch (Throwable t) {
assertSame(expectedT, t);
- verify(stream).close(statusCaptor.capture(), any(Metadata.class));
- assertSame(expectedT, statusCaptor.getValue().getCause());
+ ensureServerStateNotLeaked();
}
}
@@ -1137,6 +1133,13 @@ public class ServerImplTest {
verifyNoMoreInteractions(timerPool);
}
+ private void ensureServerStateNotLeaked() {
+ verify(stream).close(statusCaptor.capture(), metadataCaptor.capture());
+ assertEquals(Status.UNKNOWN, statusCaptor.getValue());
+ assertNull(statusCaptor.getValue().getCause());
+ assertTrue(metadataCaptor.getValue().keys().isEmpty());
+ }
+
private static class SimpleServer implements io.grpc.internal.InternalServer {
ServerListener listener;
diff --git a/core/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java b/core/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java
new file mode 100644
index 0000000000..7bbb1847dc
--- /dev/null
+++ b/core/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2017, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.util;
+
+import static com.google.common.collect.Iterables.getOnlyElement;
+import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptors;
+import io.grpc.ServerMethodDefinition;
+import io.grpc.ServerServiceDefinition;
+import io.grpc.ServiceDescriptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.testing.NoopServerCall;
+import io.grpc.testing.TestMethodDescriptors;
+import java.util.Arrays;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+/**
+ * Unit test for {@link io.grpc.ServerInterceptor} implementations that come with gRPC. Not to be
+ * confused with the unit tests that validate gRPC's usage of interceptors.
+ */
+@RunWith(JUnit4.class)
+public class UtilServerInterceptorsTest {
+ private MethodDescriptor flowMethod = TestMethodDescriptors.noopMethod();
+ private ServerCall call = Mockito.spy(new NoopServerCall());
+ private final Metadata headers = new Metadata();
+ private ServerCallHandler handler = new ServerCallHandler() {
+ @Override
+ public ServerCall.Listener startCall(
+ ServerCall call, Metadata headers) {
+ return listener;
+ }
+ };
+ private ServerServiceDefinition serviceDefinition =
+ ServerServiceDefinition.builder(new ServiceDescriptor("service_foo", flowMethod))
+ .addMethod(flowMethod, handler)
+ .build();
+ private ServerCall.Listener listener;
+
+ @SuppressWarnings("unchecked")
+ private static ServerMethodDefinition getSoleMethod(
+ ServerServiceDefinition serviceDef) {
+ if (serviceDef.getMethods().size() != 1) {
+ throw new AssertionError("Not exactly one method present");
+ }
+ return (ServerMethodDefinition) getOnlyElement(serviceDef.getMethods());
+ }
+
+ @Test
+ public void statusRuntimeExceptionTransmitter() {
+ final Status expectedStatus = Status.UNAVAILABLE;
+ final Metadata expectedMetadata = new Metadata();
+ final StatusRuntimeException exception =
+ new StatusRuntimeException(expectedStatus, expectedMetadata);
+ listener = new ServerCall.Listener() {
+ @Override
+ public void onMessage(String message) {
+ throw exception;
+ }
+
+ @Override
+ public void onHalfClose() {
+ throw exception;
+ }
+
+ @Override
+ public void onCancel() {
+ throw exception;
+ }
+
+ @Override
+ public void onComplete() {
+ throw exception;
+ }
+
+ @Override
+ public void onReady() {
+ throw exception;
+ }
+ };
+
+ ServerServiceDefinition intercepted = ServerInterceptors.intercept(
+ serviceDefinition,
+ Arrays.asList(TransmitStatusRuntimeExceptionInterceptor.instance()));
+ // The interceptor should have handled the error by directly closing the ServerCall
+ // and the exception should not propagate to the method's caller
+ getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onMessage("hello");
+ getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onCancel();
+ getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onComplete();
+ getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onHalfClose();
+ getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onReady();
+ verify(call, times(5)).close(same(expectedStatus), same(expectedMetadata));
+ }
+}