Don't hold a lock in DelayedStream when calling realStream

Our current lock ordering rules prohibit holding a lock when calling the
channel and stream. This change avoids the lock for both
DelayedClientTransport and DelayedStream. It is effectively a rewrite of
DelayedStream.

The fixes to ClientCallImpl were to ensure sane state in DelayedStream.

Fixes #1510
This commit is contained in:
Eric Anderson 2016-03-05 12:19:07 -08:00
parent b9c12327eb
commit eccd231131
5 changed files with 314 additions and 235 deletions

View File

@ -31,6 +31,7 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
@ -208,11 +209,11 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
stream.setAuthority(callOptions.getAuthority()); stream.setAuthority(callOptions.getAuthority());
} }
stream.setCompressor(compressor); stream.setCompressor(compressor);
stream.start(new ClientStreamListenerImpl(observer));
if (compressor != Codec.Identity.NONE) { if (compressor != Codec.Identity.NONE) {
stream.setMessageCompression(true); stream.setMessageCompression(true);
} }
stream.start(new ClientStreamListenerImpl(observer));
// Delay any sources of cancellation after start(), because most of the transports are broken if // Delay any sources of cancellation after start(), because most of the transports are broken if
// they receive cancel before start. Issue #1343 has more details // they receive cancel before start. Issue #1343 has more details
@ -269,6 +270,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@Override @Override
public void request(int numMessages) { public void request(int numMessages) {
Preconditions.checkState(stream != null, "Not started"); Preconditions.checkState(stream != null, "Not started");
checkArgument(numMessages >= 0, "Number requested must be non-negative");
stream.request(numMessages); stream.request(numMessages);
} }

View File

@ -80,44 +80,41 @@ class DelayedClientTransport implements ManagedClientTransport {
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
Supplier<ClientTransport> supplier = transportSupplier; Supplier<ClientTransport> supplier = transportSupplier;
if (supplier != null) { if (supplier == null) {
return supplier.get().newStream(method, headers);
}
synchronized (lock) { synchronized (lock) {
// Check again, since it may have changed while waiting for lock // Check again, since it may have changed while waiting for lock
supplier = transportSupplier; supplier = transportSupplier;
if (supplier != null) { if (supplier == null && !shutdown) {
return supplier.get().newStream(method, headers);
}
if (!shutdown) {
PendingStream pendingStream = new PendingStream(method, headers); PendingStream pendingStream = new PendingStream(method, headers);
pendingStreams.add(pendingStream); pendingStreams.add(pendingStream);
return pendingStream; return pendingStream;
} }
} }
}
if (supplier != null) {
return supplier.get().newStream(method, headers);
}
return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown")); return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown"));
} }
@Override @Override
public void ping(final PingCallback callback, Executor executor) { public void ping(final PingCallback callback, Executor executor) {
Supplier<ClientTransport> supplier = transportSupplier; Supplier<ClientTransport> supplier = transportSupplier;
if (supplier != null) { if (supplier == null) {
supplier.get().ping(callback, executor);
return;
}
synchronized (lock) { synchronized (lock) {
// Check again, since it may have changed while waiting for lock // Check again, since it may have changed while waiting for lock
supplier = transportSupplier; supplier = transportSupplier;
if (supplier != null) { if (supplier == null && !shutdown) {
supplier.get().ping(callback, executor);
return;
}
if (!shutdown) {
PendingPing pendingPing = new PendingPing(callback, executor); PendingPing pendingPing = new PendingPing(callback, executor);
pendingPings.add(pendingPing); pendingPings.add(pendingPing);
return; return;
} }
} }
}
if (supplier != null) {
supplier.get().ping(callback, executor);
return;
}
executor.execute(new Runnable() { executor.execute(new Runnable() {
@Override public void run() { @Override public void run() {
callback.onFailure( callback.onFailure(

View File

@ -34,15 +34,13 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import com.google.common.base.Preconditions;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.Decompressor; import io.grpc.Decompressor;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import java.io.InputStream; import java.io.InputStream;
import java.util.LinkedList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.GuardedBy;
@ -56,263 +54,251 @@ import javax.annotation.concurrent.GuardedBy;
* necessary. * necessary.
*/ */
class DelayedStream implements ClientStream { class DelayedStream implements ClientStream {
/** {@code true} once realStream is valid and all pending calls have been drained. */
// set to non null once both listener and realStream are valid. After this point it is safe private volatile boolean passThrough;
// to call methods on startedRealStream. Note: this may be true even after the delayed stream is /**
// cancelled. This should be okay. * Non-{@code null} iff start has been called. Used to assert methods are called in appropriate
private volatile ClientStream startedRealStream; * order, but also used if an error occurrs before {@code realStream} is set.
@GuardedBy("this") */
private String authority;
@GuardedBy("this")
private ClientStreamListener listener; private ClientStreamListener listener;
@GuardedBy("this") /** Must hold {@code this} lock when setting. */
private ClientStream realStream; private ClientStream realStream;
@GuardedBy("this") @GuardedBy("this")
private Status error; private Status error;
@GuardedBy("this") @GuardedBy("this")
private final List<PendingMessage> pendingMessages = new LinkedList<PendingMessage>(); private List<Runnable> pendingCalls = new ArrayList<Runnable>();
private boolean messageCompressionEnabled;
@GuardedBy("this")
private boolean pendingHalfClose;
@GuardedBy("this")
private int pendingFlowControlRequests;
@GuardedBy("this")
private boolean pendingFlush;
@GuardedBy("this")
private Compressor compressor;
@GuardedBy("this")
private Decompressor decompressor;
static final class PendingMessage {
final InputStream message;
final boolean shouldBeCompressed;
public PendingMessage(InputStream message, boolean shouldBeCompressed) {
this.message = message;
this.shouldBeCompressed = shouldBeCompressed;
}
}
@Override
public synchronized void setAuthority(String authority) {
checkState(listener == null, "must be called before start");
checkNotNull(authority, "authority");
if (realStream == null) {
this.authority = authority;
} else {
realStream.setAuthority(authority);
}
}
@Override
public void start(ClientStreamListener listener) {
synchronized (this) {
// start may be called at most once.
checkState(this.listener == null, "already started");
this.listener = checkNotNull(listener, "listener");
// Check error first rather than success.
if (error != null) {
listener.closed(error, new Metadata());
}
// In the event that an error happened, realStream will be a noop stream. We still call
// start stream in order to drain references to pending messages.
if (realStream != null) {
startStream();
}
}
}
@GuardedBy("this")
private void startStream() {
checkState(realStream != null, "realStream");
checkState(listener != null, "listener");
if (authority != null) {
realStream.setAuthority(authority);
}
realStream.start(listener);
if (decompressor != null) {
realStream.setDecompressor(decompressor);
}
if (compressor != null) {
realStream.setCompressor(compressor);
}
for (PendingMessage message : pendingMessages) {
realStream.setMessageCompression(message.shouldBeCompressed);
realStream.writeMessage(message.message);
}
// Set this again, incase no messages were sent.
realStream.setMessageCompression(messageCompressionEnabled);
pendingMessages.clear();
if (pendingHalfClose) {
realStream.halfClose();
pendingHalfClose = false;
}
if (pendingFlowControlRequests > 0) {
realStream.request(pendingFlowControlRequests);
pendingFlowControlRequests = 0;
}
if (pendingFlush) {
realStream.flush();
pendingFlush = false;
}
// Ensures visibility.
startedRealStream = realStream;
}
/** /**
* Transfers all pending and future requests and mutations to the given stream. * Transfers all pending and future requests and mutations to the given stream.
* *
* <p>No-op if either this method or {@link #cancel} have already been called. * <p>No-op if either this method or {@link #cancel} have already been called.
*/ */
// When this method returns, passThrough is guaranteed to be true
final void setStream(ClientStream stream) { final void setStream(ClientStream stream) {
synchronized (this) { synchronized (this) {
if (error != null || realStream != null) { // If realStream != null, then either setStream() or cancel() has been called.
if (realStream != null) {
return; return;
} }
realStream = checkNotNull(stream, "stream"); realStream = checkNotNull(stream, "stream");
// listener can only be non-null if start has already been called.
if (listener != null) {
startStream();
} }
drainPendingCalls();
}
/**
* Called to transition {@code passThrough} to {@code true}. This method is not safe to be called
* multiple times; the caller must ensure it will only be called once, ever. {@code this} lock
* should not be held when calling this method.
*/
private void drainPendingCalls() {
assert realStream != null;
assert !passThrough;
List<Runnable> toRun = new ArrayList<Runnable>();
while (true) {
synchronized (this) {
if (pendingCalls.isEmpty()) {
pendingCalls = null;
passThrough = true;
break;
}
// Since there were pendingCalls, we need to process them. To maintain ordering we can't set
// passThrough=true until we run all pendingCalls, but new Runnables may be added after we
// drop the lock. So we will have to re-check pendingCalls.
List<Runnable> tmp = toRun;
toRun = pendingCalls;
pendingCalls = tmp;
}
for (Runnable runnable : toRun) {
// Must not call transport while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
}
toRun.clear();
} }
} }
@Override /**
public void writeMessage(InputStream message) { * Enqueue the runnable or execute it now. Call sites that may be called many times may want avoid
if (startedRealStream == null) { * this method if {@code passThrough == true}.
*
* <p>Note that this method is no more thread-safe than {@code runnable}. It is thread-safe if and
* only if {@code runnable} is thread-safe.
*/
private void delayOrExecute(Runnable runnable) {
synchronized (this) { synchronized (this) {
if (startedRealStream == null) { if (!passThrough) {
pendingMessages.add(new PendingMessage(message, messageCompressionEnabled)); pendingCalls.add(runnable);
return; return;
} }
} }
runnable.run();
}
@Override
public void setAuthority(final String authority) {
checkState(listener == null, "May only be called before start");
checkNotNull(authority, "authority");
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.setAuthority(authority);
}
});
}
@Override
public void start(final ClientStreamListener listener) {
checkState(this.listener == null, "already started");
Status savedError;
synchronized (this) {
this.listener = checkNotNull(listener, "listener");
// If error != null, then cancel() has been called and was unable to close the listener
savedError = error;
}
if (savedError != null) {
listener.closed(savedError, new Metadata());
return;
}
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.start(listener);
}
});
}
@Override
public void writeMessage(final InputStream message) {
checkNotNull(message, "message");
if (passThrough) {
realStream.writeMessage(message);
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.writeMessage(message);
}
});
} }
startedRealStream.writeMessage(message);
} }
@Override @Override
public void flush() { public void flush() {
if (startedRealStream == null) { if (passThrough) {
synchronized (this) { realStream.flush();
if (startedRealStream == null) { } else {
pendingFlush = true; delayOrExecute(new Runnable() {
return; @Override
public void run() {
realStream.flush();
} }
});
} }
} }
startedRealStream.flush();
}
// When this method returns, passThrough is guaranteed to be true
@Override @Override
public void cancel(Status reason) { public void cancel(final Status reason) {
// At least one of them is null. checkNotNull(reason, "reason");
ClientStream streamToBeCancelled = startedRealStream; boolean delegateToRealStream = true;
ClientStreamListener listenerToBeCalled = null; ClientStreamListener listenerToClose = null;
if (streamToBeCancelled == null) {
synchronized (this) { synchronized (this) {
if (realStream != null) { // If realStream != null, then either setStream() or cancel() has been called
// realStream already set. Just cancel it. if (realStream == null) {
streamToBeCancelled = realStream;
} else if (error == null) {
// Neither realStream and error are set. Will set the error and call the listener if
// it's set.
error = checkNotNull(reason);
realStream = NoopClientStream.INSTANCE; realStream = NoopClientStream.INSTANCE;
if (listener != null) { delegateToRealStream = false;
// call startStream anyways to drain pending messages.
startStream(); // If listener == null, then start() will later call listener with 'error'
listenerToBeCalled = listener; listenerToClose = listener;
} error = reason;
} // else: error already set, do nothing.
} }
} }
if (listenerToBeCalled != null) { if (delegateToRealStream) {
Preconditions.checkState(streamToBeCancelled == null, "unexpected streamToBeCancelled"); delayOrExecute(new Runnable() {
listenerToBeCalled.closed(reason, new Metadata()); @Override
public void run() {
realStream.cancel(reason);
} }
if (streamToBeCancelled != null) { });
streamToBeCancelled.cancel(reason); } else {
if (listenerToClose != null) {
listenerToClose.closed(reason, new Metadata());
}
drainPendingCalls();
} }
} }
@Override @Override
public void halfClose() { public void halfClose() {
if (startedRealStream == null) { delayOrExecute(new Runnable() {
synchronized (this) { @Override
if (startedRealStream == null) { public void run() {
pendingHalfClose = true; realStream.halfClose();
return;
} }
} });
}
startedRealStream.halfClose();
} }
@Override @Override
public void request(int numMessages) { public void request(final int numMessages) {
if (startedRealStream == null) { if (passThrough) {
synchronized (this) { realStream.request(numMessages);
if (startedRealStream == null) { } else {
pendingFlowControlRequests += numMessages; delayOrExecute(new Runnable() {
return; @Override
public void run() {
realStream.request(numMessages);
} }
});
} }
} }
startedRealStream.request(numMessages);
}
@Override @Override
public void setCompressor(Compressor compressor) { public void setCompressor(final Compressor compressor) {
if (startedRealStream == null) { checkNotNull(compressor, "compressor");
synchronized (this) { delayOrExecute(new Runnable() {
if (startedRealStream == null) { @Override
this.compressor = compressor; public void run() {
return; realStream.setCompressor(compressor);
} }
} });
}
startedRealStream.setCompressor(compressor);
} }
@Override @Override
public void setDecompressor(Decompressor decompressor) { public void setDecompressor(Decompressor decompressor) {
if (startedRealStream == null) { checkNotNull(decompressor, "decompressor");
synchronized (this) { // This method being called only makes sense after setStream() has been called (but not
if (startedRealStream == null) { // necessarily returned), but there is not necessarily a happens-before relationship. This
this.decompressor = decompressor; // synchronized block creates one.
return; synchronized (this) { }
} checkState(realStream != null, "How did we receive a reply before the request is sent?");
} // ClientStreamListenerImpl (in ClientCallImpl) requires setDecompressor to be set immediately,
} // since messages may be processed immediately after this method returns.
startedRealStream.setDecompressor(decompressor); realStream.setDecompressor(decompressor);
} }
@Override @Override
public boolean isReady() { public boolean isReady() {
if (startedRealStream == null) { if (passThrough) {
synchronized (this) { return realStream.isReady();
if (startedRealStream == null) { } else {
return false; return false;
} }
} }
}
return startedRealStream.isReady();
}
@Override @Override
public void setMessageCompression(boolean enable) { public void setMessageCompression(final boolean enable) {
if (startedRealStream == null) { if (passThrough) {
synchronized (this) { realStream.setMessageCompression(enable);
if (startedRealStream == null) { } else {
messageCompressionEnabled = enable; delayOrExecute(new Runnable() {
return; @Override
public void run() {
realStream.setMessageCompression(enable);
} }
});
} }
} }
startedRealStream.setMessageCompression(enable);
}
} }

View File

@ -219,6 +219,18 @@ public class DelayedClientTransportTest {
verify(mockRealTransport).ping(same(pingCallback), same(mockExecutor)); verify(mockRealTransport).ping(same(pingCallback), same(mockExecutor));
} }
@Test public void shutdownThenPing() {
delayedTransport.shutdown();
verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated();
delayedTransport.ping(pingCallback, mockExecutor);
verifyNoMoreInteractions(pingCallback);
ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class);
verify(mockExecutor).execute(runnableCaptor.capture());
runnableCaptor.getValue().run();
verify(pingCallback).onFailure(any(Throwable.class));
}
@Test public void shutdownThenNewStream() { @Test public void shutdownThenNewStream() {
delayedTransport.shutdown(); delayedTransport.shutdown();
verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportShutdown(any(Status.class));

View File

@ -31,12 +31,18 @@
package io.grpc.internal; package io.grpc.internal;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.grpc.Codec; import io.grpc.Codec;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -83,11 +89,28 @@ public class DelayedStreamTest {
inOrder.verify(realStream).start(listener); inOrder.verify(realStream).start(listener);
} }
@Test(expected = IllegalStateException.class)
public void setAuthority_afterStart() {
stream.start(listener);
stream.setAuthority("notgonnawork");
}
@Test(expected = IllegalStateException.class)
public void start_afterStart() {
stream.start(listener);
stream.start(mock(ClientStreamListener.class));
}
@Test(expected = IllegalStateException.class)
public void setDecompressor_beforeSetStream() {
stream.start(listener);
stream.setDecompressor(Codec.Identity.NONE);
}
@Test @Test
public void setStream_sendsAllMessages() { public void setStream_sendsAllMessages() {
stream.start(listener); stream.start(listener);
stream.setCompressor(Codec.Identity.NONE); stream.setCompressor(Codec.Identity.NONE);
stream.setDecompressor(Codec.Identity.NONE);
stream.setMessageCompression(true); stream.setMessageCompression(true);
InputStream message = new ByteArrayInputStream(new byte[]{'a'}); InputStream message = new ByteArrayInputStream(new byte[]{'a'});
@ -96,17 +119,19 @@ public class DelayedStreamTest {
stream.writeMessage(message); stream.writeMessage(message);
stream.setStream(realStream); stream.setStream(realStream);
stream.setDecompressor(Codec.Identity.NONE);
verify(realStream).setCompressor(Codec.Identity.NONE); verify(realStream).setCompressor(Codec.Identity.NONE);
verify(realStream).setDecompressor(Codec.Identity.NONE); verify(realStream).setDecompressor(Codec.Identity.NONE);
// Verify that the order was correct, even though they should be interleaved with the
// writeMessage calls
verify(realStream).setMessageCompression(true); verify(realStream).setMessageCompression(true);
verify(realStream, times(2)).setMessageCompression(false); verify(realStream).setMessageCompression(false);
verify(realStream, times(2)).writeMessage(message); verify(realStream, times(2)).writeMessage(message);
verify(realStream).start(listener); verify(realStream).start(listener);
stream.writeMessage(message);
verify(realStream, times(3)).writeMessage(message);
} }
@Test @Test
@ -123,8 +148,10 @@ public class DelayedStreamTest {
stream.start(listener); stream.start(listener);
stream.flush(); stream.flush();
stream.setStream(realStream); stream.setStream(realStream);
verify(realStream).flush(); verify(realStream).flush();
stream.flush();
verify(realStream, times(2)).flush();
} }
@Test @Test
@ -132,17 +159,45 @@ public class DelayedStreamTest {
stream.start(listener); stream.start(listener);
stream.request(1); stream.request(1);
stream.request(2); stream.request(2);
stream.setStream(realStream); stream.setStream(realStream);
verify(realStream).request(1);
verify(realStream).request(2);
stream.request(3);
verify(realStream).request(3); verify(realStream).request(3);
} }
@Test
public void setStream_setMessageCompression() {
stream.start(listener);
stream.setMessageCompression(false);
stream.setStream(realStream);
verify(realStream).setMessageCompression(false);
stream.setMessageCompression(true);
verify(realStream).setMessageCompression(true);
}
@Test
public void setStream_isReady() {
stream.start(listener);
assertFalse(stream.isReady());
stream.setStream(realStream);
verify(realStream, never()).isReady();
assertFalse(stream.isReady());
verify(realStream).isReady();
when(realStream.isReady()).thenReturn(true);
assertTrue(stream.isReady());
verify(realStream, times(2)).isReady();
}
@Test @Test
public void startThenCancelled() { public void startThenCancelled() {
stream.start(listener); stream.start(listener);
stream.cancel(Status.CANCELLED); stream.cancel(Status.CANCELLED);
verify(listener).closed(eq(Status.CANCELLED), isA(Metadata.class)); verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class));
} }
@Test @Test
@ -170,10 +225,37 @@ public class DelayedStreamTest {
verify(realStream).cancel(same(Status.CANCELLED)); verify(realStream).cancel(same(Status.CANCELLED));
} }
@Test
public void setStreamTwice() {
stream.start(listener);
stream.setStream(realStream);
verify(realStream).start(listener);
stream.setStream(mock(ClientStream.class));
stream.flush();
verify(realStream).flush();
}
@Test
public void cancelThenSetStream() {
stream.cancel(Status.CANCELLED);
stream.setStream(realStream);
stream.start(listener);
stream.isReady();
verifyNoMoreInteractions(realStream);
}
@Test
public void cancel_beforeStart() {
Status status = Status.CANCELLED.withDescription("that was quick");
stream.cancel(status);
stream.start(listener);
verify(listener).closed(same(status), any(Metadata.class));
}
@Test @Test
public void cancelledThenStart() { public void cancelledThenStart() {
stream.cancel(Status.CANCELLED); stream.cancel(Status.CANCELLED);
stream.start(listener); stream.start(listener);
verify(listener).closed(eq(Status.CANCELLED), isA(Metadata.class)); verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class));
} }
} }