diff --git a/instrumentation/cassandra/cassandra-4.0/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/cassandra/v4_0/TracingCqlSession.java b/instrumentation/cassandra/cassandra-4.0/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/cassandra/v4_0/TracingCqlSession.java index b6126c51f0..676868dffd 100644 --- a/instrumentation/cassandra/cassandra-4.0/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/cassandra/v4_0/TracingCqlSession.java +++ b/instrumentation/cassandra/cassandra-4.0/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/cassandra/v4_0/TracingCqlSession.java @@ -26,7 +26,9 @@ import com.datastax.oss.driver.api.core.type.reflect.GenericType; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.function.Supplier; import org.checkerframework.checker.nullness.qual.Nullable; public class TracingCqlSession implements CqlSession { @@ -188,29 +190,50 @@ public class TracingCqlSession implements CqlSession { public CompletionStage executeAsync(Statement statement) { String query = getQuery(statement); CassandraRequest request = CassandraRequest.create(session, query); - Context context = instrumenter().start(Context.current(), request); - try (Scope ignored = context.makeCurrent()) { - CompletionStage stage = session.executeAsync(statement); - return stage.whenComplete( - (asyncResultSet, throwable) -> - instrumenter() - .end(context, request, getExecutionInfo(asyncResultSet, throwable), throwable)); - } + return executeAsync(request, () -> session.executeAsync(statement)); } @Override public CompletionStage executeAsync(String query) { CassandraRequest request = CassandraRequest.create(session, query); - Context context = instrumenter().start(Context.current(), request); + return executeAsync(request, () -> session.executeAsync(query)); + } + + private static CompletionStage executeAsync( + CassandraRequest request, Supplier> query) { + Context parentContext = Context.current(); + Context context = instrumenter().start(parentContext, request); try (Scope ignored = context.makeCurrent()) { - CompletionStage stage = session.executeAsync(query); - return stage.whenComplete( - (asyncResultSet, throwable) -> - instrumenter() - .end(context, request, getExecutionInfo(asyncResultSet, throwable), throwable)); + CompletionStage stage = query.get(); + return wrap( + stage.whenComplete( + (asyncResultSet, throwable) -> + instrumenter() + .end( + context, + request, + getExecutionInfo(asyncResultSet, throwable), + throwable)), + parentContext); } } + static CompletableFuture wrap(CompletionStage future, Context context) { + CompletableFuture result = new CompletableFuture<>(); + future.whenComplete( + (T value, Throwable throwable) -> { + try (Scope ignored = context.makeCurrent()) { + if (throwable != null) { + result.completeExceptionally(throwable); + } else { + result.complete(value); + } + } + }); + + return result; + } + private static String getQuery(Statement statement) { String query = null; if (statement instanceof SimpleStatement) { diff --git a/instrumentation/cassandra/cassandra-4.0/javaagent/src/test/groovy/CassandraClientTest.groovy b/instrumentation/cassandra/cassandra-4.0/javaagent/src/test/groovy/CassandraClientTest.groovy index 338d9dd2b4..ecda5fee56 100644 --- a/instrumentation/cassandra/cassandra-4.0/javaagent/src/test/groovy/CassandraClientTest.groovy +++ b/instrumentation/cassandra/cassandra-4.0/javaagent/src/test/groovy/CassandraClientTest.groovy @@ -74,14 +74,17 @@ class CassandraClientTest extends AgentInstrumentationSpecification { CqlSession session = getSession(keyspace) runUnderTrace("parent") { - session.executeAsync(statement).toCompletableFuture().get() + session.executeAsync(statement).toCompletableFuture().whenComplete({result, throwable -> + runUnderTrace("child") {} + }) .get() } expect: assertTraces(1) { - trace(0, 2) { + trace(0, 3) { basicSpan(it, 0, "parent") cassandraSpan(it, 1, spanName, expectedStatement, operation, keyspace, table, span(0)) + basicSpan(it, 2, "child", span(0)) } }