mirror of https://github.com/grpc/grpc-java.git
Ensure that trailers are not lost when making blocking calls.
Fixes #2036
This commit is contained in:
parent
c1ef8061d1
commit
780b2696c1
|
|
@ -31,6 +31,8 @@
|
||||||
|
|
||||||
package io.grpc.stub;
|
package io.grpc.stub;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkNotNull;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
import com.google.common.util.concurrent.AbstractFuture;
|
import com.google.common.util.concurrent.AbstractFuture;
|
||||||
import com.google.common.util.concurrent.ListenableFuture;
|
import com.google.common.util.concurrent.ListenableFuture;
|
||||||
|
|
@ -38,9 +40,11 @@ import com.google.common.util.concurrent.ListenableFuture;
|
||||||
import io.grpc.CallOptions;
|
import io.grpc.CallOptions;
|
||||||
import io.grpc.Channel;
|
import io.grpc.Channel;
|
||||||
import io.grpc.ClientCall;
|
import io.grpc.ClientCall;
|
||||||
|
import io.grpc.ExperimentalApi;
|
||||||
import io.grpc.Metadata;
|
import io.grpc.Metadata;
|
||||||
import io.grpc.MethodDescriptor;
|
import io.grpc.MethodDescriptor;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
|
import io.grpc.StatusException;
|
||||||
import io.grpc.StatusRuntimeException;
|
import io.grpc.StatusRuntimeException;
|
||||||
|
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
|
@ -205,10 +209,32 @@ public class ClientCalls {
|
||||||
Thread.currentThread().interrupt();
|
Thread.currentThread().interrupt();
|
||||||
throw Status.CANCELLED.withCause(e).asRuntimeException();
|
throw Status.CANCELLED.withCause(e).asRuntimeException();
|
||||||
} catch (ExecutionException e) {
|
} catch (ExecutionException e) {
|
||||||
throw Status.fromThrowable(e).asRuntimeException();
|
throw toStatusRuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps the given {@link Throwable} in a {@link StatusRuntimeException}. If it contains an
|
||||||
|
* embedded {@link StatusException} or {@link StatusRuntimeException}, the returned exception will
|
||||||
|
* contain the embedded trailers and status, with the given exception as the cause. Otherwise, an
|
||||||
|
* exception will be generated from an {@link Status#UNKNOWN} status.
|
||||||
|
*/
|
||||||
|
private static StatusRuntimeException toStatusRuntimeException(Throwable t) {
|
||||||
|
Throwable cause = checkNotNull(t);
|
||||||
|
while (cause != null) {
|
||||||
|
// If we have an embedded status, use it and replace the cause
|
||||||
|
if (cause instanceof StatusException) {
|
||||||
|
StatusException se = (StatusException) cause;
|
||||||
|
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
|
||||||
|
} else if (cause instanceof StatusRuntimeException) {
|
||||||
|
StatusRuntimeException se = (StatusRuntimeException) cause;
|
||||||
|
return new StatusRuntimeException(se.getStatus().withCause(t), se.getTrailers());
|
||||||
|
}
|
||||||
|
cause = cause.getCause();
|
||||||
|
}
|
||||||
|
return Status.UNKNOWN.withCause(t).asRuntimeException();
|
||||||
|
}
|
||||||
|
|
||||||
private static <ReqT, RespT> void asyncUnaryRequestCall(
|
private static <ReqT, RespT> void asyncUnaryRequestCall(
|
||||||
ClientCall<ReqT, RespT> call, ReqT param, StreamObserver<RespT> responseObserver,
|
ClientCall<ReqT, RespT> call, ReqT param, StreamObserver<RespT> responseObserver,
|
||||||
boolean streamingResponse) {
|
boolean streamingResponse) {
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ import static org.junit.Assert.assertSame;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
import static org.mockito.Matchers.any;
|
import static org.mockito.Matchers.any;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
|
|
@ -44,6 +45,7 @@ import com.google.common.util.concurrent.SettableFuture;
|
||||||
|
|
||||||
import io.grpc.CallOptions;
|
import io.grpc.CallOptions;
|
||||||
import io.grpc.ClientCall;
|
import io.grpc.ClientCall;
|
||||||
|
import io.grpc.ClientCall.Listener;
|
||||||
import io.grpc.ManagedChannel;
|
import io.grpc.ManagedChannel;
|
||||||
import io.grpc.Metadata;
|
import io.grpc.Metadata;
|
||||||
import io.grpc.MethodDescriptor;
|
import io.grpc.MethodDescriptor;
|
||||||
|
|
@ -51,6 +53,7 @@ import io.grpc.Server;
|
||||||
import io.grpc.ServerServiceDefinition;
|
import io.grpc.ServerServiceDefinition;
|
||||||
import io.grpc.ServiceDescriptor;
|
import io.grpc.ServiceDescriptor;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
|
import io.grpc.StatusRuntimeException;
|
||||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||||
import io.grpc.inprocess.InProcessServerBuilder;
|
import io.grpc.inprocess.InProcessServerBuilder;
|
||||||
import io.grpc.stub.ServerCalls.NoopStreamObserver;
|
import io.grpc.stub.ServerCalls.NoopStreamObserver;
|
||||||
|
|
@ -63,7 +66,10 @@ import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.JUnit4;
|
import org.junit.runners.JUnit4;
|
||||||
import org.mockito.ArgumentCaptor;
|
import org.mockito.ArgumentCaptor;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.Mockito;
|
||||||
import org.mockito.MockitoAnnotations;
|
import org.mockito.MockitoAnnotations;
|
||||||
|
import org.mockito.invocation.InvocationOnMock;
|
||||||
|
import org.mockito.stubbing.Answer;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
@ -81,10 +87,11 @@ import java.util.concurrent.TimeUnit;
|
||||||
@RunWith(JUnit4.class)
|
@RunWith(JUnit4.class)
|
||||||
public class ClientCallsTest {
|
public class ClientCallsTest {
|
||||||
|
|
||||||
static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor.create(
|
private static final MethodDescriptor<Integer, Integer> STREAMING_METHOD = MethodDescriptor
|
||||||
MethodDescriptor.MethodType.BIDI_STREAMING,
|
.create(
|
||||||
"some/method",
|
MethodDescriptor.MethodType.BIDI_STREAMING,
|
||||||
new IntegerMarshaller(), new IntegerMarshaller());
|
"some/method",
|
||||||
|
new IntegerMarshaller(), new IntegerMarshaller());
|
||||||
|
|
||||||
private Server server;
|
private Server server;
|
||||||
private ManagedChannel channel;
|
private ManagedChannel channel;
|
||||||
|
|
@ -107,6 +114,53 @@ public class ClientCallsTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void unaryBlockingCallSuccess() throws Exception {
|
||||||
|
Integer req = 2;
|
||||||
|
final String resp = "bar";
|
||||||
|
final Status status = Status.OK;
|
||||||
|
final Metadata trailers = new Metadata();
|
||||||
|
|
||||||
|
doAnswer(new Answer<Void>() {
|
||||||
|
@Override
|
||||||
|
public Void answer(InvocationOnMock in) throws Throwable {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Listener<String> listener = (Listener<String>) in.getArguments()[0];
|
||||||
|
listener.onMessage(resp);
|
||||||
|
listener.onClose(status, trailers);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
|
||||||
|
|
||||||
|
String actualResponse = ClientCalls.blockingUnaryCall(call, req);
|
||||||
|
assertEquals(resp, actualResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void unaryBlockingCallFailed() throws Exception {
|
||||||
|
Integer req = 2;
|
||||||
|
final Status status = Status.INTERNAL;
|
||||||
|
final Metadata trailers = new Metadata();
|
||||||
|
|
||||||
|
doAnswer(new Answer<Void>() {
|
||||||
|
@Override
|
||||||
|
public Void answer(InvocationOnMock in) throws Throwable {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Listener<String> listener = (Listener<String>) in.getArguments()[0];
|
||||||
|
listener.onClose(status, trailers);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
|
||||||
|
|
||||||
|
try {
|
||||||
|
ClientCalls.blockingUnaryCall(call, req);
|
||||||
|
fail("Should fail");
|
||||||
|
} catch (StatusRuntimeException e) {
|
||||||
|
assertSame(status.getCode(), e.getStatus().getCode());
|
||||||
|
assertSame(trailers, e.getTrailers());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void unaryFutureCallSuccess() throws Exception {
|
public void unaryFutureCallSuccess() throws Exception {
|
||||||
Integer req = 2;
|
Integer req = 2;
|
||||||
|
|
@ -246,24 +300,24 @@ public class ClientCallsTest {
|
||||||
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
|
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
|
||||||
ClientResponseObserver<Integer, String> responseObserver =
|
ClientResponseObserver<Integer, String> responseObserver =
|
||||||
new ClientResponseObserver<Integer, String>() {
|
new ClientResponseObserver<Integer, String>() {
|
||||||
@Override
|
@Override
|
||||||
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
|
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
|
||||||
requestStream.disableAutoInboundFlowControl();
|
requestStream.disableAutoInboundFlowControl();
|
||||||
requestStream.request(5);
|
requestStream.request(5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onNext(String value) {
|
public void onNext(String value) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onError(Throwable t) {
|
public void onError(Throwable t) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onCompleted() {
|
public void onCompleted() {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ClientCalls.asyncServerStreamingCall(call, 1, responseObserver);
|
ClientCalls.asyncServerStreamingCall(call, 1, responseObserver);
|
||||||
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
|
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
|
||||||
listenerCaptor.getValue().onMessage("message");
|
listenerCaptor.getValue().onMessage("message");
|
||||||
|
|
@ -398,6 +452,7 @@ public class ClientCallsTest {
|
||||||
public void beforeStart(final ClientCallStreamObserver<Integer> requestStream) {
|
public void beforeStart(final ClientCallStreamObserver<Integer> requestStream) {
|
||||||
requestStream.setOnReadyHandler(new Runnable() {
|
requestStream.setOnReadyHandler(new Runnable() {
|
||||||
int iteration;
|
int iteration;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
while (requestStream.isReady()) {
|
while (requestStream.isReady()) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue