Enforce request deadline

Use a ScheduledExecutorService in the ChannelImpl to terminate the
request by closing the ClientStream with status DEADLINE_EXCEEDED
This commit is contained in:
Jack Coughlin 2015-07-07 08:44:01 -07:00 committed by Eric Anderson
parent ac9db3b157
commit 3e26b993ce
14 changed files with 275 additions and 51 deletions

View File

@ -48,6 +48,10 @@ import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -66,7 +70,7 @@ public final class ChannelImpl extends Channel {
@Override public void flush() {} @Override public void flush() {}
@Override public void cancel() {} @Override public void cancel(Status reason) {}
@Override public void halfClose() {} @Override public void halfClose() {}
@ -85,6 +89,10 @@ public final class ChannelImpl extends Channel {
private final ExecutorService executor; private final ExecutorService executor;
private final String userAgent; private final String userAgent;
/**
* Executor that runs deadline timers for requests.
*/
private ScheduledExecutorService deadlineCancellationExecutor;
/** /**
* All transports that are not stopped. At the very least {@link #activeTransport} will be * All transports that are not stopped. At the very least {@link #activeTransport} will be
* present, but previously used transports that still have streams or are stopping may also be * present, but previously used transports that still have streams or are stopping may also be
@ -108,6 +116,7 @@ public final class ChannelImpl extends Channel {
this.transportFactory = transportFactory; this.transportFactory = transportFactory;
this.executor = executor; this.executor = executor;
this.userAgent = userAgent; this.userAgent = userAgent;
deadlineCancellationExecutor = SharedResourceHolder.get(TIMER_SERVICE);
} }
/** Hack to allow executors to auto-shutdown. Not for general use. */ /** Hack to allow executors to auto-shutdown. Not for general use. */
@ -125,6 +134,9 @@ public final class ChannelImpl extends Channel {
return this; return this;
} }
shutdown = true; shutdown = true;
// After shutdown there are no new calls, so no new cancellation tasks are needed
deadlineCancellationExecutor =
SharedResourceHolder.release(TIMER_SERVICE, deadlineCancellationExecutor);
if (activeTransport != null) { if (activeTransport != null) {
activeTransport.shutdown(); activeTransport.shutdown();
activeTransport = null; activeTransport = null;
@ -300,6 +312,7 @@ public final class ChannelImpl extends Channel {
private final boolean unaryRequest; private final boolean unaryRequest;
private final CallOptions callOptions; private final CallOptions callOptions;
private ClientStream stream; private ClientStream stream;
private volatile ScheduledFuture<?> deadlineCancellationFuture;
public CallImpl(MethodDescriptor<ReqT, RespT> method, SerializingExecutor executor, public CallImpl(MethodDescriptor<ReqT, RespT> method, SerializingExecutor executor,
CallOptions callOptions) { CallOptions callOptions) {
@ -331,8 +344,9 @@ public final class ChannelImpl extends Channel {
// Convert the deadline to timeout. Timeout is more favorable than deadline on the wire // Convert the deadline to timeout. Timeout is more favorable than deadline on the wire
// because timeout tolerates the clock difference between machines. // because timeout tolerates the clock difference between machines.
Long deadlineNanoTime = callOptions.getDeadlineNanoTime(); Long deadlineNanoTime = callOptions.getDeadlineNanoTime();
long timeoutMicros = 0;
if (deadlineNanoTime != null) { if (deadlineNanoTime != null) {
long timeoutMicros = TimeUnit.NANOSECONDS.toMicros(deadlineNanoTime - System.nanoTime()); timeoutMicros = TimeUnit.NANOSECONDS.toMicros(deadlineNanoTime - System.nanoTime());
if (timeoutMicros <= 0) { if (timeoutMicros <= 0) {
closeCallPrematurely(listener, Status.DEADLINE_EXCEEDED); closeCallPrematurely(listener, Status.DEADLINE_EXCEEDED);
return; return;
@ -353,6 +367,10 @@ public final class ChannelImpl extends Channel {
// TODO(ejona86): Improve the API to remove the possibility of the race. // TODO(ejona86): Improve the API to remove the possibility of the race.
closeCallPrematurely(listener, Status.fromThrowable(ex)); closeCallPrematurely(listener, Status.fromThrowable(ex));
} }
// Start the deadline timer after stream creation because it will close the stream
if (deadlineNanoTime != null) {
deadlineCancellationFuture = startDeadlineTimer(timeoutMicros);
}
} }
@Override @Override
@ -366,7 +384,7 @@ public final class ChannelImpl extends Channel {
// Cancel is called in exception handling cases, so it may be the case that the // Cancel is called in exception handling cases, so it may be the case that the
// stream was never successfully created. // stream was never successfully created.
if (stream != null) { if (stream != null) {
stream.cancel(); stream.cancel(Status.CANCELLED);
} }
} }
@ -411,6 +429,15 @@ public final class ChannelImpl extends Channel {
listener.closed(status, new Metadata.Trailers()); listener.closed(status, new Metadata.Trailers());
} }
private ScheduledFuture<?> startDeadlineTimer(long timeoutMicros) {
return deadlineCancellationExecutor.schedule(new Runnable() {
@Override
public void run() {
stream.cancel(Status.DEADLINE_EXCEEDED);
}
}, timeoutMicros, TimeUnit.MICROSECONDS);
}
private class ClientStreamListenerImpl implements ClientStreamListener { private class ClientStreamListenerImpl implements ClientStreamListener {
private final Listener<RespT> observer; private final Listener<RespT> observer;
private boolean closed; private boolean closed;
@ -468,6 +495,11 @@ public final class ChannelImpl extends Channel {
@Override @Override
public void run() { public void run() {
closed = true; closed = true;
// manually optimize the volatile read
ScheduledFuture<?> future = deadlineCancellationFuture;
if (future != null) {
future.cancel(false);
}
observer.onClose(status, trailers); observer.onClose(status, trailers);
} }
}); });
@ -561,4 +593,24 @@ public final class ChannelImpl extends Channel {
return Long.parseLong(valuePart) * factor; return Long.parseLong(valuePart) * factor;
} }
} }
private static final SharedResourceHolder.Resource<ScheduledExecutorService> TIMER_SERVICE =
new SharedResourceHolder.Resource<ScheduledExecutorService>() {
@Override
public ScheduledExecutorService create() {
return Executors.newSingleThreadScheduledExecutor(new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread thread = new Thread(r);
thread.setDaemon(true);
return thread;
}
});
}
@Override
public void close(ScheduledExecutorService instance) {
instance.shutdown();
}
};
} }

