diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java new file mode 100644 index 0000000000..fd3a4ccac6 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -0,0 +1,527 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import io.grpc.Attributes; +import io.grpc.ClientCall; +import io.grpc.Context; +import io.grpc.Deadline; +import io.grpc.Metadata; +import io.grpc.Status; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +/** + * A call that queues requests before the transport is available, and delegates to a real call + * implementation when the transport is available. + * + *

{@code ClientCall} itself doesn't require thread-safety. However, the state of {@code + * DelayedCall} may be internally altered by different threads, thus internal synchronization is + * necessary. + */ +final class DelayedClientCall extends ClientCall { + private static final Logger logger = Logger.getLogger(DelayedClientCall.class.getName()); + /** + * A timer to monitor the initial deadline. The timer must be cancelled on transition to the real + * call. + */ + @Nullable + private final ScheduledFuture initialDeadlineMonitor; + private final Executor callExecutor; + private final Context context; + /** {@code true} once realCall is valid and all pending calls have been drained. */ + private volatile boolean passThrough; + /** + * Non-{@code null} iff start has been called. Used to assert methods are called in appropriate + * order, but also used if an error occurs before {@code realCall} is set. + */ + private Listener listener; + // Must hold {@code this} lock when setting. + private ClientCall realCall; + @GuardedBy("this") + private Status error; + @GuardedBy("this") + private List pendingRunnables = new ArrayList<>(); + @GuardedBy("this") + private DelayedListener delayedListener; + + DelayedClientCall( + Executor callExecutor, ScheduledExecutorService scheduler, @Nullable Deadline deadline) { + this.callExecutor = checkNotNull(callExecutor, "callExecutor"); + checkNotNull(scheduler, "scheduler"); + context = Context.current(); + initialDeadlineMonitor = scheduleDeadlineIfNeeded(scheduler, deadline); + } + + @Nullable + private ScheduledFuture scheduleDeadlineIfNeeded( + ScheduledExecutorService scheduler, @Nullable Deadline deadline) { + Deadline contextDeadline = context.getDeadline(); + if (deadline == null && contextDeadline == null) { + return null; + } + long remainingNanos = Long.MAX_VALUE; + if (deadline != null) { + remainingNanos = Math.min(remainingNanos, deadline.timeRemaining(NANOSECONDS)); + } + if (contextDeadline != null && contextDeadline.timeRemaining(NANOSECONDS) < remainingNanos) { + remainingNanos = contextDeadline.timeRemaining(NANOSECONDS); + if (logger.isLoggable(Level.FINE)) { + StringBuilder builder = + new StringBuilder( + String.format( + "Call timeout set to '%d' ns, due to context deadline.", remainingNanos)); + if (deadline == null) { + builder.append(" Explicit call timeout was not set."); + } else { + long callTimeout = deadline.timeRemaining(TimeUnit.NANOSECONDS); + builder.append(String.format(" Explicit call timeout was '%d' ns.", callTimeout)); + } + logger.fine(builder.toString()); + } + } + long seconds = Math.abs(remainingNanos) / TimeUnit.SECONDS.toNanos(1); + long nanos = Math.abs(remainingNanos) % TimeUnit.SECONDS.toNanos(1); + final StringBuilder buf = new StringBuilder(); + if (remainingNanos < 0) { + buf.append("ClientCall started after deadline exceeded. Deadline exceeded after -"); + } else { + buf.append("Deadline exceeded after "); + } + buf.append(seconds); + buf.append(String.format(".%09d", nanos)); + buf.append("s. "); + /** Cancels the call if deadline exceeded prior to the real call being set. */ + class DeadlineExceededRunnable implements Runnable { + @Override + public void run() { + cancel( + Status.DEADLINE_EXCEEDED.withDescription(buf.toString()), + // We should not cancel the call if the realCall is set because there could be a + // race between cancel() and realCall.start(). The realCall will handle deadline by + // itself. + /* onlyCancelPendingCall= */ true); + } + } + + return scheduler.schedule(new DeadlineExceededRunnable(), remainingNanos, NANOSECONDS); + } + + /** + * Transfers all pending and future requests and mutations to the given call. + * + *

No-op if either this method or {@link #cancel} have already been called. + */ + // When this method returns, passThrough is guaranteed to be true + final void setCall(ClientCall call) { + synchronized (this) { + // If realCall != null, then either setCall() or cancel() has been called. + if (realCall != null) { + return; + } + setRealCall(checkNotNull(call, "call")); + } + drainPendingCalls(); + } + + @Override + public void start(Listener listener, final Metadata headers) { + checkState(this.listener == null, "already started"); + Status savedError; + boolean savedPassThrough; + synchronized (this) { + this.listener = checkNotNull(listener, "listener"); + // If error != null, then cancel() has been called and was unable to close the listener + savedError = error; + savedPassThrough = passThrough; + if (!savedPassThrough) { + listener = delayedListener = new DelayedListener<>(listener); + } + } + if (savedError != null) { + callExecutor.execute(new CloseListenerRunnable(listener, savedError)); + return; + } + if (savedPassThrough) { + realCall.start(listener, headers); + } else { + final Listener finalListener = listener; + delayOrExecute(new Runnable() { + @Override + public void run() { + realCall.start(finalListener, headers); + } + }); + } + } + + // When this method returns, passThrough is guaranteed to be true + @Override + public void cancel(@Nullable final String message, @Nullable final Throwable cause) { + Status status = Status.CANCELLED; + if (message != null) { + status = status.withDescription(message); + } else { + status = status.withDescription("Call cancelled without message"); + } + if (cause != null) { + status = status.withCause(cause); + } + cancel(status, false); + } + + /** + * Cancels the call unless {@code realCall} is set and {@code onlyCancelPendingCall} is true. + */ + private void cancel(final Status status, boolean onlyCancelPendingCall) { + boolean delegateToRealCall = true; + Listener listenerToClose = null; + synchronized (this) { + // If realCall != null, then either setCall() or cancel() has been called + if (realCall == null) { + @SuppressWarnings("unchecked") + ClientCall noopCall = (ClientCall) NOOP_CALL; + setRealCall(noopCall); + delegateToRealCall = false; + // If listener == null, then start() will later call listener with 'error' + listenerToClose = listener; + error = status; + } else if (onlyCancelPendingCall) { + return; + } + } + if (delegateToRealCall) { + delayOrExecute(new Runnable() { + @Override + public void run() { + realCall.cancel(status.getDescription(), status.getCause()); + } + }); + } else { + if (listenerToClose != null) { + callExecutor.execute(new CloseListenerRunnable(listenerToClose, status)); + } + drainPendingCalls(); + } + } + + private void delayOrExecute(Runnable runnable) { + synchronized (this) { + if (!passThrough) { + pendingRunnables.add(runnable); + return; + } + } + runnable.run(); + } + + /** + * Called to transition {@code passThrough} to {@code true}. This method is not safe to be called + * multiple times; the caller must ensure it will only be called once, ever. {@code this} lock + * should not be held when calling this method. + */ + private void drainPendingCalls() { + assert realCall != null; + assert !passThrough; + List toRun = new ArrayList<>(); + DelayedListener delayedListener ; + while (true) { + synchronized (this) { + if (pendingRunnables.isEmpty()) { + pendingRunnables = null; + passThrough = true; + delayedListener = this.delayedListener; + break; + } + // Since there were pendingCalls, we need to process them. To maintain ordering we can't set + // passThrough=true until we run all pendingCalls, but new Runnables may be added after we + // drop the lock. So we will have to re-check pendingCalls. + List tmp = toRun; + toRun = pendingRunnables; + pendingRunnables = tmp; + } + for (Runnable runnable : toRun) { + // Must not call transport while lock is held to prevent deadlocks. + // TODO(ejona): exception handling + runnable.run(); + } + toRun.clear(); + } + if (delayedListener != null) { + final DelayedListener listener = delayedListener; + class DrainListenerRunnable extends ContextRunnable { + DrainListenerRunnable() { + super(context); + } + + @Override + public void runInContext() { + listener.drainPendingCallbacks(); + } + } + + callExecutor.execute(new DrainListenerRunnable()); + } + } + + @GuardedBy("this") + private void setRealCall(ClientCall realCall) { + checkState(this.realCall == null, "realCall already set to %s", this.realCall); + if (initialDeadlineMonitor != null) { + initialDeadlineMonitor.cancel(false); + } + this.realCall = realCall; + } + + @VisibleForTesting + ClientCall getRealCall() { + return realCall; + } + + @Override + public void sendMessage(final ReqT message) { + if (passThrough) { + realCall.sendMessage(message); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realCall.sendMessage(message); + } + }); + } + } + + @Override + public void setMessageCompression(final boolean enable) { + if (passThrough) { + realCall.setMessageCompression(enable); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realCall.setMessageCompression(enable); + } + }); + } + } + + @Override + public void request(final int numMessages) { + if (passThrough) { + realCall.request(numMessages); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realCall.request(numMessages); + } + }); + } + } + + @Override + public void halfClose() { + delayOrExecute(new Runnable() { + @Override + public void run() { + realCall.halfClose(); + } + }); + } + + @Override + public boolean isReady() { + if (passThrough) { + return realCall.isReady(); + } else { + return false; + } + } + + @Override + public Attributes getAttributes() { + ClientCall savedRealCall; + synchronized (this) { + savedRealCall = realCall; + } + if (savedRealCall != null) { + return savedRealCall.getAttributes(); + } else { + return Attributes.EMPTY; + } + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("realCall", realCall) + .toString(); + } + + private final class CloseListenerRunnable extends ContextRunnable { + final Listener listener; + final Status status; + + CloseListenerRunnable(Listener listener, Status status) { + super(context); + this.listener = listener; + this.status = status; + } + + @Override + public void runInContext() { + listener.onClose(status, new Metadata()); + } + } + + private static final class DelayedListener extends Listener { + private final Listener realListener; + private volatile boolean passThrough; + @GuardedBy("this") + private List pendingCallbacks = new ArrayList<>(); + + public DelayedListener(Listener listener) { + this.realListener = listener; + } + + private void delayOrExecute(Runnable runnable) { + synchronized (this) { + if (!passThrough) { + pendingCallbacks.add(runnable); + return; + } + } + runnable.run(); + } + + @Override + public void onHeaders(final Metadata headers) { + if (passThrough) { + realListener.onHeaders(headers); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realListener.onHeaders(headers); + } + }); + } + } + + @Override + public void onMessage(final RespT message) { + if (passThrough) { + realListener.onMessage(message); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realListener.onMessage(message); + } + }); + } + } + + @Override + public void onClose(final Status status, final Metadata trailers) { + delayOrExecute(new Runnable() { + @Override + public void run() { + realListener.onClose(status, trailers); + } + }); + } + + @Override + public void onReady() { + if (passThrough) { + realListener.onReady(); + } else { + delayOrExecute(new Runnable() { + @Override + public void run() { + realListener.onReady(); + } + }); + } + } + + void drainPendingCallbacks() { + assert !passThrough; + List toRun = new ArrayList<>(); + while (true) { + synchronized (this) { + if (pendingCallbacks.isEmpty()) { + pendingCallbacks = null; + passThrough = true; + break; + } + // Since there were pendingCallbacks, we need to process them. To maintain ordering we + // can't set passThrough=true until we run all pendingCallbacks, but new Runnables may be + // added after we drop the lock. So we will have to re-check pendingCallbacks. + List tmp = toRun; + toRun = pendingCallbacks; + pendingCallbacks = tmp; + } + for (Runnable runnable : toRun) { + // Avoid calling listener while lock is held to prevent deadlocks. + // TODO(ejona): exception handling + runnable.run(); + } + toRun.clear(); + } + } + } + + private static final ClientCall NOOP_CALL = new ClientCall() { + @Override + public void start(Listener responseListener, Metadata headers) {} + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(String message, Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(Object message) {} + + // Always returns {@code false}, since this is only used when the startup of the call fails. + @Override + public boolean isReady() { + return false; + } + }; +} diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java new file mode 100644 index 0000000000..290e2b9de6 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -0,0 +1,122 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; +import io.grpc.Deadline; +import io.grpc.ForwardingTestUtil; +import io.grpc.Metadata; +import io.grpc.Status; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.concurrent.Executor; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Tests for {@link DelayedClientCall}. */ +@RunWith(JUnit4.class) +public class DelayedClientCallTest { + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + @Mock + private ClientCall mockRealCall; + @Mock + private ClientCall.Listener listener; + @Captor + ArgumentCaptor statusCaptor; + + private final FakeClock fakeClock = new FakeClock(); + private final Executor callExecutor = MoreExecutors.directExecutor(); + + @Test + public void allMethodsForwarded() throws Exception { + DelayedClientCall delayedClientCall = + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.setCall(mockRealCall); + ForwardingTestUtil.testMethodsForwarded( + ClientCall.class, + mockRealCall, + delayedClientCall, + Arrays.asList(ClientCall.class.getMethod("toString")), + new ForwardingTestUtil.ArgumentProvider() { + @Override + public Object get(Method method, int argPos, Class clazz) { + if (!Modifier.isFinal(clazz.getModifiers())) { + return mock(clazz); + } + if (clazz.equals(String.class)) { + return "message"; + } + return null; + } + }); + } + + // Coverage for deadline exceeded before call started is enforced by + // AbstractInteropTest.deadlineInPast(). + @Test + public void deadlineExceededWhileCallIsStartedButStillPending() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), Deadline.after(10, SECONDS)); + + delayedClientCall.start(listener, new Metadata()); + fakeClock.forwardTime(10, SECONDS); + verify(listener).onClose(statusCaptor.capture(), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.DEADLINE_EXCEEDED); + } + + @Test + public void listenerEventsPropagated() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), Deadline.after(10, SECONDS)); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.setCall(mockRealCall); + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("key", Metadata.ASCII_STRING_MARSHALLER), "value"); + realCallListener.onHeaders(metadata); + verify(listener).onHeaders(metadata); + realCallListener.onMessage(3); + verify(listener).onMessage(3); + realCallListener.onReady(); + verify(listener).onReady(); + Metadata trailer = new Metadata(); + trailer.put(Metadata.Key.of("key2", Metadata.ASCII_STRING_MARSHALLER), "value2"); + realCallListener.onClose(Status.DATA_LOSS, trailer); + verify(listener).onClose(statusCaptor.capture(), eq(trailer)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.DATA_LOSS); + } +}