diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index 6408288d37..a2c5b7fc00 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -31,6 +31,7 @@ package io.grpc.internal; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.grpc.internal.GrpcUtil.AUTHORITY_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; @@ -40,6 +41,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; @@ -55,15 +58,12 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import java.io.InputStream; -import java.util.LinkedList; -import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Implementation of {@link ClientCall}. @@ -181,13 +181,14 @@ final class ClientCallImpl extends ClientCall ClientStreamListener listener = new ClientStreamListenerImpl(observer); ListenableFuture transportFuture = clientTransportProvider.get(callOptions); + if (transportFuture.isDone()) { // Try to skip DelayedStream when possible to avoid the overhead of a volatile read in the // fast path. If that fails, stream will stay null and DelayedStream will be created. ClientTransport transport; try { transport = transportFuture.get(); - if (transport != null && updateTimeoutHeader(headers)) { + if (transport != null && updateTimeoutHeader(callOptions.getDeadlineNanoTime(), headers)) { stream = transport.newStream(method, headers, listener); } } catch (InterruptedException e) { @@ -196,9 +197,11 @@ final class ClientCallImpl extends ClientCall // Fall through to DelayedStream } } - if (stream == null) { - stream = new DelayedStream(transportFuture, headers, listener); + DelayedStream delayed; + stream = delayed = new DelayedStream(listener, callExecutor); + addListener(transportFuture, + new StreamCreationTask(delayed, headers, method, callOptions, listener)); } stream.setDecompressionRegistry(decompressorRegistry); @@ -210,7 +213,7 @@ final class ClientCallImpl extends ClientCall } // Start the deadline timer after stream creation because it will close the stream - Long timeoutMicros = getRemainingTimeoutMicros(); + Long timeoutMicros = getRemainingTimeoutMicros(callOptions.getDeadlineNanoTime()); if (timeoutMicros != null) { deadlineCancellationFuture = startDeadlineTimer(timeoutMicros); } @@ -223,12 +226,13 @@ final class ClientCallImpl extends ClientCall * * @return {@code false} if deadline already exceeded */ - private boolean updateTimeoutHeader(Metadata headers) { + static boolean updateTimeoutHeader(@Nullable Long deadlineNanoTime, Metadata headers) { // Fill out timeout on the headers + // TODO(someone): Find out if this should always remove the timeout, even when returning false. headers.removeAll(TIMEOUT_KEY); // Convert the deadline to timeout. Timeout is more favorable than deadline on the wire // because timeout tolerates the clock difference between machines. - Long timeoutMicros = getRemainingTimeoutMicros(); + Long timeoutMicros = getRemainingTimeoutMicros(deadlineNanoTime); if (timeoutMicros != null) { if (timeoutMicros <= 0) { return false; @@ -239,13 +243,12 @@ final class ClientCallImpl extends ClientCall } /** - * Return the remaining amout of microseconds before the deadline is reached. + * Return the remaining amount of microseconds before the deadline is reached. * *

{@code null} if deadline is not set. Negative value if already expired. */ @Nullable - private Long getRemainingTimeoutMicros() { - Long deadlineNanoTime = callOptions.getDeadlineNanoTime(); + private static Long getRemainingTimeoutMicros(@Nullable Long deadlineNanoTime) { if (deadlineNanoTime == null) { return null; } @@ -377,7 +380,7 @@ final class ClientCallImpl extends ClientCall @Override public void closed(Status status, Metadata trailers) { - Long timeoutMicros = getRemainingTimeoutMicros(); + Long timeoutMicros = getRemainingTimeoutMicros(callOptions.getDeadlineNanoTime()); if (status.getCode() == Status.Code.CANCELLED && timeoutMicros != null) { // When the server's deadline expires, it can only reset the stream with CANCEL and no // description. Since our timer may be delayed in firing, we double-check the deadline and @@ -419,282 +422,47 @@ final class ClientCallImpl extends ClientCall } } - private static final class PendingMessage { - private final InputStream message; - private final boolean shouldBeCompressed; - - public PendingMessage(InputStream message, boolean shouldBeCompressed) { - this.message = message; - this.shouldBeCompressed = shouldBeCompressed; - } + private void addListener(ListenableFuture future, FutureCallback callback) { + Executor executor = future.isDone() ? directExecutor() : callExecutor; + Futures.addCallback(future, callback, executor); } /** - * A stream that queues requests before the transport is available, and delegates to a real stream - * implementation when the transport is available. - * - *

{@code ClientStream} itself doesn't require thread-safety. However, the state of {@code - * DelayedStream} may be internally altered by different threads, thus internal synchronization is - * necessary. + * Wakes up delayed stream when the transport is ready or failed. */ - private class DelayedStream implements ClientStream { - final Metadata headers; - final ClientStreamListener listener; + @VisibleForTesting + static final class StreamCreationTask implements FutureCallback { + private final DelayedStream stream; + private final MethodDescriptor method; + private final Metadata headers; + private final ClientStreamListener listener; + private final CallOptions callOptions; - // Volatile to be readable without synchronization in the fast path. - // Writes are also done within synchronized(this). - volatile ClientStream realStream; - - @GuardedBy("this") - Compressor compressor; - // Can be either a Decompressor or a String - @GuardedBy("this") - Object decompressor; - @GuardedBy("this") - DecompressorRegistry decompressionRegistry; - @GuardedBy("this") - final List pendingMessages = new LinkedList(); - boolean messageCompressionEnabled; - @GuardedBy("this") - boolean pendingHalfClose; - @GuardedBy("this") - int pendingFlowControlRequests; - @GuardedBy("this") - boolean pendingFlush; - - /** - * Get a transport and try to create a stream on it. - */ - private class StreamCreationTask extends ContextRunnable { - final ListenableFuture transportFuture; - - StreamCreationTask(Context context, ListenableFuture transportFuture) { - super(context); - this.transportFuture = Preconditions.checkNotNull(transportFuture); - } - - @Override - public void runInContext() { - if (transportFuture.isDone()) { - ClientTransport transport; - try { - transport = transportFuture.get(); - } catch (Exception e) { - maybeClosePrematurely(Status.fromThrowable(e)); - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - return; - } - if (transport == null) { - maybeClosePrematurely(Status.UNAVAILABLE.withDescription("Channel is shutdown")); - return; - } - createStream(transport); - } else { - transportFuture.addListener(this, callExecutor); - } - } - } - - DelayedStream(ListenableFuture initialTransportFuture, Metadata headers, - ClientStreamListener listener) { + StreamCreationTask(DelayedStream stream, Metadata headers, MethodDescriptor method, + CallOptions callOptions, ClientStreamListener listener) { + this.stream = stream; this.headers = headers; + this.method = method; + this.callOptions = callOptions; this.listener = listener; - new StreamCreationTask(context, initialTransportFuture).run(); - } - - /** - * Creates a stream on a presumably usable transport. - */ - private void createStream(ClientTransport transport) { - synchronized (this) { - if (realStream == NOOP_CLIENT_STREAM) { - // Already cancelled - return; - } - Preconditions.checkState(realStream == null, "Stream already created: %s", realStream); - if (!updateTimeoutHeader(headers)) { - maybeClosePrematurely(Status.DEADLINE_EXCEEDED); - return; - } - realStream = transport.newStream(method, headers, listener); - Preconditions.checkNotNull(realStream, transport.toString() + " returned null stream"); - if (compressor != null) { - realStream.setCompressor(compressor); - } - if (this.decompressionRegistry != null) { - realStream.setDecompressionRegistry(this.decompressionRegistry); - } - for (PendingMessage message : pendingMessages) { - realStream.setMessageCompression(message.shouldBeCompressed); - realStream.writeMessage(message.message); - } - // Set this again, incase no messages were sent. - realStream.setMessageCompression(messageCompressionEnabled); - pendingMessages.clear(); - if (pendingHalfClose) { - realStream.halfClose(); - pendingHalfClose = false; - } - if (pendingFlowControlRequests > 0) { - realStream.request(pendingFlowControlRequests); - pendingFlowControlRequests = 0; - } - if (pendingFlush) { - realStream.flush(); - pendingFlush = false; - } - } - } - - private void maybeClosePrematurely(final Status reason) { - synchronized (this) { - if (realStream == null) { - realStream = NOOP_CLIENT_STREAM; - callExecutor.execute(new ContextRunnable(context) { - @Override - public void runInContext() { - try { - listener.closed(reason, new Metadata()); - } finally { - context.removeListener(ClientCallImpl.this); - } - } - }); - } - } } @Override - public void writeMessage(InputStream message) { - if (realStream == null) { - synchronized (this) { - if (realStream == null) { - pendingMessages.add(new PendingMessage(message, messageCompressionEnabled)); - return; - } - } + public void onSuccess(ClientTransport transport) { + if (transport == null) { + stream.maybeClosePrematurely(Status.UNAVAILABLE.withDescription("Channel is shutdown")); + return; } - realStream.writeMessage(message); + if (!updateTimeoutHeader(callOptions.getDeadlineNanoTime(), headers)) { + stream.maybeClosePrematurely(Status.DEADLINE_EXCEEDED); + return; + } + stream.setStream(transport.newStream(method, headers, listener)); } @Override - public void flush() { - if (realStream == null) { - synchronized (this) { - if (realStream == null) { - pendingFlush = true; - return; - } - } - } - realStream.flush(); - } - - @Override - public void cancel(Status reason) { - maybeClosePrematurely(reason); - realStream.cancel(reason); - } - - @Override - public void halfClose() { - if (realStream == null) { - synchronized (this) { - if (realStream == null) { - pendingHalfClose = true; - return; - } - } - } - realStream.halfClose(); - } - - @Override - public void request(int numMessages) { - if (realStream == null) { - synchronized (this) { - if (realStream == null) { - pendingFlowControlRequests += numMessages; - return; - } - } - } - realStream.request(numMessages); - } - - @Override - public synchronized void setCompressor(Compressor c) { - compressor = c; - if (realStream != null) { - realStream.setCompressor(c); - } - } - - @Override - public synchronized void setDecompressionRegistry(DecompressorRegistry registry) { - this.decompressionRegistry = registry; - if (realStream != null) { - realStream.setDecompressionRegistry(registry); - } - } - - @Override - public boolean isReady() { - if (realStream == null) { - synchronized (this) { - if (realStream == null) { - return false; - } - } - } - return realStream.isReady(); - } - - @Override - public synchronized void setMessageCompression(boolean enable) { - if (realStream != null) { - realStream.setMessageCompression(enable); - } else { - messageCompressionEnabled = enable; - } + public void onFailure(Throwable t) { + stream.maybeClosePrematurely(Status.fromThrowable(t)); } } - - private static final ClientStream NOOP_CLIENT_STREAM = new ClientStream() { - @Override public void writeMessage(InputStream message) {} - - @Override public void flush() {} - - @Override public void cancel(Status reason) {} - - @Override public void halfClose() {} - - @Override public void request(int numMessages) {} - - @Override public void setCompressor(Compressor c) {} - - @Override - public void setMessageCompression(boolean enable) { - // noop - } - - /** - * Always returns {@code false}, since this is only used when the startup of the {@link - * ClientCall} fails (i.e. the {@link ClientCall} is closed). - */ - @Override public boolean isReady() { - return false; - } - - @Override - public void setDecompressionRegistry(DecompressorRegistry registry) {} - - @Override - public String toString() { - return "NOOP_CLIENT_STREAM"; - } - }; } - diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java new file mode 100644 index 0000000000..87fd2cf9a6 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -0,0 +1,292 @@ +/* + * Copyright 2015, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; + +import io.grpc.Compressor; +import io.grpc.DecompressorRegistry; +import io.grpc.Metadata; +import io.grpc.Status; + +import java.io.InputStream; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.Executor; + +import javax.annotation.concurrent.GuardedBy; + +/** + * A stream that queues requests before the transport is available, and delegates to a real stream + * implementation when the transport is available. + * + *

{@code ClientStream} itself doesn't require thread-safety. However, the state of {@code + * DelayedStream} may be internally altered by different threads, thus internal synchronization is + * necessary. + */ +class DelayedStream implements ClientStream { + private final Executor callExecutor; + + private final ClientStreamListener listener; + + private final Object lock = new Object(); + + // Volatile to be readable without synchronization in the fast path. + // Writes are also done within synchronized(this). + private volatile ClientStream realStream; + + @GuardedBy("lock") + private Compressor compressor; + // Can be either a Decompressor or a String + @GuardedBy("lock") + private Object decompressor; + @GuardedBy("lock") + private DecompressorRegistry decompressionRegistry; + @GuardedBy("lock") + private final List pendingMessages = new LinkedList(); + private boolean messageCompressionEnabled; + @GuardedBy("lock") + private boolean pendingHalfClose; + @GuardedBy("lock") + private int pendingFlowControlRequests; + @GuardedBy("lock") + private boolean pendingFlush; + + static final class PendingMessage { + final InputStream message; + final boolean shouldBeCompressed; + + public PendingMessage(InputStream message, boolean shouldBeCompressed) { + this.message = message; + this.shouldBeCompressed = shouldBeCompressed; + } + } + + DelayedStream( + ClientStreamListener listener, + Executor callExecutor) { + this.listener = listener; + this.callExecutor = callExecutor; + } + + /** + * Creates a stream on a presumably usable transport. + */ + void setStream(ClientStream stream) { + synchronized (lock) { + if (realStream == NOOP_CLIENT_STREAM) { + // Already cancelled + return; + } + checkState(realStream == null, "Stream already created: %s", realStream); + realStream = stream; + if (compressor != null) { + realStream.setCompressor(compressor); + } + if (this.decompressionRegistry != null) { + realStream.setDecompressionRegistry(this.decompressionRegistry); + } + for (PendingMessage message : pendingMessages) { + realStream.setMessageCompression(message.shouldBeCompressed); + realStream.writeMessage(message.message); + } + // Set this again, incase no messages were sent. + realStream.setMessageCompression(messageCompressionEnabled); + pendingMessages.clear(); + if (pendingHalfClose) { + realStream.halfClose(); + pendingHalfClose = false; + } + if (pendingFlowControlRequests > 0) { + realStream.request(pendingFlowControlRequests); + pendingFlowControlRequests = 0; + } + if (pendingFlush) { + realStream.flush(); + pendingFlush = false; + } + } + } + + void maybeClosePrematurely(final Status reason) { + synchronized (lock) { + if (realStream == null) { + realStream = NOOP_CLIENT_STREAM; + callExecutor.execute(new Runnable() { + @Override + public void run() { + listener.closed(reason, new Metadata()); + } + }); + } + } + } + + @Override + public void writeMessage(InputStream message) { + if (realStream == null) { + synchronized (lock) { + if (realStream == null) { + pendingMessages.add(new PendingMessage(message, messageCompressionEnabled)); + return; + } + } + } + realStream.writeMessage(message); + } + + @Override + public void flush() { + if (realStream == null) { + synchronized (lock) { + if (realStream == null) { + pendingFlush = true; + return; + } + } + } + realStream.flush(); + } + + @Override + public void cancel(Status reason) { + maybeClosePrematurely(reason); + realStream.cancel(reason); + } + + @Override + public void halfClose() { + if (realStream == null) { + synchronized (lock) { + if (realStream == null) { + pendingHalfClose = true; + return; + } + } + } + realStream.halfClose(); + } + + @Override + public void request(int numMessages) { + if (realStream == null) { + synchronized (lock) { + if (realStream == null) { + pendingFlowControlRequests += numMessages; + return; + } + } + } + realStream.request(numMessages); + } + + @Override + public void setCompressor(Compressor c) { + synchronized (lock) { + compressor = c; + if (realStream != null) { + realStream.setCompressor(c); + } + } + } + + @Override + public void setDecompressionRegistry(DecompressorRegistry registry) { + synchronized (lock) { + this.decompressionRegistry = registry; + if (realStream != null) { + realStream.setDecompressionRegistry(registry); + } + } + } + + @Override + public boolean isReady() { + if (realStream == null) { + synchronized (lock) { + if (realStream == null) { + return false; + } + } + } + return realStream.isReady(); + } + + @Override + public void setMessageCompression(boolean enable) { + synchronized (lock) { + if (realStream != null) { + realStream.setMessageCompression(enable); + } else { + messageCompressionEnabled = enable; + } + } + } + + @VisibleForTesting + static final ClientStream NOOP_CLIENT_STREAM = new ClientStream() { + @Override public void writeMessage(InputStream message) {} + + @Override public void flush() {} + + @Override public void cancel(Status reason) {} + + @Override public void halfClose() {} + + @Override public void request(int numMessages) {} + + @Override public void setCompressor(Compressor c) {} + + @Override + public void setMessageCompression(boolean enable) { + // noop + } + + /** + * Always returns {@code false}, since this is only used when the startup of the {@link + * ClientCall} fails (i.e. the {@link ClientCall} is closed). + */ + @Override public boolean isReady() { + return false; + } + + @Override + public void setDecompressionRegistry(DecompressorRegistry registry) {} + + @Override + public String toString() { + return "NOOP_CLIENT_STREAM"; + } + }; +} diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index f67bdfca4b..d6c2597c01 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -287,7 +287,7 @@ public class AbstractClientStreamTest { /** * No-op base class for testing. */ - private static class BaseClientStreamListener implements ClientStreamListener { + static class BaseClientStreamListener implements ClientStreamListener { @Override public void messageRead(InputStream message) {} diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index b90685d494..8e27b732c2 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -37,9 +37,12 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; -import static org.mockito.Mockito.any; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; @@ -64,6 +67,7 @@ import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.internal.ClientCallImpl.ClientTransportProvider; +import io.grpc.internal.ClientCallImpl.StreamCreationTask; import org.junit.After; import org.junit.Before; @@ -101,6 +105,16 @@ public class ClientCallImplTest { Executors.newScheduledThreadPool(0); private final DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); + private final MethodDescriptor method = MethodDescriptor.create( + MethodType.UNARY, + "service/method", + new TestMarshaller(), + new TestMarshaller()); + + @Mock private ClientStreamListener streamListener; + @Mock private ClientTransport clientTransport; + @Mock private DelayedStream delayedStream; + @Captor private ArgumentCaptor statusCaptor; @Mock private ClientTransport transport; @@ -121,6 +135,7 @@ public class ClientCallImplTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); + decompressorRegistry.register(new Codec.Gzip(), true); } @@ -131,11 +146,20 @@ public class ClientCallImplTest { @Test public void advertisedEncodingsAreSent() { + final ClientTransport transport = mock(ClientTransport.class); + final ClientStream stream = mock(ClientStream.class); + ClientTransportProvider provider = new ClientTransportProvider() { + @Override + public ListenableFuture get(CallOptions callOptions) { + return Futures.immediateFuture(transport); + } + }; + when(transport.newStream(any(MethodDescriptor.class), any(Metadata.class), any(ClientStreamListener.class))).thenReturn(stream); ClientCallImpl call = new ClientCallImpl( - DESCRIPTOR, + method, MoreExecutors.directExecutor(), CallOptions.DEFAULT, provider, @@ -146,7 +170,7 @@ public class ClientCallImplTest { ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); verify(transport).newStream( - eq(DESCRIPTOR), metadataCaptor.capture(), isA(ClientStreamListener.class)); + eq(method), metadataCaptor.capture(), isA(ClientStreamListener.class)); Metadata actual = metadataCaptor.getValue(); Set acceptedEncodings = @@ -424,6 +448,56 @@ public class ClientCallImplTest { } } + @Test + public void streamCreationTask_failure() { + StreamCreationTask task = new StreamCreationTask( + delayedStream, new Metadata(), method, CallOptions.DEFAULT, streamListener); + + task.onFailure(Status.CANCELLED.asException()); + + verify(delayedStream).maybeClosePrematurely(statusCaptor.capture()); + assertEquals(Status.Code.CANCELLED, statusCaptor.getValue().getCode()); + } + + @Test + public void streamCreationTask_transportShutdown() { + StreamCreationTask task = new StreamCreationTask( + delayedStream, new Metadata(), method, CallOptions.DEFAULT, streamListener); + + // null means no transport available + task.onSuccess(null); + + verify(delayedStream).maybeClosePrematurely(statusCaptor.capture()); + assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); + } + + @Test + public void streamCreationTask_deadlineExceeded() { + Metadata headers = new Metadata(); + headers.put(GrpcUtil.TIMEOUT_KEY, 1L); + CallOptions callOptions = CallOptions.DEFAULT.withDeadlineNanoTime(System.nanoTime() - 1); + StreamCreationTask task = + new StreamCreationTask(delayedStream, headers, method, callOptions, streamListener); + + task.onSuccess(clientTransport); + + verify(delayedStream).maybeClosePrematurely(statusCaptor.capture()); + assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); + } + + @Test + public void streamCreationTask_success() { + Metadata headers = new Metadata(); + StreamCreationTask task = + new StreamCreationTask(delayedStream, headers, method, CallOptions.DEFAULT, streamListener); + when(clientTransport.newStream(method, headers, streamListener)) + .thenReturn(DelayedStream.NOOP_CLIENT_STREAM); + + task.onSuccess(clientTransport); + + verify(clientTransport).newStream(method, headers, streamListener); + } + private static class TestMarshaller implements Marshaller { @Override public InputStream stream(T value) { diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java new file mode 100644 index 0000000000..ff7f4d5b07 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -0,0 +1,164 @@ +/* + * Copyright 2015, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.internal; + +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.util.concurrent.MoreExecutors; + +import io.grpc.Codec; +import io.grpc.DecompressorRegistry; +import io.grpc.IntegerMarshaller; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Status; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.concurrent.Executor; + +/** + * Tests for {@link DelayedStream}. Most of the state checking is enforced by + * {@link ClientCallImpl} so we don't check it here. + */ +@RunWith(JUnit4.class) +public class DelayedStreamTest { + private static final Executor executor = MoreExecutors.directExecutor(); + + @Rule public final ExpectedException thrown = ExpectedException.none(); + + @Mock private ClientStreamListener listener; + @Mock private ClientTransport transport; + @Mock private ClientStream realStream; + @Captor private ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + private DelayedStream stream; + private Metadata headers = new Metadata(); + + private MethodDescriptor method = MethodDescriptor.create( + MethodType.UNARY, "service/method", new IntegerMarshaller(), new IntegerMarshaller()); + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + stream = new DelayedStream(listener, executor); + } + + @Test + public void setStream_sendsAllMessages() { + stream.setCompressor(Codec.Identity.NONE); + + DecompressorRegistry registry = DecompressorRegistry.newEmptyInstance(); + stream.setDecompressionRegistry(registry); + + stream.setMessageCompression(true); + InputStream message = new ByteArrayInputStream(new byte[]{'a'}); + stream.writeMessage(message); + stream.setMessageCompression(false); + stream.writeMessage(message); + + stream.setStream(realStream); + + + verify(realStream).setCompressor(Codec.Identity.NONE); + verify(realStream).setDecompressionRegistry(registry); + + // Verify that the order was correct, even though they should be interleaved with the + // writeMessage calls + verify(realStream).setMessageCompression(true); + verify(realStream, times(2)).setMessageCompression(false); + + verify(realStream, times(2)).writeMessage(message); + } + + @Test + public void setStream_halfClose() { + stream.halfClose(); + stream.setStream(realStream); + + verify(realStream).halfClose(); + } + + @Test + public void setStream_flush() { + stream.flush(); + stream.setStream(realStream); + + verify(realStream).flush(); + } + + @Test + public void setStream_flowControl() { + stream.request(1); + stream.request(2); + + stream.setStream(realStream); + + verify(realStream).request(3); + } + + @Test + public void setStream_cantCreateTwice() { + // The first call will be a success + stream.setStream(realStream); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Stream already created"); + + stream.setStream(realStream); + } + + @Test + public void streamCancelled() { + stream.cancel(Status.CANCELLED); + + // Should be a no op, and not fail due to transport not returning a newStream + stream.setStream(realStream); + + verify(listener).closed(eq(Status.CANCELLED), isA(Metadata.class)); + } +}