View File

@ -31,6 +31,9 @@
package io.grpc.transport; package io.grpc.transport;
import static io.grpc.Status.Code.CANCELLED;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
@ -38,6 +41,7 @@ import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import java.io.InputStream; import java.io.InputStream;
import java.util.EnumSet;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -153,7 +157,7 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
@Override @Override
protected final void deframeFailed(Throwable cause) { protected final void deframeFailed(Throwable cause) {
log.log(Level.WARNING, "Exception processing message", cause); log.log(Level.WARNING, "Exception processing message", cause);
cancel(); cancel(Status.CANCELLED);
} }
/** /**
@ -278,9 +282,11 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
* Cancel the stream. Called by the application layer, never called by the transport. * Cancel the stream. Called by the application layer, never called by the transport.
*/ */
@Override @Override
public void cancel() { public void cancel(Status reason) {
Preconditions.checkArgument(EnumSet.of(CANCELLED, DEADLINE_EXCEEDED).contains(reason.getCode()),
"Invalid cancellation reason");
outboundPhase(Phase.STATUS); outboundPhase(Phase.STATUS);
sendCancel(); sendCancel(reason);
dispose(); dispose();
} }
@ -289,7 +295,7 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
* Can be called by either the application or transport layers. This method is safe to be called * Can be called by either the application or transport layers. This method is safe to be called
* at any time and multiple times. * at any time and multiple times.
*/ */
protected abstract void sendCancel(); protected abstract void sendCancel(Status reason);
// We support Guava 14 // We support Guava 14
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")

View File

