Handle async requests in srping mvc library instrumentation (#10868)

This commit is contained in:
Lauri Tulmin 2024-03-20 14:04:53 +02:00 committed by GitHub
parent e4e224beff
commit 4c45e94098
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 282 additions and 4 deletions

View File

@ -13,9 +13,16 @@ import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.api.instrumenter.Instrumenter;
import io.opentelemetry.instrumentation.api.semconv.http.HttpServerRoute;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.Ordered;
import org.springframework.web.filter.OncePerRequestFilter;
@ -53,9 +60,11 @@ final class WebMvcTelemetryProducingFilter extends OncePerRequestFilter implemen
}
Context context = instrumenter.start(parentContext, request);
AsyncAwareHttpServletRequest asyncAwareRequest =
new AsyncAwareHttpServletRequest(request, response, context);
Throwable error = null;
try (Scope ignored = context.makeCurrent()) {
filterChain.doFilter(request, response);
filterChain.doFilter(asyncAwareRequest, response);
} catch (Throwable t) {
error = t;
throw t;
@ -63,7 +72,9 @@ final class WebMvcTelemetryProducingFilter extends OncePerRequestFilter implemen
if (httpRouteSupport.hasMappings()) {
HttpServerRoute.update(context, CONTROLLER, httpRouteSupport::getHttpRoute, request);
}
instrumenter.end(context, request, response, error);
if (error != null || asyncAwareRequest.isNotAsync()) {
instrumenter.end(context, request, response, error);
}
}
}
@ -75,4 +86,88 @@ final class WebMvcTelemetryProducingFilter extends OncePerRequestFilter implemen
// Run after all HIGHEST_PRECEDENCE items
return Ordered.HIGHEST_PRECEDENCE + 1;
}
private class AsyncAwareHttpServletRequest extends HttpServletRequestWrapper {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean listenerAttached = new AtomicBoolean();
AsyncAwareHttpServletRequest(
HttpServletRequest request, HttpServletResponse response, Context context) {
super(request);
this.request = request;
this.response = response;
this.context = context;
}
@Override
public AsyncContext startAsync() {
AsyncContext asyncContext = super.startAsync();
attachListener(asyncContext);
return asyncContext;
}
@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) {
AsyncContext asyncContext = super.startAsync(servletRequest, servletResponse);
attachListener(asyncContext);
return asyncContext;
}
private void attachListener(AsyncContext asyncContext) {
if (!listenerAttached.compareAndSet(false, true)) {
return;
}
asyncContext.addListener(
new AsyncRequestCompletionListener(request, response, context), request, response);
}
boolean isNotAsync() {
return !listenerAttached.get();
}
}
private class AsyncRequestCompletionListener implements AsyncListener {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean responseHandled = new AtomicBoolean();
AsyncRequestCompletionListener(
HttpServletRequest request, HttpServletResponse response, Context context) {
this.request = request;
this.response = response;
this.context = context;
}
@Override
public void onComplete(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}
@Override
public void onTimeout(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}
@Override
public void onError(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, asyncEvent.getThrowable());
}
}
@Override
public void onStartAsync(AsyncEvent asyncEvent) {
asyncEvent
.getAsyncContext()
.addListener(this, asyncEvent.getSuppliedRequest(), asyncEvent.getSuppliedResponse());
}
}
}

View File

@ -17,8 +17,12 @@ import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint
import static java.util.Collections.singletonList;
import io.opentelemetry.api.GlobalOpenTelemetry;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import javax.servlet.Filter;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@ -37,6 +41,7 @@ import org.springframework.web.servlet.view.RedirectView;
@SpringBootApplication
class TestWebSpringBootApp {
static final ServerEndpoint ASYNC_ENDPOINT = new ServerEndpoint("ASYNC", "async", 200, "success");
static ConfigurableApplicationContext start(int port, String contextPath) {
Properties props = new Properties();
@ -122,6 +127,26 @@ class TestWebSpringBootApp {
});
}
@RequestMapping("/async")
@ResponseBody
CompletableFuture<String> async() {
Context context = Context.current();
return CompletableFuture.supplyAsync(
() -> {
// Sleep a bit so that the future completes after the controller method. This helps to
// verify whether request ends after the future has completed not after when the
// controller method has completed.
try {
Thread.sleep(100);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
try (Scope ignored = context.makeCurrent()) {
return controller(ASYNC_ENDPOINT, ASYNC_ENDPOINT::getBody);
}
});
}
@ExceptionHandler
ResponseEntity<String> handleException(Throwable throwable) {
return new ResponseEntity<>(throwable.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);

View File

@ -5,12 +5,17 @@
package io.opentelemetry.instrumentation.spring.webmvc.v5_3;
import static org.assertj.core.api.Assertions.assertThat;
import io.opentelemetry.instrumentation.api.internal.HttpConstants;
import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerInstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerTestOptions;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpRequest;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.springframework.context.ConfigurableApplicationContext;
@ -49,4 +54,18 @@ class WebMvcHttpServerTest extends AbstractHttpServerTest<ConfigurableApplicatio
return expectedHttpRoute(endpoint, method);
});
}
@Test
void asyncRequest() {
ServerEndpoint endpoint = TestWebSpringBootApp.ASYNC_ENDPOINT;
String method = "GET";
AggregatedHttpRequest request = request(endpoint, method);
AggregatedHttpResponse response = client.execute(request).aggregate().join();
assertThat(response.status().code()).isEqualTo(endpoint.getStatus());
assertThat(response.contentUtf8()).isEqualTo(endpoint.getBody());
String spanId = assertResponseHasCustomizedHeaders(response, endpoint, null);
assertTheTraces(1, null, null, spanId, method, endpoint);
}
}

