From eccd2311311f8c480f396fd76bf8fe9b5931a035 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Sat, 5 Mar 2016 12:19:07 -0800 Subject: [PATCH] Don't hold a lock in DelayedStream when calling realStream Our current lock ordering rules prohibit holding a lock when calling the channel and stream. This change avoids the lock for both DelayedClientTransport and DelayedStream. It is effectively a rewrite of DelayedStream. The fixes to ClientCallImpl were to ensure sane state in DelayedStream. Fixes #1510 --- .../java/io/grpc/internal/ClientCallImpl.java | 6 +- .../grpc/internal/DelayedClientTransport.java | 47 +-- .../java/io/grpc/internal/DelayedStream.java | 384 +++++++++--------- .../internal/DelayedClientTransportTest.java | 12 + .../io/grpc/internal/DelayedStreamTest.java | 100 ++++- 5 files changed, 314 insertions(+), 235 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index e5ce53ac71..26f664fbe8 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -31,6 +31,7 @@ package io.grpc.internal; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -208,11 +209,11 @@ final class ClientCallImpl extends ClientCall stream.setAuthority(callOptions.getAuthority()); } stream.setCompressor(compressor); + + stream.start(new ClientStreamListenerImpl(observer)); if (compressor != Codec.Identity.NONE) { stream.setMessageCompression(true); } - - stream.start(new ClientStreamListenerImpl(observer)); // Delay any sources of cancellation after start(), because most of the transports are broken if // they receive cancel before start. Issue #1343 has more details @@ -269,6 +270,7 @@ final class ClientCallImpl extends ClientCall @Override public void request(int numMessages) { Preconditions.checkState(stream != null, "Not started"); + checkArgument(numMessages >= 0, "Number requested must be non-negative"); stream.request(numMessages); } diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 83cfc0cc08..537023148e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -80,44 +80,41 @@ class DelayedClientTransport implements ManagedClientTransport { @Override public ClientStream newStream(MethodDescriptor method, Metadata headers) { Supplier supplier = transportSupplier; + if (supplier == null) { + synchronized (lock) { + // Check again, since it may have changed while waiting for lock + supplier = transportSupplier; + if (supplier == null && !shutdown) { + PendingStream pendingStream = new PendingStream(method, headers); + pendingStreams.add(pendingStream); + return pendingStream; + } + } + } if (supplier != null) { return supplier.get().newStream(method, headers); } - synchronized (lock) { - // Check again, since it may have changed while waiting for lock - supplier = transportSupplier; - if (supplier != null) { - return supplier.get().newStream(method, headers); - } - if (!shutdown) { - PendingStream pendingStream = new PendingStream(method, headers); - pendingStreams.add(pendingStream); - return pendingStream; - } - } return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown")); } @Override public void ping(final PingCallback callback, Executor executor) { Supplier supplier = transportSupplier; + if (supplier == null) { + synchronized (lock) { + // Check again, since it may have changed while waiting for lock + supplier = transportSupplier; + if (supplier == null && !shutdown) { + PendingPing pendingPing = new PendingPing(callback, executor); + pendingPings.add(pendingPing); + return; + } + } + } if (supplier != null) { supplier.get().ping(callback, executor); return; } - synchronized (lock) { - // Check again, since it may have changed while waiting for lock - supplier = transportSupplier; - if (supplier != null) { - supplier.get().ping(callback, executor); - return; - } - if (!shutdown) { - PendingPing pendingPing = new PendingPing(callback, executor); - pendingPings.add(pendingPing); - return; - } - } executor.execute(new Runnable() { @Override public void run() { callback.onFailure( diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index c6152936b2..bc22a6ee62 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -34,15 +34,13 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import com.google.common.base.Preconditions; - import io.grpc.Compressor; import io.grpc.Decompressor; import io.grpc.Metadata; import io.grpc.Status; import java.io.InputStream; -import java.util.LinkedList; +import java.util.ArrayList; import java.util.List; import javax.annotation.concurrent.GuardedBy; @@ -56,263 +54,251 @@ import javax.annotation.concurrent.GuardedBy; * necessary. */ class DelayedStream implements ClientStream { - - // set to non null once both listener and realStream are valid. After this point it is safe - // to call methods on startedRealStream. Note: this may be true even after the delayed stream is - // cancelled. This should be okay. - private volatile ClientStream startedRealStream; - @GuardedBy("this") - private String authority; - @GuardedBy("this") + /** {@code true} once realStream is valid and all pending calls have been drained. */ + private volatile boolean passThrough; + /** + * Non-{@code null} iff start has been called. Used to assert methods are called in appropriate + * order, but also used if an error occurrs before {@code realStream} is set. + */ private ClientStreamListener listener; - @GuardedBy("this") + /** Must hold {@code this} lock when setting. */ private ClientStream realStream; @GuardedBy("this") private Status error; - @GuardedBy("this") - private final List pendingMessages = new LinkedList(); - private boolean messageCompressionEnabled; - @GuardedBy("this") - private boolean pendingHalfClose; - @GuardedBy("this") - private int pendingFlowControlRequests; - @GuardedBy("this") - private boolean pendingFlush; - @GuardedBy("this") - private Compressor compressor; - @GuardedBy("this") - private Decompressor decompressor; - - static final class PendingMessage { - final InputStream message; - final boolean shouldBeCompressed; - - public PendingMessage(InputStream message, boolean shouldBeCompressed) { - this.message = message; - this.shouldBeCompressed = shouldBeCompressed; - } - } - - @Override - public synchronized void setAuthority(String authority) { - checkState(listener == null, "must be called before start"); - checkNotNull(authority, "authority"); - if (realStream == null) { - this.authority = authority; - } else { - realStream.setAuthority(authority); - } - } - - @Override - public void start(ClientStreamListener listener) { - synchronized (this) { - // start may be called at most once. - checkState(this.listener == null, "already started"); - this.listener = checkNotNull(listener, "listener"); - - // Check error first rather than success. - if (error != null) { - listener.closed(error, new Metadata()); - } - // In the event that an error happened, realStream will be a noop stream. We still call - // start stream in order to drain references to pending messages. - if (realStream != null) { - startStream(); - } - } - } - - @GuardedBy("this") - private void startStream() { - checkState(realStream != null, "realStream"); - checkState(listener != null, "listener"); - if (authority != null) { - realStream.setAuthority(authority); - } - realStream.start(listener); - - if (decompressor != null) { - realStream.setDecompressor(decompressor); - } - if (compressor != null) { - realStream.setCompressor(compressor); - } - - for (PendingMessage message : pendingMessages) { - realStream.setMessageCompression(message.shouldBeCompressed); - realStream.writeMessage(message.message); - } - // Set this again, incase no messages were sent. - realStream.setMessageCompression(messageCompressionEnabled); - pendingMessages.clear(); - if (pendingHalfClose) { - realStream.halfClose(); - pendingHalfClose = false; - } - if (pendingFlowControlRequests > 0) { - realStream.request(pendingFlowControlRequests); - pendingFlowControlRequests = 0; - } - if (pendingFlush) { - realStream.flush(); - pendingFlush = false; - } - // Ensures visibility. - startedRealStream = realStream; - } + private List pendingCalls = new ArrayList(); /** * Transfers all pending and future requests and mutations to the given stream. * *

