Ensure that trailers are not lost when making blocking calls.

Fixes #2036
This commit is contained in:
nmittler 2016-07-18 09:26:48 -07:00
parent c1ef8061d1
commit 780b2696c1
2 changed files with 101 additions and 20 deletions

View File

@ -31,6 +31,8 @@
package io.grpc.stub; package io.grpc.stub;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
@ -38,9 +40,11 @@ import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ExperimentalApi;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import java.util.Iterator; import java.util.Iterator;
@ -205,10 +209,32 @@ public class ClientCalls {
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
throw Status.CANCELLED.withCause(e).asRuntimeException(); throw Status.CANCELLED.withCause(e).asRuntimeException();
} catch (ExecutionException e) { } catch (ExecutionException e) {
throw Status.fromThrowable(e).asRuntimeException(); throw toStatusRuntimeException(e);
} }
} }
/**
* Wraps the given {@link Throwable} in a {@link StatusRuntimeException}. If it contains an
* embedded {@link StatusException} or {@link StatusRuntimeException}, the returned exception will
* contain the embedded trailers and status, with the given exception as the cause. Otherwise, an
* exception will be generated from an {@link Status#UNKNOWN} status.
*/
private static StatusRuntimeException toStatusRuntimeException(Throwable t) {
Throwable cause = checkNotNull(t);
while (cause != null) {
// If we have an embedded status, use it and replace the cause
if (cause instanceof StatusException) {
StatusException se = (StatusException) cause;
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
} else if (cause instanceof StatusRuntimeException) {
StatusRuntimeException se = (StatusRuntimeException) cause;
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
}
cause = cause.getCause();
}
return Status.UNKNOWN.withCause(t).asRuntimeException();
}
private static <ReqT, RespT> void asyncUnaryRequestCall( private static <ReqT, RespT> void asyncUnaryRequestCall(
ClientCall<ReqT, RespT> call, ReqT param, StreamObserver<RespT> responseObserver, ClientCall<ReqT, RespT> call, ReqT param, StreamObserver<RespT> responseObserver,
boolean streamingResponse) { boolean streamingResponse) {

View File

@ -36,6 +36,7 @@ import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -44,6 +45,7 @@ import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
@ -51,6 +53,7 @@ import io.grpc.Server;
import io.grpc.ServerServiceDefinition; import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor; import io.grpc.ServiceDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.ServerCalls.NoopStreamObserver; import io.grpc.stub.ServerCalls.NoopStreamObserver;
@ -63,7 +66,10 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -81,10 +87,11 @@ import java.util.concurrent.TimeUnit;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class ClientCallsTest { public class ClientCallsTest {
static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor.create( private static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor
MethodDescriptor.MethodType.BIDI_STREAMING, .create(
"some/method", MethodDescriptor.MethodType.BIDI_STREAMING,
new IntegerMarshaller(), new IntegerMarshaller()); "some/method",
new IntegerMarshaller(), new IntegerMarshaller());
private Server server; private Server server;
private ManagedChannel channel; private ManagedChannel channel;
@ -107,6 +114,53 @@ public class ClientCallsTest {
} }
} }
@Test
public void unaryBlockingCallSuccess() throws Exception {
Integer req = 2;
final String resp = "bar";
final Status status = Status.OK;
final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
listener.onMessage(resp);
listener.onClose(status, trailers);
return null;
}
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
String actualResponse = ClientCalls.blockingUnaryCall(call, req);
assertEquals(resp, actualResponse);
}
@Test
public void unaryBlockingCallFailed() throws Exception {
Integer req = 2;
final Status status = Status.INTERNAL;
final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
listener.onClose(status, trailers);
return null;
}
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
try {
ClientCalls.blockingUnaryCall(call, req);
fail("Should fail");
} catch (StatusRuntimeException e) {
assertSame(status.getCode(), e.getStatus().getCode());
assertSame(trailers, e.getTrailers());
}
}
@Test @Test
public void unaryFutureCallSuccess() throws Exception { public void unaryFutureCallSuccess() throws Exception {
Integer req = 2; Integer req = 2;
@ -246,24 +300,24 @@ public class ClientCallsTest {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
ClientResponseObserver<Integer, String> responseObserver = ClientResponseObserver<Integer, String> responseObserver =
new ClientResponseObserver<Integer, String>() { new ClientResponseObserver<Integer, String>() {
@Override @Override
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) { public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
requestStream.disableAutoInboundFlowControl(); requestStream.disableAutoInboundFlowControl();
requestStream.request(5); requestStream.request(5);
} }
@Override @Override
public void onNext(String value) { public void onNext(String value) {
} }
@Override @Override
public void onError(Throwable t) { public void onError(Throwable t) {
} }
@Override @Override
public void onCompleted() { public void onCompleted() {
} }
}; };
ClientCalls.asyncServerStreamingCall(call, 1, responseObserver); ClientCalls.asyncServerStreamingCall(call, 1, responseObserver);
verify(call).start(listenerCaptor.capture(), any(Metadata.class)); verify(call).start(listenerCaptor.capture(), any(Metadata.class));
listenerCaptor.getValue().onMessage("message"); listenerCaptor.getValue().onMessage("message");
@ -398,6 +452,7 @@ public class ClientCallsTest {
public void beforeStart(final ClientCallStreamObserver<Integer> requestStream) { public void beforeStart(final ClientCallStreamObserver<Integer> requestStream) {
requestStream.setOnReadyHandler(new Runnable() { requestStream.setOnReadyHandler(new Runnable() {
int iteration; int iteration;
@Override @Override
public void run() { public void run() {
while (requestStream.isReady()) { while (requestStream.isReady()) {