diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index fbb24633d7..89442e4bf9 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -141,15 +141,20 @@ public class DelayedClientCall extends ClientCall { *

No-op if either this method or {@link #cancel} have already been called. */ // When this method returns, passThrough is guaranteed to be true - public final void setCall(ClientCall call) { + public final Runnable setCall(ClientCall call) { synchronized (this) { // If realCall != null, then either setCall() or cancel() has been called. if (realCall != null) { - return; + return null; } setRealCall(checkNotNull(call, "call")); } - drainPendingCalls(); + return new Runnable() { + @Override + public void run() { + drainPendingCalls(); + } + }; } @Override diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 601c7740ca..b9ad2a1840 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -1099,22 +1099,25 @@ final class ManagedChannelImpl extends ManagedChannel implements /** Called when it's ready to create a real call and reprocess the pending call. */ void reprocess() { - getCallExecutor(callOptions).execute( - new Runnable() { - @Override - public void run() { - ClientCall realCall; - Context previous = context.attach(); - try { - realCall = newClientCall(method, callOptions); - } finally { - context.detach(previous); - } - setCall(realCall); - syncContext.execute(new PendingCallRemoval()); - } + ClientCall realCall; + Context previous = context.attach(); + try { + realCall = newClientCall(method, callOptions); + } finally { + context.detach(previous); + } + Runnable toRun = setCall(realCall); + if (toRun == null) { + syncContext.execute(new PendingCallRemoval()); + } else { + getCallExecutor(callOptions).execute(new Runnable() { + @Override + public void run() { + toRun.run(); + syncContext.execute(new PendingCallRemoval()); } - ); + }); + } } @Override diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index 290e2b9de6..7b653f8213 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -21,6 +21,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import com.google.common.util.concurrent.MoreExecutors; @@ -30,6 +31,7 @@ import io.grpc.Deadline; import io.grpc.ForwardingTestUtil; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.StatusException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; @@ -63,12 +65,13 @@ public class DelayedClientCallTest { public void allMethodsForwarded() throws Exception { DelayedClientCall delayedClientCall = new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null); - delayedClientCall.setCall(mockRealCall); + callMeMaybe(delayedClientCall.setCall(mockRealCall)); ForwardingTestUtil.testMethodsForwarded( ClientCall.class, mockRealCall, delayedClientCall, - Arrays.asList(ClientCall.class.getMethod("toString")), + Arrays.asList(ClientCall.class.getMethod("toString"), + ClientCall.class.getMethod("start", Listener.class, Metadata.class)), new ForwardingTestUtil.ArgumentProvider() { @Override public Object get(Method method, int argPos, Class clazz) { @@ -101,7 +104,7 @@ public class DelayedClientCallTest { DelayedClientCall delayedClientCall = new DelayedClientCall<>( callExecutor, fakeClock.getScheduledExecutorService(), Deadline.after(10, SECONDS)); delayedClientCall.start(listener, new Metadata()); - delayedClientCall.setCall(mockRealCall); + callMeMaybe(delayedClientCall.setCall(mockRealCall)); ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); Listener realCallListener = listenerCaptor.getValue(); @@ -119,4 +122,78 @@ public class DelayedClientCallTest { verify(listener).onClose(statusCaptor.capture(), eq(trailer)); assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.DATA_LOSS); } + + @Test + public void setCallThenStart() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + callMeMaybe(delayedClientCall.setCall(mockRealCall)); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + verify(mockRealCall).request(1); + realCallListener.onMessage(1); + verify(listener).onMessage(1); + } + + @Test + public void startThenSetCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + verify(mockRealCall).request(1); + realCallListener.onMessage(1); + verify(listener).onMessage(1); + } + + @Test + @SuppressWarnings("unchecked") + public void cancelThenSetCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + delayedClientCall.cancel("cancel", new StatusException(Status.CANCELLED)); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNull(); + verify(mockRealCall, never()).start(any(Listener.class), any(Metadata.class)); + verify(mockRealCall, never()).request(1); + verify(mockRealCall, never()).cancel(any(), any()); + verify(listener).onClose(any(), any()); + } + + @Test + @SuppressWarnings("unchecked") + public void setCallThenCancel() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + delayedClientCall.cancel("cancel", new StatusException(Status.CANCELLED)); + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + verify(mockRealCall).request(1); + verify(mockRealCall).cancel(any(), any()); + realCallListener.onClose(Status.CANCELLED, null); + verify(listener).onClose(Status.CANCELLED, null); + } + + private void callMeMaybe(Runnable r) { + if (r != null) { + r.run(); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index f0c37b20df..2abef58d93 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -400,7 +400,10 @@ final class FaultFilter implements Filter, ClientInterceptorBuilder { activeFaultCounter.decrementAndGet(); } } - setCall(callSupplier.get()); + Runnable toRun = setCall(callSupplier.get()); + if (toRun != null) { + toRun.run(); + } } }, delayNanos,