Don't hold a lock in DelayedStream when calling realStream

Our current lock ordering rules prohibit holding a lock when calling the
channel and stream. This change avoids the lock for both
DelayedClientTransport and DelayedStream. It is effectively a rewrite of
DelayedStream.

The fixes to ClientCallImpl were to ensure sane state in DelayedStream.

Fixes #1510
This commit is contained in:
Eric Anderson 2016-03-05 12:19:07 -08:00
parent b9c12327eb
commit eccd231131
5 changed files with 314 additions and 235 deletions

View File

@ -31,6 +31,7 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
@ -208,11 +209,11 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
stream.setAuthority(callOptions.getAuthority());
}
stream.setCompressor(compressor);
stream.start(new ClientStreamListenerImpl(observer));
if (compressor != Codec.Identity.NONE) {
stream.setMessageCompression(true);
}
stream.start(new ClientStreamListenerImpl(observer));
// Delay any sources of cancellation after start(), because most of the transports are broken if
// they receive cancel before start. Issue #1343 has more details
@ -269,6 +270,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@Override
public void request(int numMessages) {
Preconditions.checkState(stream != null, "Not started");
checkArgument(numMessages >= 0, "Number requested must be non-negative");
stream.request(numMessages);
}

View File

@ -80,44 +80,41 @@ class DelayedClientTransport implements ManagedClientTransport {
@Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
Supplier<ClientTransport> supplier = transportSupplier;
if (supplier == null) {
synchronized (lock) {
// Check again, since it may have changed while waiting for lock
supplier = transportSupplier;
if (supplier == null && !shutdown) {
PendingStream pendingStream = new PendingStream(method, headers);
pendingStreams.add(pendingStream);
return pendingStream;
}
}
}
if (supplier != null) {
return supplier.get().newStream(method, headers);
}
synchronized (lock) {
// Check again, since it may have changed while waiting for lock
supplier = transportSupplier;
if (supplier != null) {
return supplier.get().newStream(method, headers);
}
if (!shutdown) {
PendingStream pendingStream = new PendingStream(method, headers);
pendingStreams.add(pendingStream);
return pendingStream;
}
}
return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown"));
}
@Override
public void ping(final PingCallback callback, Executor executor) {
Supplier<ClientTransport> supplier = transportSupplier;
if (supplier == null) {
synchronized (lock) {
// Check again, since it may have changed while waiting for lock
supplier = transportSupplier;
if (supplier == null && !shutdown) {
PendingPing pendingPing = new PendingPing(callback, executor);
pendingPings.add(pendingPing);
return;
}
}
}
if (supplier != null) {
supplier.get().ping(callback, executor);
return;
}
synchronized (lock) {
// Check again, since it may have changed while waiting for lock
supplier = transportSupplier;
if (supplier != null) {
supplier.get().ping(callback, executor);
return;
}
if (!shutdown) {
PendingPing pendingPing = new PendingPing(callback, executor);
pendingPings.add(pendingPing);
return;
}
}
executor.execute(new Runnable() {
@Override public void run() {
callback.onFailure(

View File

@ -34,15 +34,13 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.base.Preconditions;
import io.grpc.Compressor;
import io.grpc.Decompressor;
import io.grpc.Metadata;
import io.grpc.Status;
import java.io.InputStream;
import java.util.LinkedList;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.concurrent.GuardedBy;
@ -56,263 +54,251 @@ import javax.annotation.concurrent.GuardedBy;
* necessary.
*/
class DelayedStream implements ClientStream {
// set to non null once both listener and realStream are valid. After this point it is safe
// to call methods on startedRealStream. Note: this may be true even after the delayed stream is
// cancelled. This should be okay.
private volatile ClientStream startedRealStream;
@GuardedBy("this")
private String authority;
@GuardedBy("this")
/** {@code true} once realStream is valid and all pending calls have been drained. */
private volatile boolean passThrough;
/**
* Non-{@code null} iff start has been called. Used to assert methods are called in appropriate
* order, but also used if an error occurrs before {@code realStream} is set.
*/
private ClientStreamListener listener;
@GuardedBy("this")
/** Must hold {@code this} lock when setting. */
private ClientStream realStream;
@GuardedBy("this")
private Status error;
@GuardedBy("this")
private final List<PendingMessage> pendingMessages = new LinkedList<PendingMessage>();
private boolean messageCompressionEnabled;
@GuardedBy("this")
private boolean pendingHalfClose;
@GuardedBy("this")
private int pendingFlowControlRequests;
@GuardedBy("this")
private boolean pendingFlush;
@GuardedBy("this")
private Compressor compressor;
@GuardedBy("this")
private Decompressor decompressor;
static final class PendingMessage {
final InputStream message;
final boolean shouldBeCompressed;
public PendingMessage(InputStream message, boolean shouldBeCompressed) {
this.message = message;
this.shouldBeCompressed = shouldBeCompressed;
}
}
@Override
public synchronized void setAuthority(String authority) {
checkState(listener == null, "must be called before start");
checkNotNull(authority, "authority");
if (realStream == null) {
this.authority = authority;
} else {
realStream.setAuthority(authority);
}
}
@Override
public void start(ClientStreamListener listener) {
synchronized (this) {
// start may be called at most once.
checkState(this.listener == null, "already started");
this.listener = checkNotNull(listener, "listener");
// Check error first rather than success.
if (error != null) {
listener.closed(error, new Metadata());
}
// In the event that an error happened, realStream will be a noop stream. We still call
// start stream in order to drain references to pending messages.
if (realStream != null) {
startStream();
}
}
}
@GuardedBy("this")
private void startStream() {
checkState(realStream != null, "realStream");
checkState(listener != null, "listener");
if (authority != null) {
realStream.setAuthority(authority);
}
realStream.start(listener);
if (decompressor != null) {
realStream.setDecompressor(decompressor);
}
if (compressor != null) {
realStream.setCompressor(compressor);
}
for (PendingMessage message : pendingMessages) {
realStream.setMessageCompression(message.shouldBeCompressed);
realStream.writeMessage(message.message);
}
// Set this again, incase no messages were sent.
realStream.setMessageCompression(messageCompressionEnabled);
pendingMessages.clear();
if (pendingHalfClose) {
realStream.halfClose();
pendingHalfClose = false;
}
if (pendingFlowControlRequests > 0) {
realStream.request(pendingFlowControlRequests);
pendingFlowControlRequests = 0;
}
if (pendingFlush) {
realStream.flush();
pendingFlush = false;
}
// Ensures visibility.
startedRealStream = realStream;
}
private List<Runnable> pendingCalls = new ArrayList<Runnable>();
/**
* Transfers all pending and future requests and mutations to the given stream.
*
* <p>No-op if either this method or {@link #cancel} have already been called.
*/
// When this method returns, passThrough is guaranteed to be true
final void setStream(ClientStream stream) {
synchronized (this) {
if (error != null || realStream != null) {
// If realStream != null, then either setStream() or cancel() has been called.
if (realStream != null) {
return;
}
realStream = checkNotNull(stream, "stream");
// listener can only be non-null if start has already been called.
if (listener != null) {
startStream();
}
drainPendingCalls();
}
/**
* Called to transition {@code passThrough} to {@code true}. This method is not safe to be called
* multiple times; the caller must ensure it will only be called once, ever. {@code this} lock
* should not be held when calling this method.
*/
private void drainPendingCalls() {
assert realStream != null;
assert !passThrough;
List<Runnable> toRun = new ArrayList<Runnable>();
while (true) {
synchronized (this) {
if (pendingCalls.isEmpty()) {
pendingCalls = null;
passThrough = true;
break;
}
// Since there were pendingCalls, we need to process them. To maintain ordering we can't set
// passThrough=true until we run all pendingCalls, but new Runnables may be added after we
// drop the lock. So we will have to re-check pendingCalls.
List<Runnable> tmp = toRun;
toRun = pendingCalls;
pendingCalls = tmp;
}
for (Runnable runnable : toRun) {
// Must not call transport while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
}
toRun.clear();
}
}
@Override
public void writeMessage(InputStream message) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
pendingMessages.add(new PendingMessage(message, messageCompressionEnabled));
return;
}
/**
* Enqueue the runnable or execute it now. Call sites that may be called many times may want avoid
* this method if {@code passThrough == true}.
*
* <p>Note that this method is no more thread-safe than {@code runnable}. It is thread-safe if and
* only if {@code runnable} is thread-safe.
*/
private void delayOrExecute(Runnable runnable) {
synchronized (this) {
if (!passThrough) {
pendingCalls.add(runnable);
return;
}
}
startedRealStream.writeMessage(message);
runnable.run();
}
@Override
public void setAuthority(final String authority) {
checkState(listener == null, "May only be called before start");
checkNotNull(authority, "authority");
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.setAuthority(authority);
}
});
}
@Override
public void start(final ClientStreamListener listener) {
checkState(this.listener == null, "already started");
Status savedError;
synchronized (this) {
this.listener = checkNotNull(listener, "listener");
// If error != null, then cancel() has been called and was unable to close the listener
savedError = error;
}
if (savedError != null) {
listener.closed(savedError, new Metadata());
return;
}
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.start(listener);
}
});
}
@Override
public void writeMessage(final InputStream message) {
checkNotNull(message, "message");
if (passThrough) {
realStream.writeMessage(message);
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.writeMessage(message);
}
});
}
}
@Override
public void flush() {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
pendingFlush = true;
return;
if (passThrough) {
realStream.flush();
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.flush();
}
}
});
}
startedRealStream.flush();
}
// When this method returns, passThrough is guaranteed to be true
@Override
public void cancel(Status reason) {
// At least one of them is null.
ClientStream streamToBeCancelled = startedRealStream;
ClientStreamListener listenerToBeCalled = null;
if (streamToBeCancelled == null) {
synchronized (this) {
if (realStream != null) {
// realStream already set. Just cancel it.
streamToBeCancelled = realStream;
} else if (error == null) {
// Neither realStream and error are set. Will set the error and call the listener if
// it's set.
error = checkNotNull(reason);
realStream = NoopClientStream.INSTANCE;
if (listener != null) {
// call startStream anyways to drain pending messages.
startStream();
listenerToBeCalled = listener;
}
} // else: error already set, do nothing.
public void cancel(final Status reason) {
checkNotNull(reason, "reason");
boolean delegateToRealStream = true;
ClientStreamListener listenerToClose = null;
synchronized (this) {
// If realStream != null, then either setStream() or cancel() has been called
if (realStream == null) {
realStream = NoopClientStream.INSTANCE;
delegateToRealStream = false;
// If listener == null, then start() will later call listener with 'error'
listenerToClose = listener;
error = reason;
}
}
if (listenerToBeCalled != null) {
Preconditions.checkState(streamToBeCancelled == null, "unexpected streamToBeCancelled");
listenerToBeCalled.closed(reason, new Metadata());
}
if (streamToBeCancelled != null) {
streamToBeCancelled.cancel(reason);
if (delegateToRealStream) {
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.cancel(reason);
}
});
} else {
if (listenerToClose != null) {
listenerToClose.closed(reason, new Metadata());
}
drainPendingCalls();
}
}
@Override
public void halfClose() {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
pendingHalfClose = true;
return;
}
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.halfClose();
}
}
startedRealStream.halfClose();
});
}
@Override
public void request(int numMessages) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
pendingFlowControlRequests += numMessages;
return;
public void request(final int numMessages) {
if (passThrough) {
realStream.request(numMessages);
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.request(numMessages);
}
}
});
}
startedRealStream.request(numMessages);
}
@Override
public void setCompressor(Compressor compressor) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
this.compressor = compressor;
return;
}
public void setCompressor(final Compressor compressor) {
checkNotNull(compressor, "compressor");
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.setCompressor(compressor);
}
}
startedRealStream.setCompressor(compressor);
});
}
@Override
public void setDecompressor(Decompressor decompressor) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
this.decompressor = decompressor;
return;
}
}
}
startedRealStream.setDecompressor(decompressor);
checkNotNull(decompressor, "decompressor");
// This method being called only makes sense after setStream() has been called (but not
// necessarily returned), but there is not necessarily a happens-before relationship. This
// synchronized block creates one.
synchronized (this) { }
checkState(realStream != null, "How did we receive a reply before the request is sent?");
// ClientStreamListenerImpl (in ClientCallImpl) requires setDecompressor to be set immediately,
// since messages may be processed immediately after this method returns.
realStream.setDecompressor(decompressor);
}
@Override
public boolean isReady() {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
return false;
}
}
if (passThrough) {
return realStream.isReady();
} else {
return false;
}
return startedRealStream.isReady();
}
@Override
public void setMessageCompression(boolean enable) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
messageCompressionEnabled = enable;
return;
public void setMessageCompression(final boolean enable) {
if (passThrough) {
realStream.setMessageCompression(enable);
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.setMessageCompression(enable);
}
}
});
}
startedRealStream.setMessageCompression(enable);
}
}