View File

@ -12,11 +12,18 @@ import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.api.instrumenter.Instrumenter;
import io.opentelemetry.instrumentation.api.semconv.http.HttpServerRoute;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.core.Ordered;
import org.springframework.web.filter.OncePerRequestFilter;
@ -53,9 +60,11 @@ final class WebMvcTelemetryProducingFilter extends OncePerRequestFilter implemen
}
Context context = instrumenter.start(parentContext, request);
AsyncAwareHttpServletRequest asyncAwareRequest =
new AsyncAwareHttpServletRequest(request, response, context);
Throwable error = null;
try (Scope ignored = context.makeCurrent()) {
filterChain.doFilter(request, response);
filterChain.doFilter(asyncAwareRequest, response);
} catch (Throwable t) {
error = t;
throw t;
@ -63,7 +72,9 @@ final class WebMvcTelemetryProducingFilter extends OncePerRequestFilter implemen
if (httpRouteSupport.hasMappings()) {
HttpServerRoute.update(context, CONTROLLER, httpRouteSupport::getHttpRoute, request);
}
instrumenter.end(context, request, response, error);
if (error != null || asyncAwareRequest.isNotAsync()) {
instrumenter.end(context, request, response, error);
}
}
}
@ -75,4 +86,88 @@ final class WebMvcTelemetryProducingFilter extends OncePerRequestFilter implemen
// Run after all HIGHEST_PRECEDENCE items
return Ordered.HIGHEST_PRECEDENCE + 1;
}
private class AsyncAwareHttpServletRequest extends HttpServletRequestWrapper {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean listenerAttached = new AtomicBoolean();
AsyncAwareHttpServletRequest(
HttpServletRequest request, HttpServletResponse response, Context context) {
super(request);
this.request = request;
this.response = response;
this.context = context;
}
@Override
public AsyncContext startAsync() {
AsyncContext asyncContext = super.startAsync();
attachListener(asyncContext);
return asyncContext;
}
@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) {
AsyncContext asyncContext = super.startAsync(servletRequest, servletResponse);
attachListener(asyncContext);
return asyncContext;
}
private void attachListener(AsyncContext asyncContext) {
if (!listenerAttached.compareAndSet(false, true)) {
return;
}
asyncContext.addListener(
new AsyncRequestCompletionListener(request, response, context), request, response);
}
boolean isNotAsync() {
return !listenerAttached.get();
}
}
private class AsyncRequestCompletionListener implements AsyncListener {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean responseHandled = new AtomicBoolean();
AsyncRequestCompletionListener(
HttpServletRequest request, HttpServletResponse response, Context context) {
this.request = request;
this.response = response;
this.context = context;
}
@Override
public void onComplete(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}
@Override
public void onTimeout(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}
@Override
public void onError(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, asyncEvent.getThrowable());
}
}
@Override
public void onStartAsync(AsyncEvent asyncEvent) {
asyncEvent
.getAsyncContext()
.addListener(this, asyncEvent.getSuppliedRequest(), asyncEvent.getSuppliedResponse());
}
}
}

View File

@ -17,9 +17,13 @@ import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint
import static java.util.Collections.singletonList;
import io.opentelemetry.api.GlobalOpenTelemetry;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import jakarta.servlet.Filter;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ConfigurableApplicationContext;
@ -37,6 +41,7 @@ import org.springframework.web.servlet.view.RedirectView;
@SpringBootApplication
class TestWebSpringBootApp {
static final ServerEndpoint ASYNC_ENDPOINT = new ServerEndpoint("ASYNC", "async", 200, "success");
static ConfigurableApplicationContext start(int port, String contextPath) {
Properties props = new Properties();
@ -122,6 +127,26 @@ class TestWebSpringBootApp {
});
}
@RequestMapping("/async")
@ResponseBody
CompletableFuture<String> async() {
Context context = Context.current();
return CompletableFuture.supplyAsync(
() -> {
// Sleep a bit so that the future completes after the controller method. This helps to
// verify whether request ends after the future has completed not after when the
// controller method has completed.
try {
Thread.sleep(100);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
try (Scope ignored = context.makeCurrent()) {
return controller(ASYNC_ENDPOINT, ASYNC_ENDPOINT::getBody);
}
});
}
@ExceptionHandler
ResponseEntity<String> handleException(Throwable throwable) {
return new ResponseEntity<>(throwable.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);

View File

@ -5,12 +5,17 @@
package io.opentelemetry.instrumentation.spring.webmvc.v6_0;
import static org.assertj.core.api.Assertions.assertThat;
import io.opentelemetry.instrumentation.api.internal.HttpConstants;
import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerInstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerTestOptions;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpRequest;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.springframework.context.ConfigurableApplicationContext;
@ -51,4 +56,18 @@ class WebMvcHttpServerTest extends AbstractHttpServerTest<ConfigurableApplicatio
options.setResponseCodeOnNonStandardHttpMethod(501);
}
@Test
void asyncRequest() {
ServerEndpoint endpoint = TestWebSpringBootApp.ASYNC_ENDPOINT;
String method = "GET";
AggregatedHttpRequest request = request(endpoint, method);
AggregatedHttpResponse response = client.execute(request).aggregate().join();
assertThat(response.status().code()).isEqualTo(endpoint.getStatus());
assertThat(response.contentUtf8()).isEqualTo(endpoint.getBody());
String spanId = assertResponseHasCustomizedHeaders(response, endpoint, null);
assertTheTraces(1, null, null, spanId, method, endpoint);
}
}