From 780b2696c1941fb1f8d1f75c7853ca5d18c636ab Mon Sep 17 00:00:00 2001 From: nmittler Date: Mon, 18 Jul 2016 09:26:48 -0700 Subject: [PATCH] Ensure that trailers are not lost when making blocking calls. Fixes #2036 --- .../main/java/io/grpc/stub/ClientCalls.java | 28 +++++- .../java/io/grpc/stub/ClientCallsTest.java | 93 +++++++++++++++---- 2 files changed, 101 insertions(+), 20 deletions(-) diff --git a/stub/src/main/java/io/grpc/stub/ClientCalls.java b/stub/src/main/java/io/grpc/stub/ClientCalls.java index 5296497635..b3cb880ac4 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCalls.java +++ b/stub/src/main/java/io/grpc/stub/ClientCalls.java @@ -31,6 +31,8 @@ package io.grpc.stub; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.common.base.Preconditions; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.ListenableFuture; @@ -38,9 +40,11 @@ import com.google.common.util.concurrent.ListenableFuture; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ExperimentalApi; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import io.grpc.StatusException; import io.grpc.StatusRuntimeException; import java.util.Iterator; @@ -205,10 +209,32 @@ public class ClientCalls { Thread.currentThread().interrupt(); throw Status.CANCELLED.withCause(e).asRuntimeException(); } catch (ExecutionException e) { - throw Status.fromThrowable(e).asRuntimeException(); + throw toStatusRuntimeException(e); } } + /** + * Wraps the given {@link Throwable} in a {@link StatusRuntimeException}. If it contains an + * embedded {@link StatusException} or {@link StatusRuntimeException}, the returned exception will + * contain the embedded trailers and status, with the given exception as the cause. Otherwise, an + * exception will be generated from an {@link Status#UNKNOWN} status. + */ + private static StatusRuntimeException toStatusRuntimeException(Throwable t) { + Throwable cause = checkNotNull(t); + while (cause != null) { + // If we have an embedded status, use it and replace the cause + if (cause instanceof StatusException) { + StatusException se = (StatusException) cause; + return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers()); + } else if (cause instanceof StatusRuntimeException) { + StatusRuntimeException se = (StatusRuntimeException) cause; + return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers()); + } + cause = cause.getCause(); + } + return Status.UNKNOWN.withCause(t).asRuntimeException(); + } + private static void asyncUnaryRequestCall( ClientCall call, ReqT param, StreamObserver responseObserver, boolean streamingResponse) { diff --git a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java index 4547e94681..079a1d49a0 100644 --- a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java @@ -36,6 +36,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -44,6 +45,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -51,6 +53,7 @@ import io.grpc.Server; import io.grpc.ServerServiceDefinition; import io.grpc.ServiceDescriptor; import io.grpc.Status; +import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.stub.ServerCalls.NoopStreamObserver; @@ -63,7 +66,10 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.util.ArrayList; import java.util.Arrays; @@ -81,10 +87,11 @@ import java.util.concurrent.TimeUnit; @RunWith(JUnit4.class) public class ClientCallsTest { - static final MethodDescriptor STREAMING_METHOD = MethodDescriptor.create( - MethodDescriptor.MethodType.BIDI_STREAMING, - "some/method", - new IntegerMarshaller(), new IntegerMarshaller()); + private static final MethodDescriptor STREAMING_METHOD = MethodDescriptor + .create( + MethodDescriptor.MethodType.BIDI_STREAMING, + "some/method", + new IntegerMarshaller(), new IntegerMarshaller()); private Server server; private ManagedChannel channel; @@ -107,6 +114,53 @@ public class ClientCallsTest { } } + @Test + public void unaryBlockingCallSuccess() throws Exception { + Integer req = 2; + final String resp = "bar"; + final Status status = Status.OK; + final Metadata trailers = new Metadata(); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + @SuppressWarnings("unchecked") + Listener listener = (Listener) in.getArguments()[0]; + listener.onMessage(resp); + listener.onClose(status, trailers); + return null; + } + }).when(call).start(Mockito.>any(), any(Metadata.class)); + + String actualResponse = ClientCalls.blockingUnaryCall(call, req); + assertEquals(resp, actualResponse); + } + + @Test + public void unaryBlockingCallFailed() throws Exception { + Integer req = 2; + final Status status = Status.INTERNAL; + final Metadata trailers = new Metadata(); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + @SuppressWarnings("unchecked") + Listener listener = (Listener) in.getArguments()[0]; + listener.onClose(status, trailers); + return null; + } + }).when(call).start(Mockito.>any(), any(Metadata.class)); + + try { + ClientCalls.blockingUnaryCall(call, req); + fail("Should fail"); + } catch (StatusRuntimeException e) { + assertSame(status.getCode(), e.getStatus().getCode()); + assertSame(trailers, e.getTrailers()); + } + } + @Test public void unaryFutureCallSuccess() throws Exception { Integer req = 2; @@ -246,24 +300,24 @@ public class ClientCallsTest { ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); ClientResponseObserver responseObserver = new ClientResponseObserver() { - @Override - public void beforeStart(ClientCallStreamObserver requestStream) { - requestStream.disableAutoInboundFlowControl(); - requestStream.request(5); - } + @Override + public void beforeStart(ClientCallStreamObserver requestStream) { + requestStream.disableAutoInboundFlowControl(); + requestStream.request(5); + } - @Override - public void onNext(String value) { - } + @Override + public void onNext(String value) { + } - @Override - public void onError(Throwable t) { - } + @Override + public void onError(Throwable t) { + } - @Override - public void onCompleted() { - } - }; + @Override + public void onCompleted() { + } + }; ClientCalls.asyncServerStreamingCall(call, 1, responseObserver); verify(call).start(listenerCaptor.capture(), any(Metadata.class)); listenerCaptor.getValue().onMessage("message"); @@ -398,6 +452,7 @@ public class ClientCallsTest { public void beforeStart(final ClientCallStreamObserver requestStream) { requestStream.setOnReadyHandler(new Runnable() { int iteration; + @Override public void run() { while (requestStream.isReady()) {