diff --git a/context/src/main/java/io/opentelemetry/context/Context.java b/context/src/main/java/io/opentelemetry/context/Context.java
index 10f56b8875..6f3d9a82e5 100644
--- a/context/src/main/java/io/opentelemetry/context/Context.java
+++ b/context/src/main/java/io/opentelemetry/context/Context.java
@@ -99,6 +99,36 @@ public interface Context {
return ArrayBasedContext.root();
}
+ /**
+ * Returns an {@link Executor} which delegates to the provided {@code executor}, wrapping all
+ * invocations of {@link Executor#execute(Runnable)} with the {@linkplain Context#current()
+ * current context} at the time of invocation.
+ *
+ *
This is generally used to create an {@link Executor} which will forward the {@link Context}
+ * during an invocation to another thread. For example, you may use something like {@code Executor
+ * dbExecutor = Context.wrapTasks(threadPool)} to ensure calls like {@code dbExecutor.execute(()
+ * -> database.query())} have {@link Context} available on the thread executing database queries.
+ */
+ static Executor taskWrapping(Executor executor) {
+ return command -> executor.execute(Context.current().wrap(command));
+ }
+
+ /**
+ * Returns an {@link ExecutorService} which delegates to the provided {@code executorService},
+ * wrapping all invocations of {@link ExecutorService} methods such as {@link
+ * ExecutorService#execute(Runnable)} or {@link ExecutorService#submit(Runnable)} with the
+ * {@linkplain Context#current() current context} at the time of invocation.
+ *
+ *
This is generally used to create an {@link ExecutorService} which will forward the {@link
+ * Context} during an invocation to another thread. For example, you may use something like {@code
+ * ExecutorService dbExecutor = Context.wrapTasks(threadPool)} to ensure calls like {@code
+ * dbExecutor.execute(() -> database.query())} have {@link Context} available on the thread
+ * executing database queries.
+ */
+ static ExecutorService taskWrapping(ExecutorService executorService) {
+ return new CurrentContextExecutorService(executorService);
+ }
+
/**
* Returns the value stored in this {@link Context} for the given {@link ContextKey}, or {@code
* null} if there is no value for the key in this context.
diff --git a/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java b/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java
index dd49e8c0c2..7b9153d33f 100644
--- a/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java
+++ b/context/src/main/java/io/opentelemetry/context/ContextExecutorService.java
@@ -5,7 +5,6 @@
package io.opentelemetry.context;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
@@ -15,99 +14,61 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
-class ContextExecutorService implements ExecutorService {
+class ContextExecutorService extends ForwardingExecutorService {
private final Context context;
- private final ExecutorService delegate;
ContextExecutorService(Context context, ExecutorService delegate) {
+ super(delegate);
this.context = context;
- this.delegate = delegate;
}
final Context context() {
return context;
}
- ExecutorService delegate() {
- return delegate;
- }
-
- @Override
- public void shutdown() {
- delegate.shutdown();
- }
-
- @Override
- public List shutdownNow() {
- return delegate.shutdownNow();
- }
-
- @Override
- public boolean isShutdown() {
- return delegate.isShutdown();
- }
-
- @Override
- public boolean isTerminated() {
- return delegate.isTerminated();
- }
-
- @Override
- public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
- return delegate.awaitTermination(timeout, unit);
- }
-
@Override
public Future submit(Callable task) {
- return delegate.submit(context.wrap(task));
+ return delegate().submit(context.wrap(task));
}
@Override
public Future submit(Runnable task, T result) {
- return delegate.submit(context.wrap(task), result);
+ return delegate().submit(context.wrap(task), result);
}
@Override
public Future> submit(Runnable task) {
- return delegate.submit(context.wrap(task));
+ return delegate().submit(context.wrap(task));
}
@Override
public List> invokeAll(Collection extends Callable> tasks)
throws InterruptedException {
- return delegate.invokeAll(wrap(tasks));
+ return delegate().invokeAll(wrap(context, tasks));
}
@Override
public List> invokeAll(
Collection extends Callable> tasks, long timeout, TimeUnit unit)
throws InterruptedException {
- return delegate.invokeAll(wrap(tasks), timeout, unit);
+ return delegate().invokeAll(wrap(context, tasks), timeout, unit);
}
@Override
public T invokeAny(Collection extends Callable> tasks)
throws InterruptedException, ExecutionException {
- return delegate.invokeAny(wrap(tasks));
+ return delegate().invokeAny(wrap(context, tasks));
}
@Override
public T invokeAny(Collection extends Callable> tasks, long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
- return delegate.invokeAny(wrap(tasks), timeout, unit);
+ return delegate().invokeAny(wrap(context, tasks), timeout, unit);
}
@Override
public void execute(Runnable command) {
- delegate.execute(context.wrap(command));
- }
-
- private Collection extends Callable> wrap(Collection extends Callable> tasks) {
- List> wrapped = new ArrayList<>();
- for (Callable task : tasks) {
- wrapped.add(context.wrap(task));
- }
- return wrapped;
+ delegate().execute(context.wrap(command));
}
}
diff --git a/context/src/main/java/io/opentelemetry/context/CurrentContextExecutorService.java b/context/src/main/java/io/opentelemetry/context/CurrentContextExecutorService.java
new file mode 100644
index 0000000000..346b849878
--- /dev/null
+++ b/context/src/main/java/io/opentelemetry/context/CurrentContextExecutorService.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright The OpenTelemetry Authors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package io.opentelemetry.context;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+final class CurrentContextExecutorService extends ForwardingExecutorService {
+
+ CurrentContextExecutorService(ExecutorService delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public Future submit(Callable task) {
+ return delegate().submit(Context.current().wrap(task));
+ }
+
+ @Override
+ public Future submit(Runnable task, T result) {
+ return delegate().submit(Context.current().wrap(task), result);
+ }
+
+ @Override
+ public Future> submit(Runnable task) {
+ return delegate().submit(Context.current().wrap(task));
+ }
+
+ @Override
+ public List> invokeAll(Collection extends Callable> tasks)
+ throws InterruptedException {
+ return delegate().invokeAll(wrap(Context.current(), tasks));
+ }
+
+ @Override
+ public List> invokeAll(
+ Collection extends Callable> tasks, long timeout, TimeUnit unit)
+ throws InterruptedException {
+ return delegate().invokeAll(wrap(Context.current(), tasks), timeout, unit);
+ }
+
+ @Override
+ public T invokeAny(Collection extends Callable> tasks)
+ throws InterruptedException, ExecutionException {
+ return delegate().invokeAny(wrap(Context.current(), tasks));
+ }
+
+ @Override
+ public T invokeAny(Collection extends Callable> tasks, long timeout, TimeUnit unit)
+ throws InterruptedException, ExecutionException, TimeoutException {
+ return delegate().invokeAny(wrap(Context.current(), tasks), timeout, unit);
+ }
+
+ @Override
+ public void execute(Runnable command) {
+ delegate().execute(Context.current().wrap(command));
+ }
+}
diff --git a/context/src/main/java/io/opentelemetry/context/ForwardingExecutorService.java b/context/src/main/java/io/opentelemetry/context/ForwardingExecutorService.java
new file mode 100644
index 0000000000..b035a1bf5c
--- /dev/null
+++ b/context/src/main/java/io/opentelemetry/context/ForwardingExecutorService.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright The OpenTelemetry Authors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package io.opentelemetry.context;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link ExecutorService} that implements methods that don't need {@link Context}. */
+abstract class ForwardingExecutorService implements ExecutorService {
+
+ private final ExecutorService delegate;
+
+ protected ForwardingExecutorService(ExecutorService delegate) {
+ this.delegate = delegate;
+ }
+
+ ExecutorService delegate() {
+ return delegate;
+ }
+
+ @Override
+ public final void shutdown() {
+ delegate.shutdown();
+ }
+
+ @Override
+ public final List shutdownNow() {
+ return delegate.shutdownNow();
+ }
+
+ @Override
+ public final boolean isShutdown() {
+ return delegate.isShutdown();
+ }
+
+ @Override
+ public final boolean isTerminated() {
+ return delegate.isTerminated();
+ }
+
+ @Override
+ public final boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
+ return delegate.awaitTermination(timeout, unit);
+ }
+
+ protected static Collection extends Callable> wrap(
+ Context context, Collection extends Callable> tasks) {
+ List> wrapped = new ArrayList<>();
+ for (Callable task : tasks) {
+ wrapped.add(context.wrap(task));
+ }
+ return wrapped;
+ }
+}
diff --git a/context/src/test/java/io/opentelemetry/context/ContextTest.java b/context/src/test/java/io/opentelemetry/context/ContextTest.java
index 60aec8d1d3..f42c830f72 100644
--- a/context/src/test/java/io/opentelemetry/context/ContextTest.java
+++ b/context/src/test/java/io/opentelemetry/context/ContextTest.java
@@ -217,6 +217,11 @@ class ContextTest {
executor.execute(callback);
assertThat(value).hasValue(null);
+
+ try (Scope ignored = CAT.makeCurrent()) {
+ Context.taskWrapping(executor).execute(callback);
+ assertThat(value).hasValue("cat");
+ }
}
@Nested
@@ -227,10 +232,14 @@ class ContextTest {
protected ExecutorService wrapped;
protected AtomicReference value;
+ protected ExecutorService wrap(ExecutorService executorService) {
+ return CAT.wrap(executorService);
+ }
+
@BeforeAll
void initExecutor() {
executor = Executors.newSingleThreadScheduledExecutor();
- wrapped = CAT.wrap((ExecutorService) executor);
+ wrapped = wrap(executor);
}
@AfterAll
@@ -358,6 +367,30 @@ class ContextTest {
}
}
+ @Nested
+ @TestInstance(Lifecycle.PER_CLASS)
+ class CurrentContextWrappingExecutorService extends WrapExecutorService {
+ @Override
+ protected ExecutorService wrap(ExecutorService executorService) {
+ return Context.taskWrapping(executorService);
+ }
+
+ private Scope scope;
+
+ @BeforeEach
+ // Closed in AfterEach
+ @SuppressWarnings("MustBeClosedChecker")
+ void makeCurrent() {
+ scope = CAT.makeCurrent();
+ }
+
+ @AfterEach
+ void close() {
+ scope.close();
+ scope = null;
+ }
+ }
+
@Test
void keyToString() {
assertThat(ANIMAL.toString()).isEqualTo("animal");