View File

@ -219,6 +219,18 @@ public class DelayedClientTransportTest {
verify(mockRealTransport).ping(same(pingCallback), same(mockExecutor));
}
@Test public void shutdownThenPing() {
delayedTransport.shutdown();
verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated();
delayedTransport.ping(pingCallback, mockExecutor);
verifyNoMoreInteractions(pingCallback);
ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class);
verify(mockExecutor).execute(runnableCaptor.capture());
runnableCaptor.getValue().run();
verify(pingCallback).onFailure(any(Throwable.class));
}
@Test public void shutdownThenNewStream() {
delayedTransport.shutdown();
verify(transportListener).transportShutdown(any(Status.class));

View File

@ -31,12 +31,18 @@
package io.grpc.internal;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.grpc.Codec;
import io.grpc.Metadata;
@ -83,11 +89,28 @@ public class DelayedStreamTest {
inOrder.verify(realStream).start(listener);
}
@Test(expected = IllegalStateException.class)
public void setAuthority_afterStart() {
stream.start(listener);
stream.setAuthority("notgonnawork");
}
@Test(expected = IllegalStateException.class)
public void start_afterStart() {
stream.start(listener);
stream.start(mock(ClientStreamListener.class));
}
@Test(expected = IllegalStateException.class)
public void setDecompressor_beforeSetStream() {
stream.start(listener);
stream.setDecompressor(Codec.Identity.NONE);
}
@Test
public void setStream_sendsAllMessages() {
stream.start(listener);
stream.setCompressor(Codec.Identity.NONE);
stream.setDecompressor(Codec.Identity.NONE);
stream.setMessageCompression(true);
InputStream message = new ByteArrayInputStream(new byte[]{'a'});
@ -96,17 +119,19 @@ public class DelayedStreamTest {
stream.writeMessage(message);
stream.setStream(realStream);
stream.setDecompressor(Codec.Identity.NONE);
verify(realStream).setCompressor(Codec.Identity.NONE);
verify(realStream).setDecompressor(Codec.Identity.NONE);
// Verify that the order was correct, even though they should be interleaved with the
// writeMessage calls
verify(realStream).setMessageCompression(true);
verify(realStream, times(2)).setMessageCompression(false);
verify(realStream).setMessageCompression(false);
verify(realStream, times(2)).writeMessage(message);
verify(realStream).start(listener);
stream.writeMessage(message);
verify(realStream, times(3)).writeMessage(message);
}
@Test
@ -123,8 +148,10 @@ public class DelayedStreamTest {
stream.start(listener);
stream.flush();
stream.setStream(realStream);
verify(realStream).flush();
stream.flush();
verify(realStream, times(2)).flush();
}
@Test
@ -132,17 +159,45 @@ public class DelayedStreamTest {
stream.start(listener);
stream.request(1);
stream.request(2);
stream.setStream(realStream);
verify(realStream).request(1);
verify(realStream).request(2);
stream.request(3);
verify(realStream).request(3);
}
@Test
public void setStream_setMessageCompression() {
stream.start(listener);
stream.setMessageCompression(false);
stream.setStream(realStream);
verify(realStream).setMessageCompression(false);
stream.setMessageCompression(true);
verify(realStream).setMessageCompression(true);
}
@Test
public void setStream_isReady() {
stream.start(listener);
assertFalse(stream.isReady());
stream.setStream(realStream);
verify(realStream, never()).isReady();
assertFalse(stream.isReady());
verify(realStream).isReady();
when(realStream.isReady()).thenReturn(true);
assertTrue(stream.isReady());
verify(realStream, times(2)).isReady();
}
@Test
public void startThenCancelled() {
stream.start(listener);
stream.cancel(Status.CANCELLED);
verify(listener).closed(eq(Status.CANCELLED), isA(Metadata.class));
verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class));
}
@Test
@ -170,10 +225,37 @@ public class DelayedStreamTest {
verify(realStream).cancel(same(Status.CANCELLED));
}
@Test
public void setStreamTwice() {
stream.start(listener);
stream.setStream(realStream);
verify(realStream).start(listener);
stream.setStream(mock(ClientStream.class));
stream.flush();
verify(realStream).flush();
}
@Test
public void cancelThenSetStream() {
stream.cancel(Status.CANCELLED);
stream.setStream(realStream);
stream.start(listener);
stream.isReady();
verifyNoMoreInteractions(realStream);
}
@Test
public void cancel_beforeStart() {
Status status = Status.CANCELLED.withDescription("that was quick");
stream.cancel(status);
stream.start(listener);
verify(listener).closed(same(status), any(Metadata.class));
}
@Test
public void cancelledThenStart() {
stream.cancel(Status.CANCELLED);
stream.start(listener);
verify(listener).closed(eq(Status.CANCELLED), isA(Metadata.class));
verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class));
}
}