Ensure that trailers are not lost when making blocking calls.

Fixes #2036
This commit is contained in:
nmittler 2016-07-18 09:26:48 -07:00
parent c1ef8061d1
commit 780b2696c1
2 changed files with 101 additions and 20 deletions

View File

@ -31,6 +31,8 @@
package io.grpc.stub;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.ListenableFuture;
@ -38,9 +40,11 @@ import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ExperimentalApi;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;
import java.util.Iterator;
@ -205,10 +209,32 @@ public class ClientCalls {
Thread.currentThread().interrupt();
throw Status.CANCELLED.withCause(e).asRuntimeException();
} catch (ExecutionException e) {
throw Status.fromThrowable(e).asRuntimeException();
throw toStatusRuntimeException(e);
}
}
/**
* Wraps the given {@link Throwable} in a {@link StatusRuntimeException}. If it contains an
* embedded {@link StatusException} or {@link StatusRuntimeException}, the returned exception will
* contain the embedded trailers and status, with the given exception as the cause. Otherwise, an
* exception will be generated from an {@link Status#UNKNOWN} status.
*/
private static StatusRuntimeException toStatusRuntimeException(Throwable t) {
Throwable cause = checkNotNull(t);
while (cause != null) {
// If we have an embedded status, use it and replace the cause
if (cause instanceof StatusException) {
StatusException se = (StatusException) cause;
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
} else if (cause instanceof StatusRuntimeException) {
StatusRuntimeException se = (StatusRuntimeException) cause;
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
}
cause = cause.getCause();
}
return Status.UNKNOWN.withCause(t).asRuntimeException();
}
private static <ReqT, RespT> void asyncUnaryRequestCall(
ClientCall<ReqT, RespT> call, ReqT param, StreamObserver<RespT> responseObserver,
boolean streamingResponse) {

View File

@ -36,6 +36,7 @@ import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -44,6 +45,7 @@ import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
@ -51,6 +53,7 @@ import io.grpc.Server;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.ServerCalls.NoopStreamObserver;
@ -63,7 +66,10 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Arrays;
@ -81,10 +87,11 @@ import java.util.concurrent.TimeUnit;
@RunWith(JUnit4.class)
public class ClientCallsTest {
static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor.create(
MethodDescriptor.MethodType.BIDI_STREAMING,
"some/method",
new IntegerMarshaller(), new IntegerMarshaller());
private static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor
.create(
MethodDescriptor.MethodType.BIDI_STREAMING,
"some/method",
new IntegerMarshaller(), new IntegerMarshaller());
private Server server;
private ManagedChannel channel;
@ -107,6 +114,53 @@ public class ClientCallsTest {
}
}
@Test
public void unaryBlockingCallSuccess() throws Exception {
Integer req = 2;
final String resp = "bar";
final Status status = Status.OK;
final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
listener.onMessage(resp);
listener.onClose(status, trailers);
return null;
}
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
String actualResponse = ClientCalls.blockingUnaryCall(call, req);
assertEquals(resp, actualResponse);
}
@Test
public void unaryBlockingCallFailed() throws Exception {
Integer req = 2;
final Status status = Status.INTERNAL;
final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
listener.onClose(status, trailers);
return null;
}
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
try {
ClientCalls.blockingUnaryCall(call, req);
fail("Should fail");
} catch (StatusRuntimeException e) {
assertSame(status.getCode(), e.getStatus().getCode());
assertSame(trailers, e.getTrailers());
}
}
@Test
public void unaryFutureCallSuccess() throws Exception {
Integer req = 2;
@ -246,24 +300,24 @@ public class ClientCallsTest {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
ClientResponseObserver<Integer, String> responseObserver =
new ClientResponseObserver<Integer, String>() {
@Override
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
requestStream.disableAutoInboundFlowControl();
requestStream.request(5);
}
@Override
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
requestStream.disableAutoInboundFlowControl();
requestStream.request(5);
}
@Override
public void onNext(String value) {
}
@Override
public void onNext(String value) {
}
@Override
public void onError(Throwable t) {
}
@Override
public void onError(Throwable t) {
}
@Override
public void onCompleted() {
}
};
@Override
public void onCompleted() {
}
};
ClientCalls.asyncServerStreamingCall(call, 1, responseObserver);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
listenerCaptor.getValue().onMessage("message");
@ -398,6 +452,7 @@ public class ClientCallsTest {
public void beforeStart(final ClientCallStreamObserver<Integer> requestStream) {
requestStream.setOnReadyHandler(new Runnable() {
int iteration;
@Override
public void run() {
while (requestStream.isReady()) {