Add static wrappers for Executor / ExecutorService using current cont… (#2988)

* Add static wrappers for Executor / ExecutorService using current context at invocation time.

* Missing javadoc

* Cleanup

* taskWrapping
This commit is contained in:
Anuraag Agrawal 2021-03-15 11:56:13 +09:00 committed by GitHub
parent 5f32df7b9b
commit 691e24fc19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 202 additions and 50 deletions

View File

@ -99,6 +99,36 @@ public interface Context {
return ArrayBasedContext.root();
}
/**
* Returns an {@link Executor} which delegates to the provided {@code executor}, wrapping all
* invocations of {@link Executor#execute(Runnable)} with the {@linkplain Context#current()
* current context} at the time of invocation.
*
* <p>This is generally used to create an {@link Executor} which will forward the {@link Context}
* during an invocation to another thread. For example, you may use something like {@code Executor
* dbExecutor = Context.wrapTasks(threadPool)} to ensure calls like {@code dbExecutor.execute(()
* -> database.query())} have {@link Context} available on the thread executing database queries.
*/
static Executor taskWrapping(Executor executor) {
return command -> executor.execute(Context.current().wrap(command));
}
/**
* Returns an {@link ExecutorService} which delegates to the provided {@code executorService},
* wrapping all invocations of {@link ExecutorService} methods such as {@link
* ExecutorService#execute(Runnable)} or {@link ExecutorService#submit(Runnable)} with the
* {@linkplain Context#current() current context} at the time of invocation.
*
* <p>This is generally used to create an {@link ExecutorService} which will forward the {@link
* Context} during an invocation to another thread. For example, you may use something like {@code
* ExecutorService dbExecutor = Context.wrapTasks(threadPool)} to ensure calls like {@code
* dbExecutor.execute(() -> database.query())} have {@link Context} available on the thread
* executing database queries.
*/
static ExecutorService taskWrapping(ExecutorService executorService) {
return new CurrentContextExecutorService(executorService);
}
/**
* Returns the value stored in this {@link Context} for the given {@link ContextKey}, or {@code
* null} if there is no value for the key in this context.

View File

@ -5,7 +5,6 @@
package io.opentelemetry.context;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
@ -15,99 +14,61 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
class ContextExecutorService implements ExecutorService {
class ContextExecutorService extends ForwardingExecutorService {
private final Context context;
private final ExecutorService delegate;
ContextExecutorService(Context context, ExecutorService delegate) {
super(delegate);
this.context = context;
this.delegate = delegate;
}
final Context context() {
return context;
}
ExecutorService delegate() {
return delegate;
}
@Override
public void shutdown() {
delegate.shutdown();
}
@Override
public List<Runnable> shutdownNow() {
return delegate.shutdownNow();
}
@Override
public boolean isShutdown() {
return delegate.isShutdown();
}
@Override
public boolean isTerminated() {
return delegate.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate.awaitTermination(timeout, unit);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
return delegate.submit(context.wrap(task));
return delegate().submit(context.wrap(task));
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
return delegate.submit(context.wrap(task), result);
return delegate().submit(context.wrap(task), result);
}
@Override
public Future<?> submit(Runnable task) {
return delegate.submit(context.wrap(task));
return delegate().submit(context.wrap(task));
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
throws InterruptedException {
return delegate.invokeAll(wrap(tasks));
return delegate().invokeAll(wrap(context, tasks));
}
@Override
public <T> List<Future<T>> invokeAll(
Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException {
return delegate.invokeAll(wrap(tasks), timeout, unit);
return delegate().invokeAll(wrap(context, tasks), timeout, unit);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks)
throws InterruptedException, ExecutionException {
return delegate.invokeAny(wrap(tasks));
return delegate().invokeAny(wrap(context, tasks));
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
return delegate.invokeAny(wrap(tasks), timeout, unit);
return delegate().invokeAny(wrap(context, tasks), timeout, unit);
}
@Override
public void execute(Runnable command) {
delegate.execute(context.wrap(command));
}
private <T> Collection<? extends Callable<T>> wrap(Collection<? extends Callable<T>> tasks) {
List<Callable<T>> wrapped = new ArrayList<>();
for (Callable<T> task : tasks) {
wrapped.add(context.wrap(task));
}
return wrapped;
delegate().execute(context.wrap(command));
}
}

View File

@ -0,0 +1,67 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/
package io.opentelemetry.context;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
final class CurrentContextExecutorService extends ForwardingExecutorService {
CurrentContextExecutorService(ExecutorService delegate) {
super(delegate);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
return delegate().submit(Context.current().wrap(task));
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
return delegate().submit(Context.current().wrap(task), result);
}
@Override
public Future<?> submit(Runnable task) {
return delegate().submit(Context.current().wrap(task));
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
throws InterruptedException {
return delegate().invokeAll(wrap(Context.current(), tasks));
}
@Override
public <T> List<Future<T>> invokeAll(
Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException {
return delegate().invokeAll(wrap(Context.current(), tasks), timeout, unit);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks)
throws InterruptedException, ExecutionException {
return delegate().invokeAny(wrap(Context.current(), tasks));
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
return delegate().invokeAny(wrap(Context.current(), tasks), timeout, unit);
}
@Override
public void execute(Runnable command) {
delegate().execute(Context.current().wrap(command));
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/
package io.opentelemetry.context;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
/** A {@link ExecutorService} that implements methods that don't need {@link Context}. */
abstract class ForwardingExecutorService implements ExecutorService {
private final ExecutorService delegate;
protected ForwardingExecutorService(ExecutorService delegate) {
this.delegate = delegate;
}
ExecutorService delegate() {
return delegate;
}
@Override
public final void shutdown() {
delegate.shutdown();
}
@Override
public final List<Runnable> shutdownNow() {
return delegate.shutdownNow();
}
@Override
public final boolean isShutdown() {
return delegate.isShutdown();
}
@Override
public final boolean isTerminated() {
return delegate.isTerminated();
}
@Override
public final boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate.awaitTermination(timeout, unit);
}
protected static <T> Collection<? extends Callable<T>> wrap(
Context context, Collection<? extends Callable<T>> tasks) {
List<Callable<T>> wrapped = new ArrayList<>();
for (Callable<T> task : tasks) {
wrapped.add(context.wrap(task));
}
return wrapped;
}
}

