Add context propagation for rector schedulers (#10311)

This commit is contained in:
Lauri Tulmin 2024-01-24 09:11:47 +02:00 committed by GitHub
parent 0949ae27c9
commit 9259ce828a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 87 additions and 2 deletions

View File

@ -31,6 +31,8 @@ import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Function; import java.util.function.Function;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import reactor.core.CoreSubscriber; import reactor.core.CoreSubscriber;
@ -40,9 +42,11 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Hooks; import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators; import reactor.core.publisher.Operators;
import reactor.core.scheduler.Schedulers;
/** Based on Spring Sleuth's Reactor instrumentation. */ /** Based on Spring Sleuth's Reactor instrumentation. */
public final class ContextPropagationOperator { public final class ContextPropagationOperator {
private static final Logger logger = Logger.getLogger(ContextPropagationOperator.class.getName());
private static final Object VALUE = new Object(); private static final Object VALUE = new Object();
@ -52,6 +56,8 @@ public final class ContextPropagationOperator {
@Nullable @Nullable
private static final MethodHandle FLUX_CONTEXT_WRITE_METHOD = getContextWriteMethod(Flux.class); private static final MethodHandle FLUX_CONTEXT_WRITE_METHOD = getContextWriteMethod(Flux.class);
@Nullable private static final MethodHandle SCHEDULERS_HOOK_METHOD = getSchedulersHookMethod();
@Nullable @Nullable
private static MethodHandle getContextWriteMethod(Class<?> type) { private static MethodHandle getContextWriteMethod(Class<?> type) {
MethodHandles.Lookup lookup = MethodHandles.publicLookup(); MethodHandles.Lookup lookup = MethodHandles.publicLookup();
@ -68,6 +74,18 @@ public final class ContextPropagationOperator {
return null; return null;
} }
@Nullable
private static MethodHandle getSchedulersHookMethod() {
MethodHandles.Lookup lookup = MethodHandles.publicLookup();
try {
return lookup.findStatic(
Schedulers.class, "onScheduleHook", methodType(void.class, String.class, Function.class));
} catch (NoSuchMethodException | IllegalAccessException e) {
// ignore
}
return null;
}
public static ContextPropagationOperator create() { public static ContextPropagationOperator create() {
return builder().build(); return builder().build();
} }
@ -137,10 +155,22 @@ public final class ContextPropagationOperator {
Hooks.onEachOperator( Hooks.onEachOperator(
TracingSubscriber.class.getName(), tracingLift(asyncOperationEndStrategy)); TracingSubscriber.class.getName(), tracingLift(asyncOperationEndStrategy));
AsyncOperationEndStrategies.instance().registerStrategy(asyncOperationEndStrategy); AsyncOperationEndStrategies.instance().registerStrategy(asyncOperationEndStrategy);
registerScheduleHook(RunnableWrapper.class.getName(), RunnableWrapper::new);
enabled = true; enabled = true;
} }
} }
private static void registerScheduleHook(String key, Function<Runnable, Runnable> function) {
if (SCHEDULERS_HOOK_METHOD == null) {
return;
}
try {
SCHEDULERS_HOOK_METHOD.invoke(key, function);
} catch (Throwable throwable) {
logger.log(Level.WARNING, "Failed to install scheduler hook", throwable);
}
}
/** Unregisters the hook registered by {@link #registerOnEachOperator()}. */ /** Unregisters the hook registered by {@link #registerOnEachOperator()}. */
public void resetOnEachOperator() { public void resetOnEachOperator() {
synchronized (lock) { synchronized (lock) {
@ -312,4 +342,21 @@ public final class ContextPropagationOperator {
return null; return null;
} }
} }
private static class RunnableWrapper implements Runnable {
private final Runnable delegate;
private final Context context;
RunnableWrapper(Runnable delegate) {
this.delegate = delegate;
context = Context.current();
}
@Override
public void run() {
try (Scope ignore = context.makeCurrent()) {
delegate.run();
}
}
}
} }

View File

@ -206,7 +206,15 @@ public class SpringWebfluxTest {
"/foo-delayed", "/foo-delayed",
"/foo-delayed", "/foo-delayed",
"getFooDelayed", "getFooDelayed",
new FooModel(3L, "delayed").toString())))); new FooModel(3L, "delayed").toString()))),
Arguments.of(
named(
"annotation API without parameters no mono",
new Parameter(
"/foo-no-mono",
"/foo-no-mono",
"getFooModelNoMono",
new FooModel(0L, "DEFAULT").toString()))));
} }
@ParameterizedTest(name = "{index}: {0}") @ParameterizedTest(name = "{index}: {0}")

