mirror of https://github.com/grpc/grpc-java.git
Ensure that trailers are not lost when making blocking calls.
Fixes #2036
This commit is contained in:
parent
c1ef8061d1
commit
780b2696c1
|
|
@ -31,6 +31,8 @@
|
|||
|
||||
package io.grpc.stub;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.util.concurrent.AbstractFuture;
|
||||
import com.google.common.util.concurrent.ListenableFuture;
|
||||
|
|
@ -38,9 +40,11 @@ import com.google.common.util.concurrent.ListenableFuture;
|
|||
import io.grpc.CallOptions;
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.ClientCall;
|
||||
import io.grpc.ExperimentalApi;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.MethodDescriptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
|
@ -205,10 +209,32 @@ public class ClientCalls {
|
|||
Thread.currentThread().interrupt();
|
||||
throw Status.CANCELLED.withCause(e).asRuntimeException();
|
||||
} catch (ExecutionException e) {
|
||||
throw Status.fromThrowable(e).asRuntimeException();
|
||||
throw toStatusRuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps the given {@link Throwable} in a {@link StatusRuntimeException}. If it contains an
|
||||
* embedded {@link StatusException} or {@link StatusRuntimeException}, the returned exception will
|
||||
* contain the embedded trailers and status, with the given exception as the cause. Otherwise, an
|
||||
* exception will be generated from an {@link Status#UNKNOWN} status.
|
||||
*/
|
||||
private static StatusRuntimeException toStatusRuntimeException(Throwable t) {
|
||||
Throwable cause = checkNotNull(t);
|
||||
while (cause != null) {
|
||||
// If we have an embedded status, use it and replace the cause
|
||||
if (cause instanceof StatusException) {
|
||||
StatusException se = (StatusException) cause;
|
||||
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
|
||||
} else if (cause instanceof StatusRuntimeException) {
|
||||
StatusRuntimeException se = (StatusRuntimeException) cause;
|
||||
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
|
||||
}
|
||||
cause = cause.getCause();
|
||||
}
|
||||
return Status.UNKNOWN.withCause(t).asRuntimeException();
|
||||
}
|
||||
|
||||
private static <ReqT, RespT> void asyncUnaryRequestCall(
|
||||
ClientCall<ReqT, RespT> call, ReqT param, StreamObserver<RespT> responseObserver,
|
||||
boolean streamingResponse) {
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ import static org.junit.Assert.assertSame;
|
|||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
|
|
@ -44,6 +45,7 @@ import com.google.common.util.concurrent.SettableFuture;
|
|||
|
||||
import io.grpc.CallOptions;
|
||||
import io.grpc.ClientCall;
|
||||
import io.grpc.ClientCall.Listener;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.MethodDescriptor;
|
||||
|
|
@ -51,6 +53,7 @@ import io.grpc.Server;
|
|||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.ServiceDescriptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import io.grpc.stub.ServerCalls.NoopStreamObserver;
|
||||
|
|
@ -63,7 +66,10 @@ import org.junit.runner.RunWith;
|
|||
import org.junit.runners.JUnit4;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Mockito;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
import org.mockito.invocation.InvocationOnMock;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
|
@ -81,10 +87,11 @@ import java.util.concurrent.TimeUnit;
|
|||
@RunWith(JUnit4.class)
|
||||
public class ClientCallsTest {
|
||||
|
||||
static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor.create(
|
||||
MethodDescriptor.MethodType.BIDI_STREAMING,
|
||||
"some/method",
|
||||
new IntegerMarshaller(), new IntegerMarshaller());
|
||||
private static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor
|
||||
.create(
|
||||
MethodDescriptor.MethodType.BIDI_STREAMING,
|
||||
"some/method",
|
||||
new IntegerMarshaller(), new IntegerMarshaller());
|
||||
|
||||
private Server server;
|
||||
private ManagedChannel channel;
|
||||
|
|
@ -107,6 +114,53 @@ public class ClientCallsTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void unaryBlockingCallSuccess() throws Exception {
|
||||
Integer req = 2;
|
||||
final String resp = "bar";
|
||||
final Status status = Status.OK;
|
||||
final Metadata trailers = new Metadata();
|
||||
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock in) throws Throwable {
|
||||
@SuppressWarnings("unchecked")
|
||||
Listener<String> listener = (Listener<String>) in.getArguments()[0];
|
||||
listener.onMessage(resp);
|
||||
listener.onClose(status, trailers);
|
||||
return null;
|
||||
}
|
||||
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
|
||||
|
||||
String actualResponse = ClientCalls.blockingUnaryCall(call, req);
|
||||
assertEquals(resp, actualResponse);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void unaryBlockingCallFailed() throws Exception {
|
||||
Integer req = 2;
|
||||
final Status status = Status.INTERNAL;
|
||||
final Metadata trailers = new Metadata();
|
||||
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock in) throws Throwable {
|
||||
@SuppressWarnings("unchecked")
|
||||
Listener<String> listener = (Listener<String>) in.getArguments()[0];
|
||||
listener.onClose(status, trailers);
|
||||
return null;
|
||||
}
|
||||
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
|
||||
|
||||
try {
|
||||
ClientCalls.blockingUnaryCall(call, req);
|
||||
fail("Should fail");
|
||||
} catch (StatusRuntimeException e) {
|
||||
assertSame(status.getCode(), e.getStatus().getCode());
|
||||
assertSame(trailers, e.getTrailers());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void unaryFutureCallSuccess() throws Exception {
|
||||
Integer req = 2;
|
||||
|
|
@ -246,24 +300,24 @@ public class ClientCallsTest {
|
|||
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
|
||||
ClientResponseObserver<Integer, String> responseObserver =
|
||||
new ClientResponseObserver<Integer, String>() {
|
||||
@Override
|
||||
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
|
||||
requestStream.disableAutoInboundFlowControl();
|
||||
requestStream.request(5);
|
||||
}
|
||||
@Override
|
||||
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
|
||||
requestStream.disableAutoInboundFlowControl();
|
||||
requestStream.request(5);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(String value) {
|
||||
}
|
||||
@Override
|
||||
public void onNext(String value) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
}
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCompleted() {
|
||||
}
|
||||
};
|
||||
@Override
|
||||
public void onCompleted() {
|
||||
}
|
||||
};
|
||||
ClientCalls.asyncServerStreamingCall(call, 1, responseObserver);
|
||||
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
|
||||
listenerCaptor.getValue().onMessage("message");
|
||||
|
|
@ -398,6 +452,7 @@ public class ClientCallsTest {
|
|||
public void beforeStart(final ClientCallStreamObserver<Integer> requestStream) {
|
||||
requestStream.setOnReadyHandler(new Runnable() {
|
||||
int iteration;
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
while (requestStream.isReady()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue