core: Fix onReady race by adding DelayedStreamListener

onReady/isReady previously could disagree causing a sort of deadlock
where the application isn't sending because grpc said not to, but won't
be informed to send via onReady later.

This is a stack trace from inprocessTransportOutboundFlowControl. The
line numbers are from this commit but with the changes to DelayedStream
disabled:

at io.grpc.internal.DelayedStream.isReady(DelayedStream.java:306)
  (That is isReady returning false because fallThrough == false)
at io.grpc.internal.ClientCallImpl.isReady(ClientCallImpl.java:382)
at io.grpc.stub.ClientCalls$CallToStreamObserverAdapter.isReady(ClientCalls.java:289)
at io.grpc.stub.ClientCallsTest$8$1.run(ClientCallsTest.java:403)
  (And yet that was the onReady callback, and it won't be called again)
at io.grpc.stub.ClientCalls$StreamObserverToCallListenerAdapter.onReady(ClientCalls.java:377)
at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$4.runInContext(ClientCallImpl.java:481)
at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:52)
at io.grpc.internal.SerializeReentrantCallsDirectExecutor.execute(SerializeReentrantCallsDirectExecutor.java:65)
at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl.onReady(ClientCallImpl.java:478)
at io.grpc.internal.DelayedStream$DelayedStreamListener.onReady(DelayedStream.java:366)
at io.grpc.inprocess.InProcessTransport$InProcessStream$InProcessServerStream.request(InProcessTransport.java:284)
at io.grpc.internal.ServerCallImpl.request(ServerCallImpl.java:99)
at io.grpc.stub.ServerCalls$ServerCallStreamObserverImpl.request(ServerCalls.java:345)
at io.grpc.stub.ClientCallsTest.inprocessTransportOutboundFlowControl(ClientCallsTest.java:432)

Fixes #1932
This commit is contained in:
Eric Anderson 2016-06-16 10:02:22 -07:00
parent 88a0378912
commit 29776ca947
6 changed files with 326 additions and 66 deletions

View File

@ -69,6 +69,8 @@ class DelayedStream implements ClientStream {
private Status error;
@GuardedBy("this")
private List<Runnable> pendingCalls = new ArrayList<Runnable>();
@GuardedBy("this")
private DelayedStreamListener delayedListener;
/**
* Transfers all pending and future requests and mutations to the given stream.
@ -97,11 +99,13 @@ class DelayedStream implements ClientStream {
assert realStream != null;
assert !passThrough;
List<Runnable> toRun = new ArrayList<Runnable>();
DelayedStreamListener delayedListener = null;
while (true) {
synchronized (this) {
if (pendingCalls.isEmpty()) {
pendingCalls = null;
passThrough = true;
delayedListener = this.delayedListener;
break;
}
// Since there were pendingCalls, we need to process them. To maintain ordering we can't set
@ -118,6 +122,9 @@ class DelayedStream implements ClientStream {
}
toRun.clear();
}
if (delayedListener != null) {
delayedListener.drainPendingCallbacks();
}
}
/**
@ -150,26 +157,36 @@ class DelayedStream implements ClientStream {
}
@Override
public void start(final ClientStreamListener listener) {
public void start(ClientStreamListener listener) {
checkState(this.listener == null, "already started");
Status savedError;
boolean savedPassThrough;
synchronized (this) {
this.listener = checkNotNull(listener, "listener");
// If error != null, then cancel() has been called and was unable to close the listener
savedError = error;
savedPassThrough = passThrough;
if (!savedPassThrough) {
listener = delayedListener = new DelayedStreamListener(listener);
}
}
if (savedError != null) {
listener.closed(savedError, new Metadata());
return;
}
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.start(listener);
}
});
if (savedPassThrough) {
realStream.start(listener);
} else {
final ClientStreamListener finalListener = listener;
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.start(finalListener);
}
});
}
}
@Override
@ -308,4 +325,99 @@ class DelayedStream implements ClientStream {
ClientStream getRealStream() {
return realStream;
}
private static class DelayedStreamListener implements ClientStreamListener {
private final ClientStreamListener realListener;
private volatile boolean passThrough;
@GuardedBy("this")
private List<Runnable> pendingCallbacks = new ArrayList<Runnable>();
public DelayedStreamListener(ClientStreamListener listener) {
this.realListener = listener;
}
private void delayOrExecute(Runnable runnable) {
synchronized (this) {
if (!passThrough) {
pendingCallbacks.add(runnable);
return;
}
}
runnable.run();
}
@Override
public void messageRead(final InputStream message) {
if (passThrough) {
realListener.messageRead(message);
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.messageRead(message);
}
});
}
}
@Override
public void onReady() {
if (passThrough) {
realListener.onReady();
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.onReady();
}
});
}
}
@Override
public void headersRead(final Metadata headers) {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.headersRead(headers);
}
});
}
@Override
public void closed(final Status status, final Metadata trailers) {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.closed(status, trailers);
}
});
}
public void drainPendingCallbacks() {
assert !passThrough;
List<Runnable> toRun = new ArrayList<Runnable>();
while (true) {
synchronized (this) {
if (pendingCallbacks.isEmpty()) {
pendingCallbacks = null;
passThrough = true;
break;
}
// Since there were pendingCallbacks, we need to process them. To maintain ordering we
// can't set passThrough=true until we run all pendingCallbacks, but new Runnables may be
// added after we drop the lock. So we will have to re-check pendingCallbacks.
List<Runnable> tmp = toRun;
toRun = pendingCallbacks;
pendingCallbacks = tmp;
}
for (Runnable runnable : toRun) {
// Avoid calling listener while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
}
toRun.clear();
}
}
}
}

View File

@ -56,8 +56,6 @@ import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.io.InputStream;
/**
* Test for {@link AbstractClientStream}. This class tries to test functionality in
* AbstractClientStream, but not in any super classes.
@ -85,7 +83,7 @@ public class AbstractClientStreamTest {
@Test
public void cancel_doNotAcceptOk() {
for (Code code : Code.values()) {
ClientStreamListener listener = new BaseClientStreamListener();
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator);
stream.start(listener);
if (code != Code.OK) {
@ -103,7 +101,7 @@ public class AbstractClientStreamTest {
@Test
public void cancel_failsOnNull() {
ClientStreamListener listener = new BaseClientStreamListener();
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator);
stream.start(listener);
thrown.expect(NullPointerException.class);
@ -162,7 +160,7 @@ public class AbstractClientStreamTest {
@Test
public void inboundDataReceived_failsOnNullFrame() {
ClientStreamListener listener = new BaseClientStreamListener();
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator);
stream.start(listener);
thrown.expect(NullPointerException.class);
@ -282,21 +280,4 @@ public class AbstractClientStreamTest {
@Override
protected void returnProcessedBytes(int processedBytes) {}
}
/**
* No-op base class for testing.
*/
static class BaseClientStreamListener implements ClientStreamListener {
@Override
public void messageRead(InputStream message) {}
@Override
public void onReady() {}
@Override
public void headersRead(Metadata headers) {}
@Override
public void closed(Status status, Metadata trailers) {}
}
}

View File

@ -79,6 +79,7 @@ public class DelayedClientTransportTest {
@Mock private ClientTransport.PingCallback pingCallback;
@Mock private Executor mockExecutor;
@Captor private ArgumentCaptor<Status> statusCaptor;
@Captor private ArgumentCaptor<ClientStreamListener> listenerCaptor;
private final MethodDescriptor<String, Integer> method = MethodDescriptor.create(
MethodDescriptor.MethodType.UNKNOWN, "/service/method",
@ -141,7 +142,11 @@ public class DelayedClientTransportTest {
assertFalse(delayedTransport.hasPendingStreams());
assertEquals(1, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
verify(mockRealStream).start(same(streamListener));
verify(mockRealStream).start(listenerCaptor.capture());
verifyNoMoreInteractions(streamListener);
listenerCaptor.getValue().onReady();
verify(streamListener).onReady();
verifyNoMoreInteractions(streamListener);
}
@Test public void newStreamThenSetTransportThenShutdown() {

View File

@ -47,6 +47,7 @@ import static org.mockito.Mockito.when;
import io.grpc.Codec;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.NoopClientStream;
import org.junit.Before;
import org.junit.Rule;
@ -54,6 +55,8 @@ import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
@ -71,6 +74,7 @@ public class DelayedStreamTest {
@Mock private ClientStreamListener listener;
@Mock private ClientStream realStream;
@Captor private ArgumentCaptor<ClientStreamListener> listenerCaptor;
private DelayedStream stream = new DelayedStream();
@Before
@ -86,7 +90,7 @@ public class DelayedStreamTest {
stream.setStream(realStream);
InOrder inOrder = inOrder(realStream);
inOrder.verify(realStream).setAuthority(authority);
inOrder.verify(realStream).start(listener);
inOrder.verify(realStream).start(any(ClientStreamListener.class));
}
@Test(expected = IllegalStateException.class)
@ -128,10 +132,14 @@ public class DelayedStreamTest {
verify(realStream).setMessageCompression(false);
verify(realStream, times(2)).writeMessage(message);
verify(realStream).start(listener);
verify(realStream).start(listenerCaptor.capture());
stream.writeMessage(message);
verify(realStream, times(3)).writeMessage(message);
verifyNoMoreInteractions(listener);
listenerCaptor.getValue().onReady();
verify(listener).onReady();
}
@Test
@ -205,7 +213,7 @@ public class DelayedStreamTest {
stream.start(listener);
stream.setStream(realStream);
stream.cancel(Status.CANCELLED);
verify(realStream).start(same(listener));
verify(realStream).start(any(ClientStreamListener.class));
verify(realStream).cancel(same(Status.CANCELLED));
}
@ -229,7 +237,7 @@ public class DelayedStreamTest {
public void setStreamTwice() {
stream.start(listener);
stream.setStream(realStream);
verify(realStream).start(listener);
verify(realStream).start(any(ClientStreamListener.class));
stream.setStream(mock(ClientStream.class));
stream.flush();
verify(realStream).flush();
@ -258,4 +266,87 @@ public class DelayedStreamTest {
stream.start(listener);
verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class));
}
@Test
public void listener_onReadyDelayedUntilPassthrough() {
class IsReadyListener extends NoopClientStreamListener {
boolean onReadyCalled;
@Override
public void onReady() {
// If onReady was not delayed, then passthrough==false and isReady will return false.
assertTrue(stream.isReady());
onReadyCalled = true;
}
}
IsReadyListener isReadyListener = new IsReadyListener();
stream.start(isReadyListener);
stream.setStream(new NoopClientStream() {
@Override
public void start(ClientStreamListener listener) {
// This call to the listener should end up being delayed.
listener.onReady();
}
@Override
public boolean isReady() {
return true;
}
});
assertTrue(isReadyListener.onReadyCalled);
}
@Test
public void listener_allQueued() {
final Metadata headers = new Metadata();
final InputStream message1 = mock(InputStream.class);
final InputStream message2 = mock(InputStream.class);
final Metadata trailers = new Metadata();
final Status status = Status.UNKNOWN.withDescription("unique status");
final InOrder inOrder = inOrder(listener);
stream.start(listener);
stream.setStream(new NoopClientStream() {
@Override
public void start(ClientStreamListener passedListener) {
passedListener.onReady();
passedListener.headersRead(headers);
passedListener.messageRead(message1);
passedListener.onReady();
passedListener.messageRead(message2);
passedListener.closed(status, trailers);
verifyNoMoreInteractions(listener);
}
});
inOrder.verify(listener).onReady();
inOrder.verify(listener).headersRead(headers);
inOrder.verify(listener).messageRead(message1);
inOrder.verify(listener).onReady();
inOrder.verify(listener).messageRead(message2);
inOrder.verify(listener).closed(status, trailers);
}
@Test
public void listener_noQueued() {
final Metadata headers = new Metadata();
final InputStream message = mock(InputStream.class);
final Metadata trailers = new Metadata();
final Status status = Status.UNKNOWN.withDescription("unique status");
stream.start(listener);
stream.setStream(realStream);
verify(realStream).start(listenerCaptor.capture());
ClientStreamListener delayedListener = listenerCaptor.getValue();
delayedListener.onReady();
verify(listener).onReady();
delayedListener.headersRead(headers);
verify(listener).headersRead(headers);
delayedListener.messageRead(message);
verify(listener).messageRead(message);
delayedListener.closed(status, trailers);
verify(listener).closed(status, trailers);
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright 2015, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import io.grpc.Metadata;
import io.grpc.Status;
import java.io.InputStream;
/**
* No-op base class for testing.
*/
class NoopClientStreamListener implements ClientStreamListener {
@Override
public void messageRead(InputStream message) {}
@Override
public void onReady() {}
@Override
public void headersRead(Metadata headers) {}
@Override
public void closed(Status status, Metadata trailers) {}
}

View File

@ -40,20 +40,23 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.ManagedChannelImpl;
import io.grpc.stub.ServerCalls.NoopStreamObserver;
import io.grpc.stub.ServerCallsTest.IntegerMarshaller;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -83,6 +86,9 @@ public class ClientCallsTest {
"some/method",
new IntegerMarshaller(), new IntegerMarshaller());
private Server server;
private ManagedChannel channel;
@Mock
private ClientCall<Integer, String> call;
@ -91,6 +97,16 @@ public class ClientCallsTest {
MockitoAnnotations.initMocks(this);
}
@After
public void tearDown() {
if (server != null) {
server.shutdownNow();
}
if (channel != null) {
channel.shutdownNow();
}
}
@Test
public void unaryFutureCallSuccess() throws Exception {
Integer req = 2;
@ -257,7 +273,7 @@ public class ClientCallsTest {
@Test
public void inprocessTransportInboundFlowControl() throws Exception {
final Semaphore semaphore = new Semaphore(1);
final Semaphore semaphore = new Semaphore(0);
ServerServiceDefinition service = ServerServiceDefinition.builder(
new ServiceDescriptor("some", STREAMING_METHOD))
.addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
@ -288,13 +304,13 @@ public class ClientCallsTest {
}))
.build();
long tag = System.nanoTime();
InProcessServerBuilder.forName("go-with-the-flow" + tag).addService(service).build().start();
ManagedChannelImpl channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).build();
server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor()
.addService(service).build().start();
channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build();
final ClientCall<Integer, Integer> clientCall = channel.newCall(STREAMING_METHOD,
CallOptions.DEFAULT);
final CountDownLatch latch = new CountDownLatch(1);
final List<Integer> receivedMessages = new ArrayList<Integer>(6);
semaphore.acquire();
final List<Object> receivedMessages = new ArrayList<Object>(6);
ClientResponseObserver<Integer, Integer> responseObserver =
new ClientResponseObserver<Integer, Integer>() {
@ -310,6 +326,7 @@ public class ClientCallsTest {
@Override
public void onError(Throwable t) {
receivedMessages.add(t);
latch.countDown();
}
@ -327,17 +344,17 @@ public class ClientCallsTest {
integerStreamObserver.request(3);
integerStreamObserver.onCompleted();
assertTrue(latch.await(5, TimeUnit.SECONDS));
// Very that number of messages produced in each onReady handler call matches the number
// Verify that number of messages produced in each onReady handler call matches the number
// requested by the client. Note that ClientCalls.asyncBidiStreamingCall will request(1)
assertEquals(Arrays.asList(0, 1, 1, 2, 2, 2), receivedMessages);
}
@org.junit.Ignore
@Test
public void inprocessTransportOutboundFlowControl() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final Semaphore semaphore = new Semaphore(1);
final List<Integer> receivedMessages = new ArrayList<Integer>(6);
final Semaphore semaphore = new Semaphore(0);
final List<Object> receivedMessages = new ArrayList<Object>(6);
final SettableFuture<ServerCallStreamObserver<Integer>> observerFuture
= SettableFuture.create();
ServerServiceDefinition service = ServerServiceDefinition.builder(
new ServiceDescriptor("some", STREAMING_METHOD))
.addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(
@ -347,42 +364,34 @@ public class ClientCallsTest {
final ServerCallStreamObserver<Integer> serverCallObserver =
(ServerCallStreamObserver<Integer>) responseObserver;
serverCallObserver.disableAutoInboundFlowControl();
new Thread(new Runnable() {
@Override
public void run() {
try {
serverCallObserver.request(1);
semaphore.acquire();
serverCallObserver.request(2);
semaphore.acquire();
serverCallObserver.request(3);
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
}).start();
return new ServerCalls.NoopStreamObserver<Integer>() {
observerFuture.set(serverCallObserver);
return new StreamObserver<Integer>() {
@Override
public void onNext(Integer value) {
receivedMessages.add(value);
}
@Override
public void onError(Throwable t) {
receivedMessages.add(t);
}
@Override
public void onCompleted() {
serverCallObserver.onCompleted();
latch.countDown();
}
};
}
}))
.build();
long tag = System.nanoTime();
InProcessServerBuilder.forName("go-with-the-flow" + tag).addService(service).build().start();
ManagedChannelImpl channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).build();
server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor()
.addService(service).build().start();
channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build();
final ClientCall<Integer, Integer> clientCall = channel.newCall(STREAMING_METHOD,
CallOptions.DEFAULT);
semaphore.acquire();
final SettableFuture<Void> future = SettableFuture.create();
ClientResponseObserver<Integer, Integer> responseObserver =
new ClientResponseObserver<Integer, Integer>() {
@Override
@ -409,16 +418,24 @@ public class ClientCallsTest {
@Override
public void onError(Throwable t) {
future.setException(t);
}
@Override
public void onCompleted() {
future.set(null);
}
};
ClientCalls.asyncBidiStreamingCall(clientCall, responseObserver);
assertTrue(latch.await(5, TimeUnit.SECONDS));
// Very that number of messages produced in each onReady handler call matches the number
ServerCallStreamObserver<Integer> serverCallObserver = observerFuture.get(5, TimeUnit.SECONDS);
serverCallObserver.request(1);
assertTrue(semaphore.tryAcquire(5, TimeUnit.SECONDS));
serverCallObserver.request(2);
assertTrue(semaphore.tryAcquire(5, TimeUnit.SECONDS));
serverCallObserver.request(3);
future.get(5, TimeUnit.SECONDS);
// Verify that number of messages produced in each onReady handler call matches the number
// requested by the client.
assertEquals(Arrays.asList(0, 1, 1, 2, 2, 2), receivedMessages);
}