From 691e24fc19f0caec88c98e05e455804dca2fc470 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Mon, 15 Mar 2021 11:56:13 +0900 Subject: [PATCH] =?UTF-8?q?Add=20static=20wrappers=20for=20Executor=20/=20?= =?UTF-8?q?ExecutorService=20using=20current=20cont=E2=80=A6=20(#2988)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add static wrappers for Executor / ExecutorService using current context at invocation time. * Missing javadoc * Cleanup * taskWrapping --- .../io/opentelemetry/context/Context.java | 30 +++++++++ .../context/ContextExecutorService.java | 59 +++------------- .../CurrentContextExecutorService.java | 67 +++++++++++++++++++ .../context/ForwardingExecutorService.java | 61 +++++++++++++++++ .../io/opentelemetry/context/ContextTest.java | 35 +++++++++- 5 files changed, 202 insertions(+), 50 deletions(-) create mode 100644 context/src/main/java/io/opentelemetry/context/CurrentContextExecutorService.java create mode 100644 context/src/main/java/io/opentelemetry/context/ForwardingExecutorService.java diff --git a/context/src/main/java/io/opentelemetry/context/Context.java b/context/src/main/java/io/opentelemetry/context/Context.java index 10f56b8875..6f3d9a82e5 100644 --- a/context/src/main/java/io/opentelemetry/context/Context.java +++ b/context/src/main/java/io/opentelemetry/context/Context.java @@ -99,6 +99,36 @@ public interface Context { return ArrayBasedContext.root(); } + /** + * Returns an {@link Executor} which delegates to the provided {@code executor}, wrapping all + * invocations of {@link Executor#execute(Runnable)} with the {@linkplain Context#current() + * current context} at the time of invocation. + * + *

This is generally used to create an {@link Executor} which will forward the {@link Context} + * during an invocation to another thread. For example, you may use something like {@code Executor + * dbExecutor = Context.wrapTasks(threadPool)} to ensure calls like {@code dbExecutor.execute(() + * -> database.query())} have {@link Context} available on the thread executing database queries. + */ + static Executor taskWrapping(Executor executor) { + return command -> executor.execute(Context.current().wrap(command)); + } + + /** + * Returns an {@link ExecutorService} which delegates to the provided {@code executorService}, + * wrapping all invocations of {@link ExecutorService} methods such as {@link + * ExecutorService#execute(Runnable)} or {@link ExecutorService#submit(Runnable)} with the + * {@linkplain Context#current() current context} at the time of invocation. + * + *

This is generally used to create an {@link ExecutorService} which will forward the {@link + * Context} during an invocation to another thread. For example, you may use something like {@code + * ExecutorService dbExecutor = Context.wrapTasks(threadPool)} to ensure calls like {@code + * dbExecutor.execute(() -> database.query())} have {@link Context} available on the thread + * executing database queries. + */ + static ExecutorService taskWrapping(ExecutorService executorService) { + return new CurrentContextExecutorService(executorService); + } + /** * Returns the value stored in this {@link Context} for the given {@link ContextKey}, or {@code * null} if there is no value for the key in this context. diff --git a/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java b/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java index dd49e8c0c2..7b9153d33f 100644 --- a/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java +++ b/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java @@ -5,7 +5,6 @@ package io.opentelemetry.context; -import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.concurrent.Callable; @@ -15,99 +14,61 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -class ContextExecutorService implements ExecutorService { +class ContextExecutorService extends ForwardingExecutorService { private final Context context; - private final ExecutorService delegate; ContextExecutorService(Context context, ExecutorService delegate) { + super(delegate); this.context = context; - this.delegate = delegate; } final Context context() { return context; } - ExecutorService delegate() { - return delegate; - } - - @Override - public void shutdown() { - delegate.shutdown(); - } - - @Override - public List shutdownNow() { - return delegate.shutdownNow(); - } - - @Override - public boolean isShutdown() { - return delegate.isShutdown(); - } - - @Override - public boolean isTerminated() { - return delegate.isTerminated(); - } - - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return delegate.awaitTermination(timeout, unit); - } - @Override public Future submit(Callable task) { - return delegate.submit(context.wrap(task)); + return delegate().submit(context.wrap(task)); } @Override public Future submit(Runnable task, T result) { - return delegate.submit(context.wrap(task), result); + return delegate().submit(context.wrap(task), result); } @Override public Future submit(Runnable task) { - return delegate.submit(context.wrap(task)); + return delegate().submit(context.wrap(task)); } @Override public List> invokeAll(Collection> tasks) throws InterruptedException { - return delegate.invokeAll(wrap(tasks)); + return delegate().invokeAll(wrap(context, tasks)); } @Override public List> invokeAll( Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException { - return delegate.invokeAll(wrap(tasks), timeout, unit); + return delegate().invokeAll(wrap(context, tasks), timeout, unit); } @Override public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { - return delegate.invokeAny(wrap(tasks)); + return delegate().invokeAny(wrap(context, tasks)); } @Override public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - return delegate.invokeAny(wrap(tasks), timeout, unit); + return delegate().invokeAny(wrap(context, tasks), timeout, unit); } @Override public void execute(Runnable command) { - delegate.execute(context.wrap(command)); - } - - private Collection> wrap(Collection> tasks) { - List> wrapped = new ArrayList<>(); - for (Callable task : tasks) { - wrapped.add(context.wrap(task)); - } - return wrapped; + delegate().execute(context.wrap(command)); } } diff --git a/context/src/main/java/io/opentelemetry/context/CurrentContextExecutorService.java b/context/src/main/java/io/opentelemetry/context/CurrentContextExecutorService.java new file mode 100644 index 0000000000..346b849878 --- /dev/null +++ b/context/src/main/java/io/opentelemetry/context/CurrentContextExecutorService.java @@ -0,0 +1,67 @@ +/* + * Copyright The OpenTelemetry Authors + * SPDX-License-Identifier: Apache-2.0 + */ + +package io.opentelemetry.context; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +final class CurrentContextExecutorService extends ForwardingExecutorService { + + CurrentContextExecutorService(ExecutorService delegate) { + super(delegate); + } + + @Override + public Future submit(Callable task) { + return delegate().submit(Context.current().wrap(task)); + } + + @Override + public Future submit(Runnable task, T result) { + return delegate().submit(Context.current().wrap(task), result); + } + + @Override + public Future submit(Runnable task) { + return delegate().submit(Context.current().wrap(task)); + } + + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + return delegate().invokeAll(wrap(Context.current(), tasks)); + } + + @Override + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException { + return delegate().invokeAll(wrap(Context.current(), tasks), timeout, unit); + } + + @Override + public T invokeAny(Collection> tasks) + throws InterruptedException, ExecutionException { + return delegate().invokeAny(wrap(Context.current(), tasks)); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return delegate().invokeAny(wrap(Context.current(), tasks), timeout, unit); + } + + @Override + public void execute(Runnable command) { + delegate().execute(Context.current().wrap(command)); + } +} diff --git a/context/src/main/java/io/opentelemetry/context/ForwardingExecutorService.java b/context/src/main/java/io/opentelemetry/context/ForwardingExecutorService.java new file mode 100644 index 0000000000..b035a1bf5c --- /dev/null +++ b/context/src/main/java/io/opentelemetry/context/ForwardingExecutorService.java @@ -0,0 +1,61 @@ +/* + * Copyright The OpenTelemetry Authors + * SPDX-License-Identifier: Apache-2.0 + */ + +package io.opentelemetry.context; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +/** A {@link ExecutorService} that implements methods that don't need {@link Context}. */ +abstract class ForwardingExecutorService implements ExecutorService { + + private final ExecutorService delegate; + + protected ForwardingExecutorService(ExecutorService delegate) { + this.delegate = delegate; + } + + ExecutorService delegate() { + return delegate; + } + + @Override + public final void shutdown() { + delegate.shutdown(); + } + + @Override + public final List shutdownNow() { + return delegate.shutdownNow(); + } + + @Override + public final boolean isShutdown() { + return delegate.isShutdown(); + } + + @Override + public final boolean isTerminated() { + return delegate.isTerminated(); + } + + @Override + public final boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return delegate.awaitTermination(timeout, unit); + } + + protected static Collection> wrap( + Context context, Collection> tasks) { + List> wrapped = new ArrayList<>(); + for (Callable task : tasks) { + wrapped.add(context.wrap(task)); + } + return wrapped; + } +} diff --git a/context/src/test/java/io/opentelemetry/context/ContextTest.java b/context/src/test/java/io/opentelemetry/context/ContextTest.java index 60aec8d1d3..f42c830f72 100644 --- a/context/src/test/java/io/opentelemetry/context/ContextTest.java +++ b/context/src/test/java/io/opentelemetry/context/ContextTest.java @@ -217,6 +217,11 @@ class ContextTest { executor.execute(callback); assertThat(value).hasValue(null); + + try (Scope ignored = CAT.makeCurrent()) { + Context.taskWrapping(executor).execute(callback); + assertThat(value).hasValue("cat"); + } } @Nested @@ -227,10 +232,14 @@ class ContextTest { protected ExecutorService wrapped; protected AtomicReference value; + protected ExecutorService wrap(ExecutorService executorService) { + return CAT.wrap(executorService); + } + @BeforeAll void initExecutor() { executor = Executors.newSingleThreadScheduledExecutor(); - wrapped = CAT.wrap((ExecutorService) executor); + wrapped = wrap(executor); } @AfterAll @@ -358,6 +367,30 @@ class ContextTest { } } + @Nested + @TestInstance(Lifecycle.PER_CLASS) + class CurrentContextWrappingExecutorService extends WrapExecutorService { + @Override + protected ExecutorService wrap(ExecutorService executorService) { + return Context.taskWrapping(executorService); + } + + private Scope scope; + + @BeforeEach + // Closed in AfterEach + @SuppressWarnings("MustBeClosedChecker") + void makeCurrent() { + scope = CAT.makeCurrent(); + } + + @AfterEach + void close() { + scope.close(); + scope = null; + } + } + @Test void keyToString() { assertThat(ANIMAL.toString()).isEqualTo("animal");