View File

@ -63,6 +63,11 @@ public class TestController {
return Mono.just(id).delayElement(Duration.ofMillis(100)).map(TestController::tracedMethod); return Mono.just(id).delayElement(Duration.ofMillis(100)).map(TestController::tracedMethod);
} }
@GetMapping("/foo-no-mono")
FooModel getFooModelNoMono() {
return new FooModel(0L, "DEFAULT");
}
private static FooModel tracedMethod(long id) { private static FooModel tracedMethod(long id) {
tracer.spanBuilder("tracedMethod").startSpan().end(); tracer.spanBuilder("tracedMethod").startSpan().end();
return new FooModel(id, "tracedMethod"); return new FooModel(id, "tracedMethod");

View File

@ -5,11 +5,17 @@
package io.opentelemetry.instrumentation.spring.webflux.v5_3; package io.opentelemetry.instrumentation.spring.webflux.v5_3;
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.SUCCESS;
import static org.assertj.core.api.Assertions.assertThat;
import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension; import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest; import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerInstrumentationExtension; import io.opentelemetry.instrumentation.testing.junit.http.HttpServerInstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerTestOptions; import io.opentelemetry.instrumentation.testing.junit.http.HttpServerTestOptions;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint; import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpRequest;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.ConfigurableApplicationContext;
@ -47,4 +53,17 @@ public final class SpringWebfluxServerInstrumentationTest
options.disableTestNonStandardHttpMethod(); options.disableTestNonStandardHttpMethod();
} }
@Test
void noMono() {
ServerEndpoint endpoint = new ServerEndpoint("NO_MONO", "no-mono", 200, "success");
String method = "GET";
AggregatedHttpRequest request = request(endpoint, method);
AggregatedHttpResponse response = client.execute(request).aggregate().join();
assertThat(response.status().code()).isEqualTo(SUCCESS.getStatus());
assertThat(response.contentUtf8()).isEqualTo(SUCCESS.getBody());
assertTheTraces(1, null, null, null, method, endpoint);
}
} }

View File

@ -57,7 +57,7 @@ class TestWebfluxSpringBootApp {
.setCapturedServerResponseHeaders( .setCapturedServerResponseHeaders(
singletonList(AbstractHttpServerTest.TEST_RESPONSE_HEADER)) singletonList(AbstractHttpServerTest.TEST_RESPONSE_HEADER))
.build() .build()
.createWebFilter(); .createWebFilterAndRegisterReactorHook();
} }
@Controller @Controller
@ -69,6 +69,12 @@ class TestWebfluxSpringBootApp {
return Flux.defer(() -> Flux.just(controller(SUCCESS, SUCCESS::getBody))); return Flux.defer(() -> Flux.just(controller(SUCCESS, SUCCESS::getBody)));
} }
@RequestMapping("/no-mono")
@ResponseBody
String noMono() {
return controller(SUCCESS, SUCCESS::getBody);
}
@RequestMapping("/query") @RequestMapping("/query")
@ResponseBody @ResponseBody
Mono<String> query_param(@RequestParam("some") String param) { Mono<String> query_param(@RequestParam("some") String param) {