No-op if either this method or {@link #cancel} have already been called. */ + // When this method returns, passThrough is guaranteed to be true final void setStream(ClientStream stream) { synchronized (this) { - if (error != null || realStream != null) { + // If realStream != null, then either setStream() or cancel() has been called. + if (realStream != null) { return; } realStream = checkNotNull(stream, "stream"); - // listener can only be non-null if start has already been called. - if (listener != null) { - startStream(); + } + + drainPendingCalls(); + } + + /** + * Called to transition {@code passThrough} to {@code true}. This method is not safe to be called + * multiple times; the caller must ensure it will only be called once, ever. {@code this} lock + * should not be held when calling this method. + */ + private void drainPendingCalls() { + assert realStream != null; + assert !passThrough; + List toRun = new ArrayList(); + while (true) { + synchronized (this) { + if (pendingCalls.isEmpty()) { + pendingCalls = null; + passThrough = true; + break; + } + // Since there were pendingCalls, we need to process them. To maintain ordering we can't set + // passThrough=true until we run all pendingCalls, but new Runnables may be added after we + // drop the lock. So we will have to re-check pendingCalls. + List tmp = toRun; + toRun = pendingCalls; + pendingCalls = tmp; } + for (Runnable runnable : toRun) { + // Must not call transport while lock is held to prevent deadlocks. + // TODO(ejona): exception handling + runnable.run(); + } + toRun.clear(); } } - @Override - public void writeMessage(InputStream message) { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - pendingMessages.add(new PendingMessage(message, messageCompressionEnabled)); - return; - } + /** + * Enqueue the runnable or execute it now. Call sites that may be called many times may want avoid + * this method if {@code passThrough == true}. + * + *

Note that this method is no more thread-safe than {@code runnable}. It is thread-safe if and + * only if {@code runnable} is thread-safe. + */ + private void delayOrExecute(Runnable runnable) { + synchronized (this) { + if (!passThrough) { + pendingCalls.add(runnable); + return; } } - startedRealStream.writeMessage(message); + runnable.run(); + } + + @Override + public void setAuthority(final String authority) { + checkState(listener == null, "May only be called before start"); + checkNotNull(authority, "authority"); + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.setAuthority(authority); + } + }); + } + + @Override + public void start(final ClientStreamListener listener) { + checkState(this.listener == null, "already started"); + + Status savedError; + synchronized (this) { + this.listener = checkNotNull(listener, "listener"); + // If error != null, then cancel() has been called and was unable to close the listener + savedError = error; + } + if (savedError != null) { + listener.closed(savedError, new Metadata()); + return; + } + + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.start(listener); + } + }); + } + + @Override + public void writeMessage(final InputStream message) { + checkNotNull(message, "message"); + if (passThrough) { + realStream.writeMessage(message); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.writeMessage(message); + } + }); + } } @Override public void flush() { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - pendingFlush = true; - return; + if (passThrough) { + realStream.flush(); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.flush(); } - } + }); } - startedRealStream.flush(); } + // When this method returns, passThrough is guaranteed to be true @Override - public void cancel(Status reason) { - // At least one of them is null. - ClientStream streamToBeCancelled = startedRealStream; - ClientStreamListener listenerToBeCalled = null; - if (streamToBeCancelled == null) { - synchronized (this) { - if (realStream != null) { - // realStream already set. Just cancel it. - streamToBeCancelled = realStream; - } else if (error == null) { - // Neither realStream and error are set. Will set the error and call the listener if - // it's set. - error = checkNotNull(reason); - realStream = NoopClientStream.INSTANCE; - if (listener != null) { - // call startStream anyways to drain pending messages. - startStream(); - listenerToBeCalled = listener; - } - } // else: error already set, do nothing. + public void cancel(final Status reason) { + checkNotNull(reason, "reason"); + boolean delegateToRealStream = true; + ClientStreamListener listenerToClose = null; + synchronized (this) { + // If realStream != null, then either setStream() or cancel() has been called + if (realStream == null) { + realStream = NoopClientStream.INSTANCE; + delegateToRealStream = false; + + // If listener == null, then start() will later call listener with 'error' + listenerToClose = listener; + error = reason; } } - if (listenerToBeCalled != null) { - Preconditions.checkState(streamToBeCancelled == null, "unexpected streamToBeCancelled"); - listenerToBeCalled.closed(reason, new Metadata()); - } - if (streamToBeCancelled != null) { - streamToBeCancelled.cancel(reason); + if (delegateToRealStream) { + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.cancel(reason); + } + }); + } else { + if (listenerToClose != null) { + listenerToClose.closed(reason, new Metadata()); + } + drainPendingCalls(); } } @Override public void halfClose() { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - pendingHalfClose = true; - return; - } + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.halfClose(); } - } - startedRealStream.halfClose(); + }); } @Override - public void request(int numMessages) { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - pendingFlowControlRequests += numMessages; - return; + public void request(final int numMessages) { + if (passThrough) { + realStream.request(numMessages); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.request(numMessages); } - } + }); } - startedRealStream.request(numMessages); } @Override - public void setCompressor(Compressor compressor) { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - this.compressor = compressor; - return; - } + public void setCompressor(final Compressor compressor) { + checkNotNull(compressor, "compressor"); + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.setCompressor(compressor); } - } - startedRealStream.setCompressor(compressor); + }); } @Override public void setDecompressor(Decompressor decompressor) { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - this.decompressor = decompressor; - return; - } - } - } - startedRealStream.setDecompressor(decompressor); + checkNotNull(decompressor, "decompressor"); + // This method being called only makes sense after setStream() has been called (but not + // necessarily returned), but there is not necessarily a happens-before relationship. This + // synchronized block creates one. + synchronized (this) { } + checkState(realStream != null, "How did we receive a reply before the request is sent?"); + // ClientStreamListenerImpl (in ClientCallImpl) requires setDecompressor to be set immediately, + // since messages may be processed immediately after this method returns. + realStream.setDecompressor(decompressor); } @Override public boolean isReady() { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - return false; - } - } + if (passThrough) { + return realStream.isReady(); + } else { + return false; } - return startedRealStream.isReady(); } @Override - public void setMessageCompression(boolean enable) { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - messageCompressionEnabled = enable; - return; + public void setMessageCompression(final boolean enable) { + if (passThrough) { + realStream.setMessageCompression(enable); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.setMessageCompression(enable); } - } + }); } - startedRealStream.setMessageCompression(enable); } } diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index de5141d2fc..d9d5af16b4 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -219,6 +219,18 @@ public class DelayedClientTransportTest { verify(mockRealTransport).ping(same(pingCallback), same(mockExecutor)); } + @Test public void shutdownThenPing() { + delayedTransport.shutdown(); + verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportTerminated(); + delayedTransport.ping(pingCallback, mockExecutor); + verifyNoMoreInteractions(pingCallback); + ArgumentCaptor runnableCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(mockExecutor).execute(runnableCaptor.capture()); + runnableCaptor.getValue().run(); + verify(pingCallback).onFailure(any(Throwable.class)); + } + @Test public void shutdownThenNewStream() { delayedTransport.shutdown(); verify(transportListener).transportShutdown(any(Status.class)); diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 229a643f55..8c390364d3 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -31,12 +31,18 @@ package io.grpc.internal; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; -import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; import io.grpc.Codec; import io.grpc.Metadata; @@ -83,11 +89,28 @@ public class DelayedStreamTest { inOrder.verify(realStream).start(listener); } + @Test(expected = IllegalStateException.class) + public void setAuthority_afterStart() { + stream.start(listener); + stream.setAuthority("notgonnawork"); + } + + @Test(expected = IllegalStateException.class) + public void start_afterStart() { + stream.start(listener); + stream.start(mock(ClientStreamListener.class)); + } + + @Test(expected = IllegalStateException.class) + public void setDecompressor_beforeSetStream() { + stream.start(listener); + stream.setDecompressor(Codec.Identity.NONE); + } + @Test public void setStream_sendsAllMessages() { stream.start(listener); stream.setCompressor(Codec.Identity.NONE); - stream.setDecompressor(Codec.Identity.NONE); stream.setMessageCompression(true); InputStream message = new ByteArrayInputStream(new byte[]{'a'}); @@ -96,17 +119,19 @@ public class DelayedStreamTest { stream.writeMessage(message); stream.setStream(realStream); + stream.setDecompressor(Codec.Identity.NONE); verify(realStream).setCompressor(Codec.Identity.NONE); verify(realStream).setDecompressor(Codec.Identity.NONE); - // Verify that the order was correct, even though they should be interleaved with the - // writeMessage calls verify(realStream).setMessageCompression(true); - verify(realStream, times(2)).setMessageCompression(false); + verify(realStream).setMessageCompression(false); verify(realStream, times(2)).writeMessage(message); verify(realStream).start(listener); + + stream.writeMessage(message); + verify(realStream, times(3)).writeMessage(message); } @Test @@ -123,8 +148,10 @@ public class DelayedStreamTest { stream.start(listener); stream.flush(); stream.setStream(realStream); - verify(realStream).flush(); + + stream.flush(); + verify(realStream, times(2)).flush(); } @Test @@ -132,17 +159,45 @@ public class DelayedStreamTest { stream.start(listener); stream.request(1); stream.request(2); - stream.setStream(realStream); + verify(realStream).request(1); + verify(realStream).request(2); + stream.request(3); verify(realStream).request(3); } + @Test + public void setStream_setMessageCompression() { + stream.start(listener); + stream.setMessageCompression(false); + stream.setStream(realStream); + verify(realStream).setMessageCompression(false); + + stream.setMessageCompression(true); + verify(realStream).setMessageCompression(true); + } + + @Test + public void setStream_isReady() { + stream.start(listener); + assertFalse(stream.isReady()); + stream.setStream(realStream); + verify(realStream, never()).isReady(); + + assertFalse(stream.isReady()); + verify(realStream).isReady(); + + when(realStream.isReady()).thenReturn(true); + assertTrue(stream.isReady()); + verify(realStream, times(2)).isReady(); + } + @Test public void startThenCancelled() { stream.start(listener); stream.cancel(Status.CANCELLED); - verify(listener).closed(eq(Status.CANCELLED), isA(Metadata.class)); + verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class)); } @Test @@ -170,10 +225,37 @@ public class DelayedStreamTest { verify(realStream).cancel(same(Status.CANCELLED)); } + @Test + public void setStreamTwice() { + stream.start(listener); + stream.setStream(realStream); + verify(realStream).start(listener); + stream.setStream(mock(ClientStream.class)); + stream.flush(); + verify(realStream).flush(); + } + + @Test + public void cancelThenSetStream() { + stream.cancel(Status.CANCELLED); + stream.setStream(realStream); + stream.start(listener); + stream.isReady(); + verifyNoMoreInteractions(realStream); + } + + @Test + public void cancel_beforeStart() { + Status status = Status.CANCELLED.withDescription("that was quick"); + stream.cancel(status); + stream.start(listener); + verify(listener).closed(same(status), any(Metadata.class)); + } + @Test public void cancelledThenStart() { stream.cancel(Status.CANCELLED); stream.start(listener); - verify(listener).closed(eq(Status.CANCELLED), isA(Metadata.class)); + verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class)); } }