View File

@ -217,6 +217,11 @@ class ContextTest {
executor.execute(callback);
assertThat(value).hasValue(null);
try (Scope ignored = CAT.makeCurrent()) {
Context.taskWrapping(executor).execute(callback);
assertThat(value).hasValue("cat");
}
}
@Nested
@ -227,10 +232,14 @@ class ContextTest {
protected ExecutorService wrapped;
protected AtomicReference<String> value;
protected ExecutorService wrap(ExecutorService executorService) {
return CAT.wrap(executorService);
}
@BeforeAll
void initExecutor() {
executor = Executors.newSingleThreadScheduledExecutor();
wrapped = CAT.wrap((ExecutorService) executor);
wrapped = wrap(executor);
}
@AfterAll
@ -358,6 +367,30 @@ class ContextTest {
}
}
@Nested
@TestInstance(Lifecycle.PER_CLASS)
class CurrentContextWrappingExecutorService extends WrapExecutorService {
@Override
protected ExecutorService wrap(ExecutorService executorService) {
return Context.taskWrapping(executorService);
}
private Scope scope;
@BeforeEach
// Closed in AfterEach
@SuppressWarnings("MustBeClosedChecker")
void makeCurrent() {
scope = CAT.makeCurrent();
}
@AfterEach
void close() {
scope.close();
scope = null;
}
}
@Test
void keyToString() {
assertThat(ANIMAL.toString()).isEqualTo("animal");