diff --git a/core/src/main/java/io/grpc/CallOptions.java b/core/src/main/java/io/grpc/CallOptions.java index e5a1136575..81f771a4fc 100644 --- a/core/src/main/java/io/grpc/CallOptions.java +++ b/core/src/main/java/io/grpc/CallOptions.java @@ -33,6 +33,7 @@ package io.grpc; import com.google.common.base.MoreObjects; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; @@ -54,6 +55,7 @@ public final class CallOptions { // them outside of constructor. Otherwise the constructor will have a potentially long list of // unnamed arguments, which is undesirable. private Long deadlineNanoTime; + private Executor executor; @Nullable private String authority; @@ -144,6 +146,21 @@ public final class CallOptions { return authority; } + /** + * Returns a new {@code CallOptions} with {@code executor} to be used instead of the default + * executor specified with {@link ManagedChannelBuilder#executor}. + */ + public CallOptions withExecutor(Executor executor) { + CallOptions newOptions = new CallOptions(this); + newOptions.executor = executor; + return newOptions; + } + + @Nullable + public Executor getExecutor() { + return executor; + } + private CallOptions() { } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 4f0a128a97..9c2ba6d0fe 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -323,6 +323,10 @@ public final class ManagedChannelImpl extends ManagedChannel { @Override public ClientCall newCall(MethodDescriptor method, CallOptions callOptions) { + Executor executor = callOptions.getExecutor(); + if (executor == null) { + executor = ManagedChannelImpl.this.executor; + } return new ClientCallImpl( method, executor, diff --git a/core/src/test/java/io/grpc/CallOptionsTest.java b/core/src/test/java/io/grpc/CallOptionsTest.java index b73b623a48..a725f3a1bb 100644 --- a/core/src/test/java/io/grpc/CallOptionsTest.java +++ b/core/src/test/java/io/grpc/CallOptionsTest.java @@ -37,11 +37,13 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import com.google.common.base.Objects; +import com.google.common.util.concurrent.MoreExecutors; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; /** Unit tests for {@link CallOptions}. */ @@ -60,6 +62,7 @@ public class CallOptionsTest { assertNull(CallOptions.DEFAULT.getDeadlineNanoTime()); assertNull(CallOptions.DEFAULT.getAuthority()); assertNull(CallOptions.DEFAULT.getRequestKey()); + assertNull(CallOptions.DEFAULT.getExecutor()); } @Test @@ -89,6 +92,17 @@ public class CallOptionsTest { assertNull(options2.getDeadlineNanoTime()); } + @Test + public void mutateExecutor() { + Executor executor = MoreExecutors.directExecutor(); + CallOptions options1 = CallOptions.DEFAULT.withExecutor(executor); + assertNull(CallOptions.DEFAULT.getExecutor()); + assertSame(executor, options1.getExecutor()); + CallOptions options2 = options1.withExecutor(null); + assertSame(executor, options1.getExecutor()); + assertNull(options2.getExecutor()); + } + @Test public void testWithDeadlineAfter() { long deadline = CallOptions.DEFAULT diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index f6b01c1fde..a1e1410bba 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -88,6 +88,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicLong; @@ -314,6 +315,37 @@ public class ManagedChannelImplTest { transportListener.transportTerminated(); } + @Test + public void callOptionsExecutor() { + Metadata headers = new Metadata(); + ClientStream mockStream = mock(ClientStream.class); + when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream); + + final List runnables = new ArrayList(); + Executor executor = new Executor() { + @Override + public void execute(Runnable r) { + runnables.add(r); + } + }; + ManagedChannel channel = createChannel( + new FakeNameResolverFactory(true), NO_INTERCEPTOR); + ClientCall call = + channel.newCall(method, CallOptions.DEFAULT.withExecutor(executor)); + call.start(mockCallListener, headers); + verify(mockTransport, timeout(1000)).newStream(same(method), same(headers)); + verify(mockStream).start(streamListenerCaptor.capture()); + ClientStreamListener streamListener = streamListenerCaptor.getValue(); + Metadata trailers = new Metadata(); + streamListener.closed(Status.CANCELLED, trailers); + assertFalse(runnables.isEmpty()); + verify(mockCallListener, never()).onClose(same(Status.CANCELLED), same(trailers)); + for (Runnable r : runnables) { + r.run(); + } + verify(mockCallListener).onClose(same(Status.CANCELLED), same(trailers)); + } + @Test public void nameResolutionFailed() { Status error = Status.UNAVAILABLE.withCause(new Throwable("fake name resolution error"));