From 2b1eee90e5bd7f5ad905e34f73f2040d6c9a3568 Mon Sep 17 00:00:00 2001 From: zpencer Date: Mon, 19 Jun 2017 11:15:22 -0700 Subject: [PATCH] core: Do not leak server state when application callbacks throw exceptions (#3064) Today JumpToApplicationThreadServerStreamListener leaks server state by transmitting details about uncaught StatusRuntimeException throwables to the client. This is a security problem. This PR ensures that uncaught exceptions always close the ServerCall without leaking any state information. Users running in a trusted environment who want to transmit error details can install the TransmitStatusRuntimeExceptionInterceptor. fixes #2189 --- .../java/io/grpc/internal/ServerImpl.java | 16 +- ...smitStatusRuntimeExceptionInterceptor.java | 264 ++++++++++++++++++ .../java/io/grpc/internal/ServerImplTest.java | 27 +- .../grpc/util/UtilServerInterceptorsTest.java | 118 ++++++++ 4 files changed, 405 insertions(+), 20 deletions(-) create mode 100644 core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java create mode 100644 core/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index b5411fa730..27d14321f5 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -546,9 +546,9 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { /** * Like {@link ServerCall#close(Status, Metadata)}, but thread-safe for internal use. */ - private void internalClose(Status status, Metadata trailers) { + private void internalClose() { // TODO(ejona86): this is not thread-safe :) - stream.close(status, trailers); + stream.close(Status.UNKNOWN, new Metadata()); } @Override @@ -559,10 +559,10 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { try { getListener().messageRead(message); } catch (RuntimeException e) { - internalClose(Status.fromThrowable(e), new Metadata()); + internalClose(); throw e; } catch (Error e) { - internalClose(Status.fromThrowable(e), new Metadata()); + internalClose(); throw e; } } @@ -577,10 +577,10 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { try { getListener().halfClosed(); } catch (RuntimeException e) { - internalClose(Status.fromThrowable(e), new Metadata()); + internalClose(); throw e; } catch (Error e) { - internalClose(Status.fromThrowable(e), new Metadata()); + internalClose(); throw e; } } @@ -612,10 +612,10 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { try { getListener().onReady(); } catch (RuntimeException e) { - internalClose(Status.fromThrowable(e), new Metadata()); + internalClose(); throw e; } catch (Error e) { - internalClose(Status.fromThrowable(e), new Metadata()); + internalClose(); throw e; } } diff --git a/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java b/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java new file mode 100644 index 0000000000..420fb3caa7 --- /dev/null +++ b/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java @@ -0,0 +1,264 @@ +/* + * Copyright 2017, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Attributes; +import io.grpc.ExperimentalApi; +import io.grpc.ForwardingServerCall; +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.internal.SerializingExecutor; +import java.util.concurrent.ExecutionException; +import javax.annotation.Nullable; + +/** + * A class that intercepts uncaught exceptions of type {@link StatusRuntimeException} and handles + * them by closing the {@link ServerCall}, and transmitting the exception's status and metadata + * to the client. + * + *

Without this interceptor, gRPC will strip all details and close the {@link ServerCall} with + * a generic {@link Status#UNKNOWN} code. + * + *

Security warning: the {@link Status} and {@link Metadata} may contain sensitive server-side + * state information, and generally should not be sent to clients. Only install this interceptor + * if all clients are trusted. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/2189") +public final class TransmitStatusRuntimeExceptionInterceptor implements ServerInterceptor { + private TransmitStatusRuntimeExceptionInterceptor() { + } + + public static ServerInterceptor instance() { + return new TransmitStatusRuntimeExceptionInterceptor(); + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + final ServerCall serverCall = new SerializingServerCall(call); + ServerCall.Listener listener = next.startCall(serverCall, headers); + return new ForwardingServerCallListener.SimpleForwardingServerCallListener(listener) { + @Override + public void onMessage(ReqT message) { + try { + super.onMessage(message); + } catch (StatusRuntimeException e) { + closeWithException(e); + } + } + + @Override + public void onHalfClose() { + try { + super.onHalfClose(); + } catch (StatusRuntimeException e) { + closeWithException(e); + } + } + + @Override + public void onCancel() { + try { + super.onCancel(); + } catch (StatusRuntimeException e) { + closeWithException(e); + } + } + + @Override + public void onComplete() { + try { + super.onComplete(); + } catch (StatusRuntimeException e) { + closeWithException(e); + } + } + + @Override + public void onReady() { + try { + super.onReady(); + } catch (StatusRuntimeException e) { + closeWithException(e); + } + } + + private void closeWithException(StatusRuntimeException t) { + Metadata metadata = t.getTrailers(); + if (metadata == null) { + metadata = new Metadata(); + } + serverCall.close(t.getStatus(), metadata); + } + }; + } + + /** + * A {@link ServerCall} that wraps around a non thread safe delegate and provides thread safe + * access by serializing everything on an executor. + */ + private static class SerializingServerCall extends + ForwardingServerCall.SimpleForwardingServerCall { + private static final String ERROR_MSG = "Encountered error during serialized access"; + private final SerializingExecutor serializingExecutor = + new SerializingExecutor(MoreExecutors.directExecutor()); + + SerializingServerCall(ServerCall delegate) { + super(delegate); + } + + @Override + public void sendMessage(final RespT message) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.sendMessage(message); + } + }); + } + + @Override + public void request(final int numMessages) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.request(numMessages); + } + }); + } + + @Override + public void sendHeaders(final Metadata headers) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.sendHeaders(headers); + } + }); + } + + @Override + public void close(final Status status, final Metadata trailers) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.close(status, trailers); + } + }); + } + + @Override + public boolean isReady() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.isReady()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } + + @Override + public boolean isCancelled() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.isCancelled()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } + + @Override + public void setMessageCompression(final boolean enabled) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.setMessageCompression(enabled); + } + }); + } + + @Override + public void setCompression(final String compressor) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.setCompression(compressor); + } + }); + } + + @Override + public Attributes getAttributes() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.getAttributes()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } + + @Nullable + @Override + public String getAuthority() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.getAuthority()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } + } +} diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 09f90a4f71..b95c7edc54 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -146,6 +146,8 @@ public class ServerImplTest { @Captor private ArgumentCaptor statusCaptor; @Captor + private ArgumentCaptor metadataCaptor; + @Captor private ArgumentCaptor streamListenerCaptor; @Mock @@ -981,8 +983,7 @@ public class ServerImplTest { fail("Expected exception"); } catch (Throwable t) { assertSame(expectedT, t); - verify(stream).close(statusCaptor.capture(), any(Metadata.class)); - assertSame(expectedT, statusCaptor.getValue().getCause()); + ensureServerStateNotLeaked(); } } @@ -1006,8 +1007,7 @@ public class ServerImplTest { fail("Expected exception"); } catch (Throwable t) { assertSame(expectedT, t); - verify(stream).close(statusCaptor.capture(), any(Metadata.class)); - assertSame(expectedT, statusCaptor.getValue().getCause()); + ensureServerStateNotLeaked(); } } @@ -1030,8 +1030,7 @@ public class ServerImplTest { fail("Expected exception"); } catch (Throwable t) { assertSame(expectedT, t); - verify(stream).close(statusCaptor.capture(), any(Metadata.class)); - assertSame(expectedT, statusCaptor.getValue().getCause()); + ensureServerStateNotLeaked(); } } @@ -1054,8 +1053,7 @@ public class ServerImplTest { fail("Expected exception"); } catch (Throwable t) { assertSame(expectedT, t); - verify(stream).close(statusCaptor.capture(), any(Metadata.class)); - assertSame(expectedT, statusCaptor.getValue().getCause()); + ensureServerStateNotLeaked(); } } @@ -1078,8 +1076,7 @@ public class ServerImplTest { fail("Expected exception"); } catch (Throwable t) { assertSame(expectedT, t); - verify(stream).close(statusCaptor.capture(), any(Metadata.class)); - assertSame(expectedT, statusCaptor.getValue().getCause()); + ensureServerStateNotLeaked(); } } @@ -1102,8 +1099,7 @@ public class ServerImplTest { fail("Expected exception"); } catch (Throwable t) { assertSame(expectedT, t); - verify(stream).close(statusCaptor.capture(), any(Metadata.class)); - assertSame(expectedT, statusCaptor.getValue().getCause()); + ensureServerStateNotLeaked(); } } @@ -1137,6 +1133,13 @@ public class ServerImplTest { verifyNoMoreInteractions(timerPool); } + private void ensureServerStateNotLeaked() { + verify(stream).close(statusCaptor.capture(), metadataCaptor.capture()); + assertEquals(Status.UNKNOWN, statusCaptor.getValue()); + assertNull(statusCaptor.getValue().getCause()); + assertTrue(metadataCaptor.getValue().keys().isEmpty()); + } + private static class SimpleServer implements io.grpc.internal.InternalServer { ServerListener listener; diff --git a/core/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java b/core/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java new file mode 100644 index 0000000000..7bbb1847dc --- /dev/null +++ b/core/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2017, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptors; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServiceDescriptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.testing.NoopServerCall; +import io.grpc.testing.TestMethodDescriptors; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +/** + * Unit test for {@link io.grpc.ServerInterceptor} implementations that come with gRPC. Not to be + * confused with the unit tests that validate gRPC's usage of interceptors. + */ +@RunWith(JUnit4.class) +public class UtilServerInterceptorsTest { + private MethodDescriptor flowMethod = TestMethodDescriptors.noopMethod(); + private ServerCall call = Mockito.spy(new NoopServerCall()); + private final Metadata headers = new Metadata(); + private ServerCallHandler handler = new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + return listener; + } + }; + private ServerServiceDefinition serviceDefinition = + ServerServiceDefinition.builder(new ServiceDescriptor("service_foo", flowMethod)) + .addMethod(flowMethod, handler) + .build(); + private ServerCall.Listener listener; + + @SuppressWarnings("unchecked") + private static ServerMethodDefinition getSoleMethod( + ServerServiceDefinition serviceDef) { + if (serviceDef.getMethods().size() != 1) { + throw new AssertionError("Not exactly one method present"); + } + return (ServerMethodDefinition) getOnlyElement(serviceDef.getMethods()); + } + + @Test + public void statusRuntimeExceptionTransmitter() { + final Status expectedStatus = Status.UNAVAILABLE; + final Metadata expectedMetadata = new Metadata(); + final StatusRuntimeException exception = + new StatusRuntimeException(expectedStatus, expectedMetadata); + listener = new ServerCall.Listener() { + @Override + public void onMessage(String message) { + throw exception; + } + + @Override + public void onHalfClose() { + throw exception; + } + + @Override + public void onCancel() { + throw exception; + } + + @Override + public void onComplete() { + throw exception; + } + + @Override + public void onReady() { + throw exception; + } + }; + + ServerServiceDefinition intercepted = ServerInterceptors.intercept( + serviceDefinition, + Arrays.asList(TransmitStatusRuntimeExceptionInterceptor.instance())); + // The interceptor should have handled the error by directly closing the ServerCall + // and the exception should not propagate to the method's caller + getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onMessage("hello"); + getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onCancel(); + getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onComplete(); + getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onHalfClose(); + getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers).onReady(); + verify(call, times(5)).close(same(expectedStatus), same(expectedMetadata)); + } +}