opentelemetry-java-instrume.../instrumentation/servlet/servlet-5.0/javaagent/src/test/groovy/TomcatServlet5Test.groovy

471 lines
15 KiB
Groovy

/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/
import io.opentelemetry.instrumentation.test.asserts.TraceAssert
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpResponse
import jakarta.servlet.Servlet
import jakarta.servlet.ServletException
import org.apache.catalina.AccessLog
import org.apache.catalina.Context
import org.apache.catalina.connector.Request
import org.apache.catalina.connector.Response
import org.apache.catalina.core.StandardHost
import org.apache.catalina.startup.Tomcat
import org.apache.catalina.valves.ErrorReportValve
import org.apache.catalina.valves.ValveBase
import org.apache.tomcat.JarScanFilter
import org.apache.tomcat.JarScanType
import spock.lang.Shared
import spock.lang.Unroll
import java.nio.file.Files
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.AUTH_REQUIRED
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.CAPTURE_HEADERS
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.CAPTURE_PARAMETERS
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.ERROR
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.EXCEPTION
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.INDEXED_CHILD
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.NOT_FOUND
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.QUERY_PARAM
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.REDIRECT
import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.SUCCESS
import static org.junit.jupiter.api.Assumptions.assumeTrue
@Unroll
abstract class TomcatServlet5Test extends AbstractServlet5Test<Tomcat, Context> {
static final ServerEndpoint ACCESS_LOG_SUCCESS = new ServerEndpoint("ACCESS_LOG_SUCCESS",
"success?access-log=true", SUCCESS.status, SUCCESS.body, false)
static final ServerEndpoint ACCESS_LOG_ERROR = new ServerEndpoint("ACCESS_LOG_ERROR",
"error-status?access-log=true", ERROR.status, ERROR.body, false)
@Override
Throwable expectedException() {
new ServletException(EXCEPTION.body)
}
@Override
boolean hasResponseSpan(ServerEndpoint endpoint) {
endpoint == NOT_FOUND || super.hasResponseSpan(endpoint)
}
@Override
void responseSpan(TraceAssert trace, int index, Object parent, String method, ServerEndpoint endpoint) {
switch (endpoint) {
case NOT_FOUND:
sendErrorSpan(trace, index, parent)
break
}
super.responseSpan(trace, index, parent, method, endpoint)
}
@Shared
def accessLogValue = new TestAccessLogValve()
@Override
Tomcat startServer(int port) {
def tomcatServer = new Tomcat()
def baseDir = Files.createTempDirectory("tomcat").toFile()
baseDir.deleteOnExit()
tomcatServer.setBaseDir(baseDir.getAbsolutePath())
tomcatServer.setPort(port)
tomcatServer.getConnector().enableLookups = true // get localhost instead of 127.0.0.1
File applicationDir = new File(baseDir, "/webapps/ROOT")
if (!applicationDir.exists()) {
applicationDir.mkdirs()
applicationDir.deleteOnExit()
}
Context servletContext = tomcatServer.addWebapp(contextPath, applicationDir.getAbsolutePath())
// Speed up startup by disabling jar scanning:
servletContext.getJarScanner().setJarScanFilter(new JarScanFilter() {
@Override
boolean check(JarScanType jarScanType, String jarName) {
return false
}
})
// setupAuthentication(tomcatServer, servletContext)
setupServlets(servletContext)
(tomcatServer.host as StandardHost).errorReportValveClass = ErrorHandlerValve.name
(tomcatServer.host as StandardHost).getPipeline().addValve(accessLogValue)
tomcatServer.start()
return tomcatServer
}
def setup() {
accessLogValue.loggedIds.clear()
}
@Override
void stopServer(Tomcat server) {
server.stop()
server.destroy()
}
@Override
String getContextPath() {
return "/tomcat-context"
}
@Override
void addServlet(Context servletContext, String path, Class<Servlet> servlet) {
String name = UUID.randomUUID()
Tomcat.addServlet(servletContext, name, servlet.newInstance())
servletContext.addServletMappingDecoded(path, name)
}
def "access log has ids for #count requests"() {
given:
def request = request(ACCESS_LOG_SUCCESS, method)
when:
List<AggregatedHttpResponse> responses = (1..count).collect {
return client.execute(request).aggregate().join()
}
then:
responses.each { response ->
assert response.status().code() == ACCESS_LOG_SUCCESS.status
assert response.contentUtf8() == ACCESS_LOG_SUCCESS.body
}
and:
assertTraces(count) {
accessLogValue.waitForLoggedIds(count)
assert accessLogValue.loggedIds.size() == count
def loggedTraces = accessLogValue.loggedIds*.first
def loggedSpans = accessLogValue.loggedIds*.second
(0..count - 1).each {
trace(it, 2) {
serverSpan(it, 0, null, null, "GET", ACCESS_LOG_SUCCESS)
controllerSpan(it, 1, span(0))
}
assert loggedTraces.contains(traces[it][0].traceId)
assert loggedSpans.contains(traces[it][0].spanId)
}
}
where:
method = "GET"
count << [1, 4] // make multiple requests.
}
def "access log has ids for error request"() {
setup:
assumeTrue(testError())
def request = request(ACCESS_LOG_ERROR, method)
def response = client.execute(request).aggregate().join()
expect:
response.status().code() == ACCESS_LOG_ERROR.status
response.contentUtf8() == ACCESS_LOG_ERROR.body
and:
def spanCount = 2
if (errorEndpointUsesSendError()) {
spanCount++
}
assertTraces(1) {
trace(0, spanCount) {
serverSpan(it, 0, null, null, method, ACCESS_LOG_ERROR)
def spanIndex = 1
controllerSpan(it, spanIndex, span(spanIndex - 1))
spanIndex++
if (errorEndpointUsesSendError()) {
sendErrorSpan(it, spanIndex, span(spanIndex - 1))
spanIndex++
}
}
accessLogValue.waitForLoggedIds(1)
def (String traceId, String spanId) = accessLogValue.loggedIds[0]
assert traces[0][0].traceId == traceId
assert traces[0][0].spanId == spanId
}
where:
method = "GET"
}
// FIXME: Add authentication tests back in...
// private setupAuthentication(Tomcat server, Context servletContext) {
// // Login Config
// LoginConfig authConfig = new LoginConfig()
// authConfig.setAuthMethod("BASIC")
//
// // adding constraint with role "test"
// SecurityConstraint constraint = new SecurityConstraint()
// constraint.addAuthRole("role")
//
// // add constraint to a collection with pattern /second
// SecurityCollection collection = new SecurityCollection()
// collection.addPattern("/auth/*")
// constraint.addCollection(collection)
//
// servletContext.setLoginConfig(authConfig)
// // does the context need a auth role too?
// servletContext.addSecurityRole("role")
// servletContext.addConstraint(constraint)
//
// // add tomcat users to realm
// MemoryRealm realm = new MemoryRealm() {
// protected void startInternal() {
// credentialHandler = new MessageDigestCredentialHandler()
// setState(LifecycleState.STARTING)
// }
// }
// realm.addUser(user, pass, "role")
// server.getEngine().setRealm(realm)
//
// servletContext.setLoginConfig(authConfig)
// }
}
class ErrorHandlerValve extends ErrorReportValve {
@Override
protected void report(Request request, Response response, Throwable t) {
if (response.getStatus() < 400 || response.getContentWritten() > 0 || !response.isError()) {
return
}
try {
response.writer.print(t ? t.cause.message : response.message)
} catch (IOException e) {
e.printStackTrace()
}
}
}
class TestAccessLogValve extends ValveBase implements AccessLog {
final List<Tuple2<String, String>> loggedIds = []
TestAccessLogValve() {
super(true)
}
void log(Request request, Response response, long time) {
if (request.getParameter("access-log") == null) {
return
}
synchronized (loggedIds) {
loggedIds.add(new Tuple2(request.getAttribute("trace_id"),
request.getAttribute("span_id")))
loggedIds.notifyAll()
}
}
void waitForLoggedIds(int expected) {
def timeout = TimeUnit.SECONDS.toMillis(20)
def startTime = System.currentTimeMillis()
def endTime = startTime + timeout
def toWait = timeout
synchronized (loggedIds) {
while (loggedIds.size() < expected && toWait > 0) {
loggedIds.wait(toWait)
toWait = endTime - System.currentTimeMillis()
}
if (toWait <= 0) {
throw new TimeoutException("Timeout waiting for " + expected + " access log ids, got " + loggedIds.size())
}
}
}
@Override
void setRequestAttributesEnabled(boolean requestAttributesEnabled) {
}
@Override
boolean getRequestAttributesEnabled() {
return false
}
@Override
void invoke(Request request, Response response) throws IOException, ServletException {
getNext().invoke(request, response)
}
}
class TomcatServlet5TestSync extends TomcatServlet5Test {
@Override
Class<Servlet> servlet() {
TestServlet5.Sync
}
}
class TomcatServlet5TestAsync extends TomcatServlet5Test {
@Override
Class<Servlet> servlet() {
TestServlet5.Async
}
@Override
boolean errorEndpointUsesSendError() {
false
}
}
class TomcatServlet5TestFakeAsync extends TomcatServlet5Test {
@Override
Class<Servlet> servlet() {
TestServlet5.FakeAsync
}
}
class TomcatServlet5TestForward extends TomcatDispatchTest {
@Override
Class<Servlet> servlet() {
TestServlet5.Sync // dispatch to sync servlet
}
@Override
boolean testNotFound() {
false
}
@Override
protected void setupServlets(Context context) {
super.setupServlets(context)
addServlet(context, "/dispatch" + SUCCESS.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + QUERY_PARAM.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + REDIRECT.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + ERROR.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + EXCEPTION.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + AUTH_REQUIRED.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + CAPTURE_HEADERS.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + CAPTURE_PARAMETERS.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + INDEXED_CHILD.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER.path, RequestDispatcherServlet.Forward)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM.path, RequestDispatcherServlet.Forward)
}
}
class TomcatServlet5TestInclude extends TomcatDispatchTest {
@Override
Class<Servlet> servlet() {
TestServlet5.Sync // dispatch to sync servlet
}
@Override
boolean testNotFound() {
false
}
@Override
boolean testRedirect() {
false
}
@Override
boolean testCapturedHttpHeaders() {
false
}
@Override
boolean testError() {
false
}
@Override
protected void setupServlets(Context context) {
super.setupServlets(context)
addServlet(context, "/dispatch" + SUCCESS.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + QUERY_PARAM.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + REDIRECT.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + ERROR.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + EXCEPTION.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + AUTH_REQUIRED.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + CAPTURE_HEADERS.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + CAPTURE_PARAMETERS.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + INDEXED_CHILD.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER.path, RequestDispatcherServlet.Include)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM.path, RequestDispatcherServlet.Include)
}
}
class TomcatServlet5TestDispatchImmediate extends TomcatDispatchTest {
@Override
Class<Servlet> servlet() {
TestServlet5.Sync
}
@Override
boolean testNotFound() {
false
}
@Override
protected void setupServlets(Context context) {
super.setupServlets(context)
addServlet(context, "/dispatch" + SUCCESS.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + QUERY_PARAM.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + ERROR.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + EXCEPTION.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + REDIRECT.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + AUTH_REQUIRED.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + CAPTURE_HEADERS.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + CAPTURE_PARAMETERS.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + INDEXED_CHILD.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM.path, TestServlet5.DispatchImmediate)
addServlet(context, "/dispatch/recursive", TestServlet5.DispatchRecursive)
}
}
class TomcatServlet5TestDispatchAsync extends TomcatDispatchTest {
@Override
Class<Servlet> servlet() {
TestServlet5.Async
}
@Override
protected void setupServlets(Context context) {
super.setupServlets(context)
addServlet(context, "/dispatch" + SUCCESS.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + QUERY_PARAM.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + ERROR.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + EXCEPTION.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + REDIRECT.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + AUTH_REQUIRED.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + CAPTURE_HEADERS.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + CAPTURE_PARAMETERS.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + INDEXED_CHILD.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + HTML_PRINT_WRITER.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch" + HTML_SERVLET_OUTPUT_STREAM.path, TestServlet5.DispatchAsync)
addServlet(context, "/dispatch/recursive", TestServlet5.DispatchRecursive)
}
@Override
boolean errorEndpointUsesSendError() {
false
}
}
abstract class TomcatDispatchTest extends TomcatServlet5Test {
@Override
URI buildAddress() {
return new URI("http://localhost:$port$contextPath/dispatch/")
}
}