@ -31,6 +31,8 @@
package io.grpc.transport; package io.grpc.transport;
import io.grpc.Status;
/** /**
* Extension of {@link Stream} to support client-side termination semantics. * Extension of {@link Stream} to support client-side termination semantics.
*/ */
@ -41,8 +43,10 @@ public interface ClientStream extends Stream {
* sent or received, however it may still be possible to receive buffered messages for a brief * sent or received, however it may still be possible to receive buffered messages for a brief
* period until {@link ClientStreamListener#closed} is called. This method is safe to be called * period until {@link ClientStreamListener#closed} is called. This method is safe to be called
* at any time and multiple times. * at any time and multiple times.
*
* @param reason must be one of Status.CANCELLED or Status.DEADLINE_EXCEEDED
*/ */
void cancel(); void cancel(Status reason);
/** /**
* Closes the local side of this stream and flushes any remaining messages. After this is called, * Closes the local side of this stream and flushes any remaining messages. After this is called,

View File

@ -128,7 +128,7 @@ public abstract class Http2ClientStream extends AbstractClientStream<Integer> {
if (transportError.getDescription().length() > 1000 || endOfStream) { if (transportError.getDescription().length() > 1000 || endOfStream) {
inboundTransportError(transportError); inboundTransportError(transportError);
// We have enough error detail so lets cancel. // We have enough error detail so lets cancel.
sendCancel(); sendCancel(Status.CANCELLED);
} }
} else { } else {
inboundDataReceived(frame); inboundDataReceived(frame);
@ -155,7 +155,7 @@ public abstract class Http2ClientStream extends AbstractClientStream<Integer> {
} }
if (transportError != null) { if (transportError != null) {
inboundTransportError(transportError); inboundTransportError(transportError);
sendCancel(); sendCancel(Status.CANCELLED);
} else { } else {
Status status = statusFromTrailers(trailers); Status status = statusFromTrailers(trailers);
stripTransportDetails(trailers); stripTransportDetails(trailers);

View File

@ -35,8 +35,10 @@ import static io.grpc.testing.integration.Messages.PayloadType.COMPRESSABLE;
import static io.grpc.testing.integration.Util.assertEquals; import static io.grpc.testing.integration.Util.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -86,6 +88,7 @@ import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
/** /**
@ -97,6 +100,7 @@ public abstract class AbstractTransportTest {
ProtoUtils.keyForProto(Messages.SimpleContext.getDefaultInstance()); ProtoUtils.keyForProto(Messages.SimpleContext.getDefaultInstance());
private static final AtomicReference<Metadata.Headers> requestHeadersCapture = private static final AtomicReference<Metadata.Headers> requestHeadersCapture =
new AtomicReference<Metadata.Headers>(); new AtomicReference<Metadata.Headers>();
private static final AtomicLong serverDelayMillis = new AtomicLong(0);
private static ScheduledExecutorService testServiceExecutor; private static ScheduledExecutorService testServiceExecutor;
private static ServerImpl server; private static ServerImpl server;
private static int OPERATION_TIMEOUT = 5000; private static int OPERATION_TIMEOUT = 5000;
@ -106,6 +110,7 @@ public abstract class AbstractTransportTest {
builder.addService(ServerInterceptors.intercept( builder.addService(ServerInterceptors.intercept(
TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)), TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)),
TestUtils.delayServerResponseInterceptor(serverDelayMillis),
TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture), TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture),
TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY))); TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY)));
try { try {
@ -133,6 +138,7 @@ public abstract class AbstractTransportTest {
blockingStub = TestServiceGrpc.newBlockingStub(channel); blockingStub = TestServiceGrpc.newBlockingStub(channel);
asyncStub = TestServiceGrpc.newStub(channel); asyncStub = TestServiceGrpc.newStub(channel);
requestHeadersCapture.set(null); requestHeadersCapture.set(null);
serverDelayMillis.set(0);
} }
/** Clean up. */ /** Clean up. */
@ -595,6 +601,60 @@ public abstract class AbstractTransportTest {
&& configuredTimeoutMinutes - transferredTimeoutMinutes <= 1); && configuredTimeoutMinutes - transferredTimeoutMinutes <= 1);
} }
@Test
public void deadlineNotExceeded() {
serverDelayMillis.set(0);
// warm up the channel and JVM
blockingStub.emptyCall(Empty.getDefaultInstance());
TestServiceGrpc.newBlockingStub(channel)
.configureNewStub()
.setDeadlineAfter(50, TimeUnit.MILLISECONDS)
.build().emptyCall(Empty.getDefaultInstance());
}
@Test(timeout = 10000)
public void deadlineExceeded() {
serverDelayMillis.set(20);
// warm up the channel and JVM
blockingStub.emptyCall(Empty.getDefaultInstance());
TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel)
.configureNewStub()
.setDeadlineAfter(10, TimeUnit.MILLISECONDS)
.build();
try {
stub.emptyCall(Empty.getDefaultInstance());
fail("Expected deadline to be exceeded");
} catch (Throwable t) {
assertEquals(Status.DEADLINE_EXCEEDED, Status.fromThrowable(t));
}
}
@Test(timeout = 10000)
public void deadlineExceededServerStreaming() throws Exception {
serverDelayMillis.set(10); // applied to every message
// warm up the channel and JVM
blockingStub.emptyCall(Empty.getDefaultInstance());
StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder()
.setResponseType(PayloadType.COMPRESSABLE)
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(1))
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(1))
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(1))
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(1))
.build();
StreamRecorder<StreamingOutputCallResponse> recorder = StreamRecorder.create();
TestServiceGrpc.newStub(channel)
.configureNewStub()
.setDeadlineAfter(30, TimeUnit.MILLISECONDS)
.build().streamingOutputCall(request, recorder);
recorder.awaitCompletion();
assertEquals(Status.DEADLINE_EXCEEDED, Status.fromThrowable(recorder.getError()));
assertNotEquals(0, recorder.getValues().size());
}
protected int unaryPayloadLength() { protected int unaryPayloadLength() {
// 10MiB. // 10MiB.
return 10485760; return 10485760;

View File

@ -31,19 +31,35 @@
package io.grpc.transport.netty; package io.grpc.transport.netty;
import static io.grpc.Status.Code.CANCELLED;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.Status;
import java.util.EnumSet;
/** /**
* Command sent from a Netty client stream to the handler to cancel the stream. * Command sent from a Netty client stream to the handler to cancel the stream.
*/ */
class CancelStreamCommand { class CancelStreamCommand {
private final NettyClientStream stream; private final NettyClientStream stream;
private final Status reason;
CancelStreamCommand(NettyClientStream stream) { CancelStreamCommand(NettyClientStream stream, Status reason) {
this.stream = Preconditions.checkNotNull(stream, "stream"); this.stream = Preconditions.checkNotNull(stream, "stream");
Preconditions.checkNotNull(reason);
Preconditions.checkArgument(EnumSet.of(CANCELLED, DEADLINE_EXCEEDED).contains(reason.getCode()),
"Invalid cancellation reason");
this.reason = reason;
} }
NettyClientStream stream() { NettyClientStream stream() {
return stream; return stream;
} }
Status reason() {
return reason;
}
} }

View File

@ -318,7 +318,7 @@ class NettyClientHandler extends Http2ConnectionHandler {
private void cancelStream(ChannelHandlerContext ctx, CancelStreamCommand cmd, private void cancelStream(ChannelHandlerContext ctx, CancelStreamCommand cmd,
ChannelPromise promise) { ChannelPromise promise) {
NettyClientStream stream = cmd.stream(); NettyClientStream stream = cmd.stream();
stream.transportReportStatus(Status.CANCELLED, true, new Metadata.Trailers()); stream.transportReportStatus(cmd.reason(), true, new Metadata.Trailers());
encoder().writeRstStream(ctx, stream.id(), Http2Error.CANCEL.code(), promise); encoder().writeRstStream(ctx, stream.id(), Http2Error.CANCEL.code(), promise);
} }

View File

@ -35,6 +35,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static io.netty.buffer.Unpooled.EMPTY_BUFFER; import static io.netty.buffer.Unpooled.EMPTY_BUFFER;
import io.grpc.Status;
import io.grpc.transport.ClientStreamListener; import io.grpc.transport.ClientStreamListener;
import io.grpc.transport.Http2ClientStream; import io.grpc.transport.Http2ClientStream;
import io.grpc.transport.WritableBuffer; import io.grpc.transport.WritableBuffer;
@ -118,9 +119,9 @@ class NettyClientStream extends Http2ClientStream {
} }
@Override @Override
protected void sendCancel() { protected void sendCancel(Status reason) {
// Send the cancel command to the handler. // Send the cancel command to the handler.
writeQueue.enqueue(new CancelStreamCommand(this), true); writeQueue.enqueue(new CancelStreamCommand(this, reason), true);
} }
@Override @Override

View File

@ -181,7 +181,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
verify(stream).id(eq(3)); verify(stream).id(eq(3));
when(stream.id()).thenReturn(3); when(stream.id()).thenReturn(3);
// Cancel the stream. // Cancel the stream.
writeQueue.enqueue(new CancelStreamCommand(stream), true); writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), true);
assertTrue(createPromise.isSuccess()); assertTrue(createPromise.isSuccess());
verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true), verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true),
@ -216,7 +216,18 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
public void cancelShouldSucceed() throws Exception { public void cancelShouldSucceed() throws Exception {
createStream(); createStream();
verify(channel, times(1)).flush(); verify(channel, times(1)).flush();
writeQueue.enqueue(new CancelStreamCommand(stream), true); writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), true);
ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code());
verify(ctx).write(eq(expected), eq(promise));
verify(channel, times(2)).flush();
}
@Test
public void cancelDeadlineExceededShouldSucceed() throws Exception {
createStream();
verify(channel, times(1)).flush();
writeQueue.enqueue(new CancelStreamCommand(stream, Status.DEADLINE_EXCEEDED), true);
ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code());
verify(ctx).write(eq(expected), eq(promise)); verify(ctx).write(eq(expected), eq(promise));
@ -233,7 +244,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
verify(stream).id(idCaptor.capture()); verify(stream).id(idCaptor.capture());
when(stream.id()).thenReturn(idCaptor.getValue()); when(stream.id()).thenReturn(idCaptor.getValue());
ChannelPromise cancelPromise = mock(ChannelPromise.class); ChannelPromise cancelPromise = mock(ChannelPromise.class);
writeQueue.enqueue(new CancelStreamCommand(stream), cancelPromise, true); writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), cancelPromise, true);
verify(cancelPromise).setSuccess(); verify(cancelPromise).setSuccess();
verify(channel, times(2)).flush(); verify(channel, times(2)).flush();
verifyNoMoreInteractions(ctx); verifyNoMoreInteractions(ctx);
@ -248,14 +259,29 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
public void cancelTwiceShouldSucceed() throws Exception { public void cancelTwiceShouldSucceed() throws Exception {
createStream(); createStream();
writeQueue.enqueue(new CancelStreamCommand(stream), promise, true); writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), promise, true);
ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code());
verify(ctx).write(eq(expected), any(ChannelPromise.class)); verify(ctx).write(eq(expected), any(ChannelPromise.class));
promise = mock(ChannelPromise.class); promise = mock(ChannelPromise.class);
writeQueue.enqueue(new CancelStreamCommand(stream), promise, true); writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), promise, true);
verify(promise).setSuccess();
}
@Test
public void cancelTwiceDifferentReasons() throws Exception {
createStream();
writeQueue.enqueue(new CancelStreamCommand(stream, Status.DEADLINE_EXCEEDED), promise, true);
ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code());
verify(ctx).write(eq(expected), any(ChannelPromise.class));
promise = mock(ChannelPromise.class);
writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), promise, true);
verify(promise).setSuccess(); verify(promise).setSuccess();
} }
@ -357,7 +383,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
writeQueue.enqueue(new CreateStreamCommand(grpcHeaders, stream), true); writeQueue.enqueue(new CreateStreamCommand(grpcHeaders, stream), true);
verify(stream).id(3); verify(stream).id(3);
when(stream.id()).thenReturn(3); when(stream.id()).thenReturn(3);
writeQueue.enqueue(new CancelStreamCommand(stream), true); writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), true);
verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true), verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true),
any(Metadata.Trailers.class)); any(Metadata.Trailers.class));
} }

