diff --git a/core/src/main/java/io/grpc/ChannelImpl.java b/core/src/main/java/io/grpc/ChannelImpl.java index 5dc9ea5224..5a19266eeb 100644 --- a/core/src/main/java/io/grpc/ChannelImpl.java +++ b/core/src/main/java/io/grpc/ChannelImpl.java @@ -48,6 +48,10 @@ import java.util.ArrayList; import java.util.Collection; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; @@ -66,7 +70,7 @@ public final class ChannelImpl extends Channel { @Override public void flush() {} - @Override public void cancel() {} + @Override public void cancel(Status reason) {} @Override public void halfClose() {} @@ -85,6 +89,10 @@ public final class ChannelImpl extends Channel { private final ExecutorService executor; private final String userAgent; + /** + * Executor that runs deadline timers for requests. + */ + private ScheduledExecutorService deadlineCancellationExecutor; /** * All transports that are not stopped. At the very least {@link #activeTransport} will be * present, but previously used transports that still have streams or are stopping may also be @@ -108,6 +116,7 @@ public final class ChannelImpl extends Channel { this.transportFactory = transportFactory; this.executor = executor; this.userAgent = userAgent; + deadlineCancellationExecutor = SharedResourceHolder.get(TIMER_SERVICE); } /** Hack to allow executors to auto-shutdown. Not for general use. */ @@ -125,6 +134,9 @@ public final class ChannelImpl extends Channel { return this; } shutdown = true; + // After shutdown there are no new calls, so no new cancellation tasks are needed + deadlineCancellationExecutor = + SharedResourceHolder.release(TIMER_SERVICE, deadlineCancellationExecutor); if (activeTransport != null) { activeTransport.shutdown(); activeTransport = null; @@ -300,6 +312,7 @@ public final class ChannelImpl extends Channel { private final boolean unaryRequest; private final CallOptions callOptions; private ClientStream stream; + private volatile ScheduledFuture deadlineCancellationFuture; public CallImpl(MethodDescriptor method, SerializingExecutor executor, CallOptions callOptions) { @@ -331,8 +344,9 @@ public final class ChannelImpl extends Channel { // Convert the deadline to timeout. Timeout is more favorable than deadline on the wire // because timeout tolerates the clock difference between machines. Long deadlineNanoTime = callOptions.getDeadlineNanoTime(); + long timeoutMicros = 0; if (deadlineNanoTime != null) { - long timeoutMicros = TimeUnit.NANOSECONDS.toMicros(deadlineNanoTime - System.nanoTime()); + timeoutMicros = TimeUnit.NANOSECONDS.toMicros(deadlineNanoTime - System.nanoTime()); if (timeoutMicros <= 0) { closeCallPrematurely(listener, Status.DEADLINE_EXCEEDED); return; @@ -353,6 +367,10 @@ public final class ChannelImpl extends Channel { // TODO(ejona86): Improve the API to remove the possibility of the race. closeCallPrematurely(listener, Status.fromThrowable(ex)); } + // Start the deadline timer after stream creation because it will close the stream + if (deadlineNanoTime != null) { + deadlineCancellationFuture = startDeadlineTimer(timeoutMicros); + } } @Override @@ -366,7 +384,7 @@ public final class ChannelImpl extends Channel { // Cancel is called in exception handling cases, so it may be the case that the // stream was never successfully created. if (stream != null) { - stream.cancel(); + stream.cancel(Status.CANCELLED); } } @@ -411,6 +429,15 @@ public final class ChannelImpl extends Channel { listener.closed(status, new Metadata.Trailers()); } + private ScheduledFuture startDeadlineTimer(long timeoutMicros) { + return deadlineCancellationExecutor.schedule(new Runnable() { + @Override + public void run() { + stream.cancel(Status.DEADLINE_EXCEEDED); + } + }, timeoutMicros, TimeUnit.MICROSECONDS); + } + private class ClientStreamListenerImpl implements ClientStreamListener { private final Listener observer; private boolean closed; @@ -468,6 +495,11 @@ public final class ChannelImpl extends Channel { @Override public void run() { closed = true; + // manually optimize the volatile read + ScheduledFuture future = deadlineCancellationFuture; + if (future != null) { + future.cancel(false); + } observer.onClose(status, trailers); } }); @@ -561,4 +593,24 @@ public final class ChannelImpl extends Channel { return Long.parseLong(valuePart) * factor; } } + + private static final SharedResourceHolder.Resource TIMER_SERVICE = + new SharedResourceHolder.Resource() { + @Override + public ScheduledExecutorService create() { + return Executors.newSingleThreadScheduledExecutor(new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + Thread thread = new Thread(r); + thread.setDaemon(true); + return thread; + } + }); + } + + @Override + public void close(ScheduledExecutorService instance) { + instance.shutdown(); + } + }; } diff --git a/core/src/main/java/io/grpc/transport/AbstractClientStream.java b/core/src/main/java/io/grpc/transport/AbstractClientStream.java index e15283d563..2f5be29dd6 100644 --- a/core/src/main/java/io/grpc/transport/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/transport/AbstractClientStream.java @@ -31,6 +31,9 @@ package io.grpc.transport; +import static io.grpc.Status.Code.CANCELLED; +import static io.grpc.Status.Code.DEADLINE_EXCEEDED; + import com.google.common.base.Objects; import com.google.common.base.Preconditions; @@ -38,6 +41,7 @@ import io.grpc.Metadata; import io.grpc.Status; import java.io.InputStream; +import java.util.EnumSet; import java.util.logging.Level; import java.util.logging.Logger; @@ -153,7 +157,7 @@ public abstract class AbstractClientStream extends AbstractStream @Override protected final void deframeFailed(Throwable cause) { log.log(Level.WARNING, "Exception processing message", cause); - cancel(); + cancel(Status.CANCELLED); } /** @@ -278,9 +282,11 @@ public abstract class AbstractClientStream extends AbstractStream * Cancel the stream. Called by the application layer, never called by the transport. */ @Override - public void cancel() { + public void cancel(Status reason) { + Preconditions.checkArgument(EnumSet.of(CANCELLED, DEADLINE_EXCEEDED).contains(reason.getCode()), + "Invalid cancellation reason"); outboundPhase(Phase.STATUS); - sendCancel(); + sendCancel(reason); dispose(); } @@ -289,7 +295,7 @@ public abstract class AbstractClientStream extends AbstractStream * Can be called by either the application or transport layers. This method is safe to be called * at any time and multiple times. */ - protected abstract void sendCancel(); + protected abstract void sendCancel(Status reason); // We support Guava 14 @SuppressWarnings("deprecation") diff --git a/core/src/main/java/io/grpc/transport/ClientStream.java b/core/src/main/java/io/grpc/transport/ClientStream.java index e79de52d8b..298ef69131 100644 --- a/core/src/main/java/io/grpc/transport/ClientStream.java +++ b/core/src/main/java/io/grpc/transport/ClientStream.java @@ -31,6 +31,8 @@ package io.grpc.transport; +import io.grpc.Status; + /** * Extension of {@link Stream} to support client-side termination semantics. */ @@ -41,8 +43,10 @@ public interface ClientStream extends Stream { * sent or received, however it may still be possible to receive buffered messages for a brief * period until {@link ClientStreamListener#closed} is called. This method is safe to be called * at any time and multiple times. + * + * @param reason must be one of Status.CANCELLED or Status.DEADLINE_EXCEEDED */ - void cancel(); + void cancel(Status reason); /** * Closes the local side of this stream and flushes any remaining messages. After this is called, diff --git a/core/src/main/java/io/grpc/transport/Http2ClientStream.java b/core/src/main/java/io/grpc/transport/Http2ClientStream.java index bf4ea435b5..27d0ae430f 100644 --- a/core/src/main/java/io/grpc/transport/Http2ClientStream.java +++ b/core/src/main/java/io/grpc/transport/Http2ClientStream.java @@ -128,7 +128,7 @@ public abstract class Http2ClientStream extends AbstractClientStream { if (transportError.getDescription().length() > 1000 || endOfStream) { inboundTransportError(transportError); // We have enough error detail so lets cancel. - sendCancel(); + sendCancel(Status.CANCELLED); } } else { inboundDataReceived(frame); @@ -155,7 +155,7 @@ public abstract class Http2ClientStream extends AbstractClientStream { } if (transportError != null) { inboundTransportError(transportError); - sendCancel(); + sendCancel(Status.CANCELLED); } else { Status status = statusFromTrailers(trailers); stripTransportDetails(trailers); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java index 7700ac01aa..c9f9e09c18 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java @@ -35,8 +35,10 @@ import static io.grpc.testing.integration.Messages.PayloadType.COMPRESSABLE; import static io.grpc.testing.integration.Util.assertEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; @@ -86,6 +88,7 @@ import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; /** @@ -97,6 +100,7 @@ public abstract class AbstractTransportTest { ProtoUtils.keyForProto(Messages.SimpleContext.getDefaultInstance()); private static final AtomicReference requestHeadersCapture = new AtomicReference(); + private static final AtomicLong serverDelayMillis = new AtomicLong(0); private static ScheduledExecutorService testServiceExecutor; private static ServerImpl server; private static int OPERATION_TIMEOUT = 5000; @@ -106,6 +110,7 @@ public abstract class AbstractTransportTest { builder.addService(ServerInterceptors.intercept( TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)), + TestUtils.delayServerResponseInterceptor(serverDelayMillis), TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture), TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY))); try { @@ -133,6 +138,7 @@ public abstract class AbstractTransportTest { blockingStub = TestServiceGrpc.newBlockingStub(channel); asyncStub = TestServiceGrpc.newStub(channel); requestHeadersCapture.set(null); + serverDelayMillis.set(0); } /** Clean up. */ @@ -595,6 +601,60 @@ public abstract class AbstractTransportTest { && configuredTimeoutMinutes - transferredTimeoutMinutes <= 1); } + @Test + public void deadlineNotExceeded() { + serverDelayMillis.set(0); + // warm up the channel and JVM + blockingStub.emptyCall(Empty.getDefaultInstance()); + TestServiceGrpc.newBlockingStub(channel) + .configureNewStub() + .setDeadlineAfter(50, TimeUnit.MILLISECONDS) + .build().emptyCall(Empty.getDefaultInstance()); + } + + @Test(timeout = 10000) + public void deadlineExceeded() { + serverDelayMillis.set(20); + // warm up the channel and JVM + blockingStub.emptyCall(Empty.getDefaultInstance()); + TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel) + .configureNewStub() + .setDeadlineAfter(10, TimeUnit.MILLISECONDS) + .build(); + try { + stub.emptyCall(Empty.getDefaultInstance()); + fail("Expected deadline to be exceeded"); + } catch (Throwable t) { + assertEquals(Status.DEADLINE_EXCEEDED, Status.fromThrowable(t)); + } + } + + @Test(timeout = 10000) + public void deadlineExceededServerStreaming() throws Exception { + serverDelayMillis.set(10); // applied to every message + // warm up the channel and JVM + blockingStub.emptyCall(Empty.getDefaultInstance()); + StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() + .setResponseType(PayloadType.COMPRESSABLE) + .addResponseParameters(ResponseParameters.newBuilder() + .setSize(1)) + .addResponseParameters(ResponseParameters.newBuilder() + .setSize(1)) + .addResponseParameters(ResponseParameters.newBuilder() + .setSize(1)) + .addResponseParameters(ResponseParameters.newBuilder() + .setSize(1)) + .build(); + StreamRecorder recorder = StreamRecorder.create(); + TestServiceGrpc.newStub(channel) + .configureNewStub() + .setDeadlineAfter(30, TimeUnit.MILLISECONDS) + .build().streamingOutputCall(request, recorder); + recorder.awaitCompletion(); + assertEquals(Status.DEADLINE_EXCEEDED, Status.fromThrowable(recorder.getError())); + assertNotEquals(0, recorder.getValues().size()); + } + protected int unaryPayloadLength() { // 10MiB. return 10485760; diff --git a/netty/src/main/java/io/grpc/transport/netty/CancelStreamCommand.java b/netty/src/main/java/io/grpc/transport/netty/CancelStreamCommand.java index ab405264e5..b2fff2cd03 100644 --- a/netty/src/main/java/io/grpc/transport/netty/CancelStreamCommand.java +++ b/netty/src/main/java/io/grpc/transport/netty/CancelStreamCommand.java @@ -31,19 +31,35 @@ package io.grpc.transport.netty; +import static io.grpc.Status.Code.CANCELLED; +import static io.grpc.Status.Code.DEADLINE_EXCEEDED; + import com.google.common.base.Preconditions; +import io.grpc.Status; + +import java.util.EnumSet; + /** * Command sent from a Netty client stream to the handler to cancel the stream. */ class CancelStreamCommand { private final NettyClientStream stream; + private final Status reason; - CancelStreamCommand(NettyClientStream stream) { + CancelStreamCommand(NettyClientStream stream, Status reason) { this.stream = Preconditions.checkNotNull(stream, "stream"); + Preconditions.checkNotNull(reason); + Preconditions.checkArgument(EnumSet.of(CANCELLED, DEADLINE_EXCEEDED).contains(reason.getCode()), + "Invalid cancellation reason"); + this.reason = reason; } NettyClientStream stream() { return stream; } + + Status reason() { + return reason; + } } diff --git a/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java index 7ae2f6d13d..7e73a392d5 100644 --- a/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java @@ -318,7 +318,7 @@ class NettyClientHandler extends Http2ConnectionHandler { private void cancelStream(ChannelHandlerContext ctx, CancelStreamCommand cmd, ChannelPromise promise) { NettyClientStream stream = cmd.stream(); - stream.transportReportStatus(Status.CANCELLED, true, new Metadata.Trailers()); + stream.transportReportStatus(cmd.reason(), true, new Metadata.Trailers()); encoder().writeRstStream(ctx, stream.id(), Http2Error.CANCEL.code(), promise); } diff --git a/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java index d2c0949246..5b8c854bbb 100644 --- a/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java @@ -35,6 +35,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import io.grpc.Status; import io.grpc.transport.ClientStreamListener; import io.grpc.transport.Http2ClientStream; import io.grpc.transport.WritableBuffer; @@ -118,9 +119,9 @@ class NettyClientStream extends Http2ClientStream { } @Override - protected void sendCancel() { + protected void sendCancel(Status reason) { // Send the cancel command to the handler. - writeQueue.enqueue(new CancelStreamCommand(this), true); + writeQueue.enqueue(new CancelStreamCommand(this, reason), true); } @Override diff --git a/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java index 4cf286e358..21096acb25 100644 --- a/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java @@ -181,7 +181,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { verify(stream).id(eq(3)); when(stream.id()).thenReturn(3); // Cancel the stream. - writeQueue.enqueue(new CancelStreamCommand(stream), true); + writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), true); assertTrue(createPromise.isSuccess()); verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true), @@ -216,7 +216,18 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { public void cancelShouldSucceed() throws Exception { createStream(); verify(channel, times(1)).flush(); - writeQueue.enqueue(new CancelStreamCommand(stream), true); + writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), true); + + ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); + verify(ctx).write(eq(expected), eq(promise)); + verify(channel, times(2)).flush(); + } + + @Test + public void cancelDeadlineExceededShouldSucceed() throws Exception { + createStream(); + verify(channel, times(1)).flush(); + writeQueue.enqueue(new CancelStreamCommand(stream, Status.DEADLINE_EXCEEDED), true); ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); verify(ctx).write(eq(expected), eq(promise)); @@ -233,7 +244,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { verify(stream).id(idCaptor.capture()); when(stream.id()).thenReturn(idCaptor.getValue()); ChannelPromise cancelPromise = mock(ChannelPromise.class); - writeQueue.enqueue(new CancelStreamCommand(stream), cancelPromise, true); + writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), cancelPromise, true); verify(cancelPromise).setSuccess(); verify(channel, times(2)).flush(); verifyNoMoreInteractions(ctx); @@ -248,14 +259,29 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { public void cancelTwiceShouldSucceed() throws Exception { createStream(); - writeQueue.enqueue(new CancelStreamCommand(stream), promise, true); + writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), promise, true); ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); verify(ctx).write(eq(expected), any(ChannelPromise.class)); promise = mock(ChannelPromise.class); - writeQueue.enqueue(new CancelStreamCommand(stream), promise, true); + writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), promise, true); + verify(promise).setSuccess(); + } + + @Test + public void cancelTwiceDifferentReasons() throws Exception { + createStream(); + + writeQueue.enqueue(new CancelStreamCommand(stream, Status.DEADLINE_EXCEEDED), promise, true); + + ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); + verify(ctx).write(eq(expected), any(ChannelPromise.class)); + + promise = mock(ChannelPromise.class); + + writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), promise, true); verify(promise).setSuccess(); } @@ -357,7 +383,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { writeQueue.enqueue(new CreateStreamCommand(grpcHeaders, stream), true); verify(stream).id(3); when(stream.id()).thenReturn(3); - writeQueue.enqueue(new CancelStreamCommand(stream), true); + writeQueue.enqueue(new CancelStreamCommand(stream, Status.CANCELLED), true); verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true), any(Metadata.Trailers.class)); } diff --git a/netty/src/test/java/io/grpc/transport/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/transport/netty/NettyClientStreamTest.java index d77855417c..4eec32bc82 100644 --- a/netty/src/test/java/io/grpc/transport/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/transport/netty/NettyClientStreamTest.java @@ -42,6 +42,7 @@ import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -102,14 +103,28 @@ public class NettyClientStreamTest extends NettyStreamTestBase { public void cancelShouldSendCommand() { // Set stream id to indicate it has been created stream().id(STREAM_ID); - stream().cancel(); - verify(writeQueue).enqueue(any(CancelStreamCommand.class), eq(true)); + stream().cancel(Status.CANCELLED); + ArgumentCaptor commandCaptor = + ArgumentCaptor.forClass(CancelStreamCommand.class); + verify(writeQueue).enqueue(commandCaptor.capture(), eq(true)); + assertEquals(commandCaptor.getValue().reason(), Status.CANCELLED); + } + + @Test + public void deadlineExceededCancelShouldSendCommand() { + // Set stream id to indicate it has been created + stream().id(STREAM_ID); + stream().cancel(Status.DEADLINE_EXCEEDED); + ArgumentCaptor commandCaptor = + ArgumentCaptor.forClass(CancelStreamCommand.class); + verify(writeQueue).enqueue(commandCaptor.capture(), eq(true)); + assertEquals(commandCaptor.getValue().reason(), Status.DEADLINE_EXCEEDED); } @Test public void cancelShouldStillSendCommandIfStreamNotCreatedToCancelCreation() { - stream().cancel(); - verify(writeQueue).enqueue(any(CancelStreamCommand.class), eq(true)); + stream().cancel(Status.CANCELLED); + verify(writeQueue).enqueue(isA(CancelStreamCommand.class), eq(true)); } @Test @@ -340,7 +355,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase { @Override protected void closeStream() { - stream().cancel(); + stream().cancel(Status.CANCELLED); } private ByteBuf simpleGrpcFrame() { diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java index 0f0fb7c5fb..bedf6cdb6e 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java @@ -209,8 +209,8 @@ class OkHttpClientStream extends Http2ClientStream { } @Override - protected void sendCancel() { - transport.finishStream(id(), Status.CANCELLED, ErrorCode.CANCEL); + protected void sendCancel(Status reason) { + transport.finishStream(id(), reason, ErrorCode.CANCEL); } @Override diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java index 5d13fbfe24..b8f11c3f41 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java @@ -270,10 +270,10 @@ class OkHttpClientTransport implements ClientTransport { } catch (InterruptedException e) { // Restore the interrupt. Thread.currentThread().interrupt(); - clientStream.cancel(); + clientStream.cancel(Status.CANCELLED); throw new RuntimeException(e); } catch (ExecutionException e) { - clientStream.cancel(); + clientStream.cancel(Status.CANCELLED); throw new RuntimeException(e.getCause() != null ? e.getCause() : e); } } @@ -457,7 +457,8 @@ class OkHttpClientTransport implements ClientTransport { frameWriter.rstStream(streamId, ErrorCode.CANCEL); } if (status != null) { - boolean isCancelled = status.getCode() == Code.CANCELLED; + boolean isCancelled = (status.getCode() == Code.CANCELLED + || status.getCode() == Code.DEADLINE_EXCEEDED); stream.transportReportStatus(status, isCancelled, new Metadata.Trailers()); } if (!startPendingStreams()) { diff --git a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java index 15160e1d9a..dff4db7357 100644 --- a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java @@ -268,7 +268,7 @@ public class OkHttpClientTransportTest { clientTransport.newStream(method, new Metadata.Headers(), listener); OkHttpClientStream stream = streams.get(3); assertNotNull(stream); - stream.cancel(); + stream.cancel(Status.CANCELLED); verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), @@ -286,7 +286,7 @@ public class OkHttpClientTransportTest { new Header(Header.TARGET_PATH, "/fakemethod"), userAgentHeader, CONTENT_TYPE_HEADER, TE_HEADER); verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); - streams.get(3).cancel(); + streams.get(3).cancel(Status.CANCELLED); } @Test @@ -303,7 +303,18 @@ public class OkHttpClientTransportTest { HttpUtil.getGrpcUserAgent("okhttp", userAgent)), CONTENT_TYPE_HEADER, TE_HEADER); verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); - streams.get(3).cancel(); + streams.get(3).cancel(Status.CANCELLED); + } + + @Test + public void cancelStreamForDeadlineExceeded() throws Exception { + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, new Metadata.Headers(), listener); + OkHttpClientStream stream = streams.get(3); + assertNotNull(stream); + stream.cancel(Status.DEADLINE_EXCEEDED); + verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + listener.waitUntilStreamClosed(); } @Test @@ -320,7 +331,7 @@ public class OkHttpClientTransportTest { verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); Buffer sentFrame = captor.getValue(); assertEquals(createMessageFrame(message), sentFrame); - stream.cancel(); + stream.cancel(Status.CANCELLED); } @Test @@ -364,13 +375,13 @@ public class OkHttpClientTransportTest { verify(frameWriter).windowUpdate(eq(5), eq((long) 2 * messageFrameLength)); verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); - stream1.cancel(); + stream1.cancel(Status.CANCELLED); verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener1.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), listener1.status.getCode()); - stream2.cancel(); + stream2.cancel(Status.CANCELLED); verify(frameWriter).rstStream(eq(5), eq(ErrorCode.CANCEL)); listener2.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), @@ -394,7 +405,7 @@ public class OkHttpClientTransportTest { // We return the bytes for the stream window as we read the message. verify(frameWriter).windowUpdate(eq(3), eq(messageFrameLength)); - stream.cancel(); + stream.cancel(Status.CANCELLED); verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), @@ -429,7 +440,7 @@ public class OkHttpClientTransportTest { verify(frameWriter).data( eq(false), eq(3), any(Buffer.class), eq(messageLength + HEADER_LENGTH - partiallySentSize)); - stream.cancel(); + stream.cancel(Status.CANCELLED); listener.waitUntilStreamClosed(); } @@ -468,7 +479,7 @@ public class OkHttpClientTransportTest { frameHandler.windowUpdate(3, HEADER_LENGTH + 20); verify(frameWriter).data(eq(false), eq(3), any(Buffer.class), eq(HEADER_LENGTH + 20)); - stream.cancel(); + stream.cancel(Status.CANCELLED); listener.waitUntilStreamClosed(); } @@ -486,8 +497,8 @@ public class OkHttpClientTransportTest { assertEquals(2, streams.size()); verify(transportListener).transportShutdown(); - stream1.cancel(); - stream2.cancel(); + stream1.cancel(Status.CANCELLED); + stream2.cancel(Status.CANCELLED); listener1.waitUntilStreamClosed(); listener2.waitUntilStreamClosed(); assertEquals(0, streams.size()); @@ -563,7 +574,7 @@ public class OkHttpClientTransportTest { assertNewStreamFail(transport); - streams.get(startId).cancel(); + streams.get(startId).cancel(Status.CANCELLED); listener1.waitUntilStreamClosed(); verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL)); verify(transportListener).transportShutdown(); @@ -592,14 +603,14 @@ public class OkHttpClientTransportTest { assertEquals(3, (int) stream1.id()); // Finish the first stream - stream1.cancel(); + stream1.cancel(Status.CANCELLED); assertTrue("newStream() call is still blocking", newStreamReturn.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)); assertEquals(1, streams.size()); assertEquals(0, clientTransport.getPendingStreamSize()); OkHttpClientStream stream2 = streams.get(5); assertNotNull(stream2); - stream2.cancel(); + stream2.cancel(Status.CANCELLED); } @Test @@ -693,7 +704,7 @@ public class OkHttpClientTransportTest { // Now finish stream1, stream2 should be started and exhaust the id, // so stream3 should be failed. - stream1.cancel(); + stream1.cancel(Status.CANCELLED); assertTrue("newStream() call for stream2 is still blocking", newStreamReturn2.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)); assertTrue("newStream() call for stream3 is still blocking", @@ -705,7 +716,7 @@ public class OkHttpClientTransportTest { assertEquals(1, streams.size()); OkHttpClientStream stream2 = streams.get(startId + 2); assertNotNull(stream2); - stream2.cancel(); + stream2.cancel(Status.CANCELLED); } @Test @@ -762,7 +773,7 @@ public class OkHttpClientTransportTest { } else { verify(frameWriter, times(0)).flush(); } - stream.cancel(); + stream.cancel(Status.CANCELLED); } @Test @@ -819,7 +830,7 @@ public class OkHttpClientTransportTest { public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); - stream.cancel(); + stream.cancel(Status.CANCELLED); Buffer buffer = createMessageFrame( new byte[OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 2 + 1]); @@ -841,7 +852,7 @@ public class OkHttpClientTransportTest { public void receiveWindowUpdateForUnknownStream() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); - stream.cancel(); + stream.cancel(Status.CANCELLED); // This should be ignored. frameHandler.windowUpdate(3, 73); listener.waitUntilStreamClosed(); @@ -859,7 +870,7 @@ public class OkHttpClientTransportTest { method,new Metadata.Headers(), listener); assertTrue(stream.isReady()); assertTrue(listener.isOnReadyCalled()); - stream.cancel(); + stream.cancel(Status.CANCELLED); assertFalse(stream.isReady()); } @@ -915,7 +926,7 @@ public class OkHttpClientTransportTest { stream.setOnReadyThreshold(HEADER_LENGTH + messageLength + 1); assertTrue(listener.isOnReadyCalled()); - stream.cancel(); + stream.cancel(Status.CANCELLED); } @Test diff --git a/testing/src/main/java/io/grpc/testing/TestUtils.java b/testing/src/main/java/io/grpc/testing/TestUtils.java index dbfe1e824b..c6ceb3ffb0 100644 --- a/testing/src/main/java/io/grpc/testing/TestUtils.java +++ b/testing/src/main/java/io/grpc/testing/TestUtils.java @@ -31,6 +31,8 @@ package io.grpc.testing; +import com.google.common.base.Throwables; + import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; import io.grpc.ServerCall; @@ -59,6 +61,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLContext; @@ -132,6 +135,35 @@ public class TestUtils { }; } + /** + * Delay each payload by the given number of milliseconds. Useful for simulating slow server + * responses. + * @param delayMillis the delay applied to each payload, in milliseconds. + */ + public static ServerInterceptor delayServerResponseInterceptor(final AtomicLong delayMillis) { + return new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(String method, + ServerCall call, + Metadata.Headers headers, + ServerCallHandler next) { + return next.startCall(method, new SimpleForwardingServerCall(call) { + @Override + public void sendPayload(RespT payload) { + if (delayMillis.get() != 0) { + try { + Thread.sleep(delayMillis.get()); + } catch (InterruptedException e) { + Throwables.propagate(e); + } + } + super.sendPayload(payload); + } + }, headers); + } + }; + } + /** * Picks an unused port. */