View File

@ -42,6 +42,7 @@ import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -102,14 +103,28 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
public void cancelShouldSendCommand() { public void cancelShouldSendCommand() {
// Set stream id to indicate it has been created // Set stream id to indicate it has been created
stream().id(STREAM_ID); stream().id(STREAM_ID);
stream().cancel(); stream().cancel(Status.CANCELLED);
verify(writeQueue).enqueue(any(CancelStreamCommand.class), eq(true)); ArgumentCaptor<CancelStreamCommand> commandCaptor =
ArgumentCaptor.forClass(CancelStreamCommand.class);
verify(writeQueue).enqueue(commandCaptor.capture(), eq(true));
assertEquals(commandCaptor.getValue().reason(), Status.CANCELLED);
}
@Test
public void deadlineExceededCancelShouldSendCommand() {
// Set stream id to indicate it has been created
stream().id(STREAM_ID);
stream().cancel(Status.DEADLINE_EXCEEDED);
ArgumentCaptor<CancelStreamCommand> commandCaptor =
ArgumentCaptor.forClass(CancelStreamCommand.class);
verify(writeQueue).enqueue(commandCaptor.capture(), eq(true));
assertEquals(commandCaptor.getValue().reason(), Status.DEADLINE_EXCEEDED);
} }
@Test @Test
public void cancelShouldStillSendCommandIfStreamNotCreatedToCancelCreation() { public void cancelShouldStillSendCommandIfStreamNotCreatedToCancelCreation() {
stream().cancel(); stream().cancel(Status.CANCELLED);
verify(writeQueue).enqueue(any(CancelStreamCommand.class), eq(true)); verify(writeQueue).enqueue(isA(CancelStreamCommand.class), eq(true));
} }
@Test @Test
@ -340,7 +355,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
@Override @Override
protected void closeStream() { protected void closeStream() {
stream().cancel(); stream().cancel(Status.CANCELLED);
} }
private ByteBuf simpleGrpcFrame() { private ByteBuf simpleGrpcFrame() {

View File

@ -209,8 +209,8 @@ class OkHttpClientStream extends Http2ClientStream {
} }
@Override @Override
protected void sendCancel() { protected void sendCancel(Status reason) {
transport.finishStream(id(), Status.CANCELLED, ErrorCode.CANCEL); transport.finishStream(id(), reason, ErrorCode.CANCEL);
} }
@Override @Override

View File

@ -270,10 +270,10 @@ class OkHttpClientTransport implements ClientTransport {
} catch (InterruptedException e) { } catch (InterruptedException e) {
// Restore the interrupt. // Restore the interrupt.
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
clientStream.cancel(); clientStream.cancel(Status.CANCELLED);
throw new RuntimeException(e); throw new RuntimeException(e);
} catch (ExecutionException e) { } catch (ExecutionException e) {
clientStream.cancel(); clientStream.cancel(Status.CANCELLED);
throw new RuntimeException(e.getCause() != null ? e.getCause() : e); throw new RuntimeException(e.getCause() != null ? e.getCause() : e);
} }
} }
@ -457,7 +457,8 @@ class OkHttpClientTransport implements ClientTransport {
frameWriter.rstStream(streamId, ErrorCode.CANCEL); frameWriter.rstStream(streamId, ErrorCode.CANCEL);
} }
if (status != null) { if (status != null) {
boolean isCancelled = status.getCode() == Code.CANCELLED; boolean isCancelled = (status.getCode() == Code.CANCELLED
|| status.getCode() == Code.DEADLINE_EXCEEDED);
stream.transportReportStatus(status, isCancelled, new Metadata.Trailers()); stream.transportReportStatus(status, isCancelled, new Metadata.Trailers());
} }
if (!startPendingStreams()) { if (!startPendingStreams()) {

View File

@ -268,7 +268,7 @@ public class OkHttpClientTransportTest {
clientTransport.newStream(method, new Metadata.Headers(), listener); clientTransport.newStream(method, new Metadata.Headers(), listener);
OkHttpClientStream stream = streams.get(3); OkHttpClientStream stream = streams.get(3);
assertNotNull(stream); assertNotNull(stream);
stream.cancel(); stream.cancel(Status.CANCELLED);
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener.waitUntilStreamClosed(); listener.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(),
@ -286,7 +286,7 @@ public class OkHttpClientTransportTest {
new Header(Header.TARGET_PATH, "/fakemethod"), new Header(Header.TARGET_PATH, "/fakemethod"),
userAgentHeader, CONTENT_TYPE_HEADER, TE_HEADER); userAgentHeader, CONTENT_TYPE_HEADER, TE_HEADER);
verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders));
streams.get(3).cancel(); streams.get(3).cancel(Status.CANCELLED);
} }
@Test @Test
@ -303,7 +303,18 @@ public class OkHttpClientTransportTest {
HttpUtil.getGrpcUserAgent("okhttp", userAgent)), HttpUtil.getGrpcUserAgent("okhttp", userAgent)),
CONTENT_TYPE_HEADER, TE_HEADER); CONTENT_TYPE_HEADER, TE_HEADER);
verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders));
streams.get(3).cancel(); streams.get(3).cancel(Status.CANCELLED);
}
@Test
public void cancelStreamForDeadlineExceeded() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, new Metadata.Headers(), listener);
OkHttpClientStream stream = streams.get(3);
assertNotNull(stream);
stream.cancel(Status.DEADLINE_EXCEEDED);
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener.waitUntilStreamClosed();
} }
@Test @Test
@ -320,7 +331,7 @@ public class OkHttpClientTransportTest {
verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH));
Buffer sentFrame = captor.getValue(); Buffer sentFrame = captor.getValue();
assertEquals(createMessageFrame(message), sentFrame); assertEquals(createMessageFrame(message), sentFrame);
stream.cancel(); stream.cancel(Status.CANCELLED);
} }
@Test @Test
@ -364,13 +375,13 @@ public class OkHttpClientTransportTest {
verify(frameWriter).windowUpdate(eq(5), eq((long) 2 * messageFrameLength)); verify(frameWriter).windowUpdate(eq(5), eq((long) 2 * messageFrameLength));
verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
stream1.cancel(); stream1.cancel(Status.CANCELLED);
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener1.waitUntilStreamClosed(); listener1.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(),
listener1.status.getCode()); listener1.status.getCode());
stream2.cancel(); stream2.cancel(Status.CANCELLED);
verify(frameWriter).rstStream(eq(5), eq(ErrorCode.CANCEL)); verify(frameWriter).rstStream(eq(5), eq(ErrorCode.CANCEL));
listener2.waitUntilStreamClosed(); listener2.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(),
@ -394,7 +405,7 @@ public class OkHttpClientTransportTest {
// We return the bytes for the stream window as we read the message. // We return the bytes for the stream window as we read the message.
verify(frameWriter).windowUpdate(eq(3), eq(messageFrameLength)); verify(frameWriter).windowUpdate(eq(3), eq(messageFrameLength));
stream.cancel(); stream.cancel(Status.CANCELLED);
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener.waitUntilStreamClosed(); listener.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(),
@ -429,7 +440,7 @@ public class OkHttpClientTransportTest {
verify(frameWriter).data( verify(frameWriter).data(
eq(false), eq(3), any(Buffer.class), eq(messageLength + HEADER_LENGTH - partiallySentSize)); eq(false), eq(3), any(Buffer.class), eq(messageLength + HEADER_LENGTH - partiallySentSize));
stream.cancel(); stream.cancel(Status.CANCELLED);
listener.waitUntilStreamClosed(); listener.waitUntilStreamClosed();
} }
@ -468,7 +479,7 @@ public class OkHttpClientTransportTest {
frameHandler.windowUpdate(3, HEADER_LENGTH + 20); frameHandler.windowUpdate(3, HEADER_LENGTH + 20);
verify(frameWriter).data(eq(false), eq(3), any(Buffer.class), eq(HEADER_LENGTH + 20)); verify(frameWriter).data(eq(false), eq(3), any(Buffer.class), eq(HEADER_LENGTH + 20));
stream.cancel(); stream.cancel(Status.CANCELLED);
listener.waitUntilStreamClosed(); listener.waitUntilStreamClosed();
} }
@ -486,8 +497,8 @@ public class OkHttpClientTransportTest {
assertEquals(2, streams.size()); assertEquals(2, streams.size());
verify(transportListener).transportShutdown(); verify(transportListener).transportShutdown();
stream1.cancel(); stream1.cancel(Status.CANCELLED);
stream2.cancel(); stream2.cancel(Status.CANCELLED);
listener1.waitUntilStreamClosed(); listener1.waitUntilStreamClosed();
listener2.waitUntilStreamClosed(); listener2.waitUntilStreamClosed();
assertEquals(0, streams.size()); assertEquals(0, streams.size());
@ -563,7 +574,7 @@ public class OkHttpClientTransportTest {
assertNewStreamFail(transport); assertNewStreamFail(transport);
streams.get(startId).cancel(); streams.get(startId).cancel(Status.CANCELLED);
listener1.waitUntilStreamClosed(); listener1.waitUntilStreamClosed();
verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL)); verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL));
verify(transportListener).transportShutdown(); verify(transportListener).transportShutdown();
@ -592,14 +603,14 @@ public class OkHttpClientTransportTest {
assertEquals(3, (int) stream1.id()); assertEquals(3, (int) stream1.id());
// Finish the first stream // Finish the first stream
stream1.cancel(); stream1.cancel(Status.CANCELLED);
assertTrue("newStream() call is still blocking", assertTrue("newStream() call is still blocking",
newStreamReturn.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)); newStreamReturn.await(TIME_OUT_MS, TimeUnit.MILLISECONDS));
assertEquals(1, streams.size()); assertEquals(1, streams.size());
assertEquals(0, clientTransport.getPendingStreamSize()); assertEquals(0, clientTransport.getPendingStreamSize());
OkHttpClientStream stream2 = streams.get(5); OkHttpClientStream stream2 = streams.get(5);
assertNotNull(stream2); assertNotNull(stream2);
stream2.cancel(); stream2.cancel(Status.CANCELLED);
} }
@Test @Test
@ -693,7 +704,7 @@ public class OkHttpClientTransportTest {
// Now finish stream1, stream2 should be started and exhaust the id, // Now finish stream1, stream2 should be started and exhaust the id,
// so stream3 should be failed. // so stream3 should be failed.
stream1.cancel(); stream1.cancel(Status.CANCELLED);
assertTrue("newStream() call for stream2 is still blocking", assertTrue("newStream() call for stream2 is still blocking",
newStreamReturn2.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)); newStreamReturn2.await(TIME_OUT_MS, TimeUnit.MILLISECONDS));
assertTrue("newStream() call for stream3 is still blocking", assertTrue("newStream() call for stream3 is still blocking",
@ -705,7 +716,7 @@ public class OkHttpClientTransportTest {
assertEquals(1, streams.size()); assertEquals(1, streams.size());
OkHttpClientStream stream2 = streams.get(startId + 2); OkHttpClientStream stream2 = streams.get(startId + 2);
assertNotNull(stream2); assertNotNull(stream2);
stream2.cancel(); stream2.cancel(Status.CANCELLED);
} }
@Test @Test
@ -762,7 +773,7 @@ public class OkHttpClientTransportTest {
} else { } else {
verify(frameWriter, times(0)).flush(); verify(frameWriter, times(0)).flush();
} }
stream.cancel(); stream.cancel(Status.CANCELLED);
} }
@Test @Test
@ -819,7 +830,7 @@ public class OkHttpClientTransportTest {
public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception { public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception {
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener);
stream.cancel(); stream.cancel(Status.CANCELLED);
Buffer buffer = createMessageFrame( Buffer buffer = createMessageFrame(
new byte[OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 2 + 1]); new byte[OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 2 + 1]);
@ -841,7 +852,7 @@ public class OkHttpClientTransportTest {
public void receiveWindowUpdateForUnknownStream() throws Exception { public void receiveWindowUpdateForUnknownStream() throws Exception {
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener);
stream.cancel(); stream.cancel(Status.CANCELLED);
// This should be ignored. // This should be ignored.
frameHandler.windowUpdate(3, 73); frameHandler.windowUpdate(3, 73);
listener.waitUntilStreamClosed(); listener.waitUntilStreamClosed();
@ -859,7 +870,7 @@ public class OkHttpClientTransportTest {
method,new Metadata.Headers(), listener); method,new Metadata.Headers(), listener);
assertTrue(stream.isReady()); assertTrue(stream.isReady());
assertTrue(listener.isOnReadyCalled()); assertTrue(listener.isOnReadyCalled());
stream.cancel(); stream.cancel(Status.CANCELLED);
assertFalse(stream.isReady()); assertFalse(stream.isReady());
} }
@ -915,7 +926,7 @@ public class OkHttpClientTransportTest {
stream.setOnReadyThreshold(HEADER_LENGTH + messageLength + 1); stream.setOnReadyThreshold(HEADER_LENGTH + messageLength + 1);
assertTrue(listener.isOnReadyCalled()); assertTrue(listener.isOnReadyCalled());
stream.cancel(); stream.cancel(Status.CANCELLED);
} }
@Test @Test

View File

@ -31,6 +31,8 @@
package io.grpc.testing; package io.grpc.testing;
import com.google.common.base.Throwables;
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
@ -59,6 +61,7 @@ import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
@ -132,6 +135,35 @@ public class TestUtils {
}; };
} }
/**
* Delay each payload by the given number of milliseconds. Useful for simulating slow server
* responses.
* @param delayMillis the delay applied to each payload, in milliseconds.
*/
public static ServerInterceptor delayServerResponseInterceptor(final AtomicLong delayMillis) {
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(String method,
ServerCall<RespT> call,
Metadata.Headers headers,
ServerCallHandler<ReqT, RespT> next) {
return next.startCall(method, new SimpleForwardingServerCall<RespT>(call) {
@Override
public void sendPayload(RespT payload) {
if (delayMillis.get() != 0) {
try {
Thread.sleep(delayMillis.get());
} catch (InterruptedException e) {
Throwables.propagate(e);
}
}
super.sendPayload(payload);
}
}, headers);
}
};
}
/** /**
* Picks an unused port. * Picks an unused port.
*/ */