New AgentTestRunner to auto-instrument spock tests

This commit is contained in:
Andrew Kent 2017-12-27 16:28:06 -08:00
parent 9d79531f59
commit ca574071ff
12 changed files with 290 additions and 205 deletions

View File

@ -11,7 +11,7 @@ dependencies {
testCompile deps.opentracingMock
testCompile project(':dd-java-agent:tooling').sourceSets.test.output
testCompile project(':dd-java-agent:testing')
testCompile group: 'org.mongodb', name: 'mongo-java-driver', version: '3.4.2'
testCompile group: 'org.mongodb', name: 'mongodb-driver-async', version: '3.4.2'

View File

@ -22,8 +22,8 @@ dependencies {
compile project(':dd-trace')
compile project(':dd-java-agent:tooling')
testCompile project(':dd-java-agent:tooling').sourceSets.test.output
testCompile project(':dd-java-agent:testing')
testCompile group: 'org.apache.activemq.tooling', name: 'activemq-junit', version: '5.14.5'
testCompile group: 'org.apache.activemq', name: 'activemq-pool', version: '5.14.5'
testCompile group: 'org.apache.activemq', name: 'activemq-broker', version: '5.14.5'

View File

@ -1,12 +1,9 @@
import com.datadoghq.trace.DDTracer
import com.datadoghq.trace.writer.ListWriter
import dd.test.TestUtils
import dd.test.AgentTestRunner
import org.apache.activemq.ActiveMQConnectionFactory
import org.apache.activemq.ActiveMQMessageConsumer
import org.apache.activemq.ActiveMQMessageProducer
import org.apache.activemq.junit.EmbeddedActiveMQBroker
import spock.lang.Shared
import spock.lang.Specification
import spock.lang.Unroll
import javax.jms.Connection
@ -15,20 +12,11 @@ import javax.jms.TextMessage
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicReference
class JMS1Test extends Specification {
@Shared
static ListWriter writer = new ListWriter()
@Shared
static DDTracer tracer = new DDTracer(writer)
class JMS1Test extends AgentTestRunner {
@Shared
static Session session
def setupSpec() {
TestUtils.addByteBuddyAgent()
TestUtils.registerOrReplaceGlobalTracer(tracer)
EmbeddedActiveMQBroker broker = new EmbeddedActiveMQBroker()
broker.start()
final ActiveMQConnectionFactory connectionFactory = broker.createConnectionFactory()
@ -38,10 +26,6 @@ class JMS1Test extends Specification {
session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE)
}
def setup() {
writer.start()
}
@Unroll
def "sending a message to #resourceName generates spans"() {
setup:
@ -55,10 +39,10 @@ class JMS1Test extends Specification {
expect:
receivedMessage.text == "a message"
writer.size() == 2
TEST_WRITER.size() == 2
and: // producer trace
def trace = writer.firstTrace()
def trace = TEST_WRITER.firstTrace()
trace.size() == 3
and: // span 0
@ -125,7 +109,7 @@ class JMS1Test extends Specification {
tags2.size() == 5
and: // consumer trace
def consumerTrace = writer.get(1)
def consumerTrace = TEST_WRITER.get(1)
consumerTrace.size() == 1
def consumerSpan = consumerTrace[0]
@ -175,14 +159,14 @@ class JMS1Test extends Specification {
def message = session.createTextMessage("a message")
producer.send(message)
lock.countDown()
writer.waitForTraces(2)
TEST_WRITER.waitForTraces(2)
expect:
messageRef.get().text == "a message"
writer.size() == 2
TEST_WRITER.size() == 2
and: // producer trace
def trace = writer.firstTrace()
def trace = TEST_WRITER.firstTrace()
trace.size() == 3
and: // span 0
@ -249,7 +233,7 @@ class JMS1Test extends Specification {
tags2.size() == 5
and: // consumer trace
def consumerTrace = writer.get(1)
def consumerTrace = TEST_WRITER.get(1)
consumerTrace.size() == 1
def consumerSpan = consumerTrace[0]

View File

@ -26,8 +26,8 @@ dependencies {
compile project(':dd-trace')
compile project(':dd-java-agent:tooling')
testCompile project(':dd-java-agent:tooling').sourceSets.test.output
testCompile project(':dd-java-agent:testing')
testCompile group: 'org.hornetq', name: 'hornetq-jms-client', version: '2.4.7.Final'
testCompile group: 'org.hornetq', name: 'hornetq-jms-server', version: '2.4.7.Final'

View File

@ -1,7 +1,5 @@
import com.datadoghq.trace.DDTracer
import com.datadoghq.trace.writer.ListWriter
import dd.test.AgentTestRunner
import com.google.common.io.Files
import dd.test.TestUtils
import org.hornetq.api.core.TransportConfiguration
import org.hornetq.api.core.client.HornetQClient
import org.hornetq.api.jms.HornetQJMSClient
@ -16,7 +14,6 @@ import org.hornetq.core.server.HornetQServers
import org.hornetq.jms.client.HornetQMessageConsumer
import org.hornetq.jms.client.HornetQMessageProducer
import spock.lang.Shared
import spock.lang.Specification
import spock.lang.Unroll
import javax.jms.Session
@ -24,19 +21,11 @@ import javax.jms.TextMessage
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicReference
class JMS2Test extends Specification {
@Shared
static ListWriter writer = new ListWriter()
@Shared
static DDTracer tracer = new DDTracer(writer)
class JMS2Test extends AgentTestRunner {
@Shared
static Session session
def setupSpec() {
TestUtils.addByteBuddyAgent()
TestUtils.registerOrReplaceGlobalTracer(tracer)
def tempDir = Files.createTempDir()
tempDir.deleteOnExit()
@ -71,10 +60,6 @@ class JMS2Test extends Specification {
session.run()
}
def setup() {
writer.start()
}
@Unroll
def "sending a message to #resourceName generates spans"() {
setup:
@ -88,10 +73,10 @@ class JMS2Test extends Specification {
expect:
receivedMessage.text == "a message"
writer.size() == 2
TEST_WRITER.size() == 2
and: // producer trace
def trace = writer.firstTrace()
def trace = TEST_WRITER.firstTrace()
trace.size() == 1
def producerSpan = trace[0]
@ -115,7 +100,7 @@ class JMS2Test extends Specification {
producerTags.size() == 5
and: // consumer trace
def consumerTrace = writer.get(1)
def consumerTrace = TEST_WRITER.get(1)
consumerTrace.size() == 1
def consumerSpan = consumerTrace[0]
@ -165,14 +150,14 @@ class JMS2Test extends Specification {
def message = session.createTextMessage("a message")
producer.send(message)
lock.countDown()
writer.waitForTraces(2)
TEST_WRITER.waitForTraces(2)
expect:
messageRef.get().text == "a message"
writer.size() == 2
TEST_WRITER.size() == 2
and: // producer trace
def trace = writer.firstTrace()
def trace = TEST_WRITER.firstTrace()
trace.size() == 1
def producerSpan = trace[0]
@ -196,7 +181,7 @@ class JMS2Test extends Specification {
producerTags.size() == 5
and: // consumer trace
def consumerTrace = writer.get(1)
def consumerTrace = TEST_WRITER.get(1)
consumerTrace.size() == 1
def consumerSpan = consumerTrace[0]

View File

@ -26,6 +26,6 @@ dependencies {
compile deps.bytebuddy
compile deps.opentracing
testCompile project(':dd-java-agent:testing')
testCompile group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.0.0'
testCompile project(':dd-java-agent:tooling').sourceSets.test.output
}

View File

@ -1,33 +1,15 @@
import com.datadoghq.trace.DDTags
import com.datadoghq.trace.DDTracer
import com.datadoghq.trace.writer.ListWriter
import dd.test.TestUtils
import dd.test.AgentTestRunner
import io.opentracing.tag.Tags
import okhttp3.OkHttpClient
import okhttp3.Request
import ratpack.http.Headers
import spock.lang.Shared
import spock.lang.Specification
import java.util.concurrent.atomic.AtomicReference
import static ratpack.groovy.test.embed.GroovyEmbeddedApp.ratpack
class OkHttp3Test extends Specification {
@Shared
def writer = new ListWriter()
@Shared
def tracer = new DDTracer(writer)
def setupSpec() {
TestUtils.addByteBuddyAgent()
TestUtils.registerOrReplaceGlobalTracer(tracer)
}
def setup() {
writer.start()
}
class OkHttp3Test extends AgentTestRunner {
def "sending a request creates spans and sends headers"() {
setup:
@ -49,9 +31,9 @@ class OkHttp3Test extends Specification {
expect:
response.body.string() == "pong"
writer.size() == 1
TEST_WRITER.size() == 1
def trace = writer.firstTrace()
def trace = TEST_WRITER.firstTrace()
trace.size() == 2
and: // span 0

View File

@ -0,0 +1,135 @@
package dd.test;
import com.datadoghq.agent.AgentInstaller;
import com.datadoghq.trace.DDTracer;
import com.datadoghq.trace.writer.ListWriter;
import io.opentracing.Tracer;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.Instrumentation;
import java.lang.reflect.Field;
import net.bytebuddy.agent.ByteBuddyAgent;
import net.bytebuddy.dynamic.ClassFileLocator;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.runner.RunWith;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.model.InitializationError;
import org.spockframework.runtime.Sputnik;
import org.spockframework.runtime.model.SpecMetadata;
import spock.lang.Specification;
@RunWith(AgentTestRunner.SpockRunner.class)
@SpecMetadata(filename = "AgentTestRunner.java", line = 0)
public abstract class AgentTestRunner extends Specification {
/**
* For test runs, agent's global tracer will report to this list writer.
*
* <p>Before the start of each test the reported traces will be reset.
*/
public static final ListWriter TEST_WRITER;
private static final Tracer TEST_TRACER;
private static final Instrumentation instrumentation;
private static ClassFileTransformer activeTransformer = null;
static {
TEST_WRITER = new ListWriter();
TEST_TRACER = new DDTracer(TEST_WRITER);
ByteBuddyAgent.install();
instrumentation = ByteBuddyAgent.getInstrumentation();
}
@BeforeClass
public static synchronized void agentSetup() {
if (null != activeTransformer) {
throw new IllegalStateException("transformer already in place: " + activeTransformer);
}
activeTransformer = AgentInstaller.installBytebuddyAgent(instrumentation);
TestUtils.registerOrReplaceGlobalTracer(TEST_TRACER);
}
@Before
public void beforeTest() {
TEST_WRITER.start();
}
@AfterClass
public static synchronized void agentClenup() {
instrumentation.removeTransformer(activeTransformer);
activeTransformer = null;
}
// FIXME: Remove SpockRunner and custom classload logic
public static class SpockRunner extends Sputnik {
private final InstrumentationClassLoader customLoader;
public SpockRunner(Class<?> clazz)
throws InitializationError, NoSuchFieldException, SecurityException,
IllegalArgumentException, IllegalAccessException {
super(shadowTestClass(clazz));
// access the classloader created in shadowTestClass above
Field clazzField = Sputnik.class.getDeclaredField("clazz");
try {
clazzField.setAccessible(true);
customLoader =
(InstrumentationClassLoader) ((Class<?>) clazzField.get(this)).getClassLoader();
} finally {
clazzField.setAccessible(false);
}
}
// Shadow the test class with bytes loaded by InstrumentationClassLoader
private static Class<?> shadowTestClass(final Class<?> clazz) {
try {
InstrumentationClassLoader customLoader =
new InstrumentationClassLoader(SpockRunner.class.getClassLoader());
return customLoader.shadow(clazz);
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
// Replace the context class loader for each test with InstrumentationClassLoader
@Override
public void run(final RunNotifier notifier) {
final ClassLoader contextLoader = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(customLoader);
super.run(notifier);
} finally {
Thread.currentThread().setContextClassLoader(contextLoader);
}
}
}
/**
* A ClassLoader which retransforms classes unseen by the installed agent. With the exception of
* shadowed classes, this class
*/
private static class InstrumentationClassLoader extends java.lang.ClassLoader {
final ClassLoader parent;
public InstrumentationClassLoader(ClassLoader parent) {
super(parent);
this.parent = parent;
}
/** Forcefully inject the bytes of clazz into this classloader. */
public Class<?> shadow(Class<?> clazz) throws IOException {
final ClassFileLocator locator = ClassFileLocator.ForClassLoader.of(clazz.getClassLoader());
final byte[] classBytes = locator.locate(clazz.getName()).resolve();
Class<?> shadowed = this.defineClass(clazz.getName(), classBytes, 0, classBytes.length);
return shadowed;
}
@Override
protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
// TODO: If already loaded and not seen by agent: do a retransform.
return parent.loadClass(name);
}
}
}

View File

@ -0,0 +1,128 @@
package dd.test;
import com.datadoghq.agent.Utils;
import io.opentracing.ActiveSpan;
import io.opentracing.Tracer;
import io.opentracing.util.GlobalTracer;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.jar.JarEntry;
import java.util.jar.JarOutputStream;
import java.util.jar.Manifest;
public class TestUtils {
private static Method findLoadedClassMethod = null;
static {
try {
findLoadedClassMethod = ClassLoader.class.getDeclaredMethod("findLoadedClass", String.class);
} catch (NoSuchMethodException | SecurityException e) {
throw new IllegalStateException(e);
}
}
public static void registerOrReplaceGlobalTracer(Tracer tracer) {
try {
GlobalTracer.register(tracer);
} catch (final Exception e) {
// Force it anyway using reflection
Field field = null;
try {
field = GlobalTracer.class.getDeclaredField("tracer");
field.setAccessible(true);
field.set(null, tracer);
} catch (Exception e2) {
throw new IllegalStateException(e2);
} finally {
if (null != field) {
field.setAccessible(false);
}
}
}
if (!GlobalTracer.isRegistered()) {
throw new RuntimeException("Unable to register the global tracer.");
}
}
public static <T extends Object> Object runUnderTrace(
final String rootOperationName, Callable<T> r) {
ActiveSpan rootSpan = GlobalTracer.get().buildSpan(rootOperationName).startActive();
try {
try {
return r.call();
} catch (Exception e) {
throw new IllegalStateException(e);
}
} finally {
rootSpan.deactivate();
}
}
public static boolean isClassLoaded(String className, ClassLoader classLoader) {
try {
findLoadedClassMethod.setAccessible(true);
Class<?> loadedClass = (Class<?>) findLoadedClassMethod.invoke(classLoader, className);
return null != loadedClass && loadedClass.getClassLoader() == classLoader;
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
throw new IllegalStateException(e);
} finally {
findLoadedClassMethod.setAccessible(false);
}
}
/**
* Create a temporary jar on the filesystem with the bytes of the given classes.
*
* <p>The jar file will be removed when the jvm exits.
*
* @param classes classes to package into the jar.
* @return the location of the newly created jar.
* @throws IOException
*/
public static URL createJarWithClasses(Class<?>... classes) throws IOException {
final File tmpJar = File.createTempFile(UUID.randomUUID().toString() + "", ".jar");
tmpJar.deleteOnExit();
final Manifest manifest = new Manifest();
JarOutputStream target = new JarOutputStream(new FileOutputStream(tmpJar), manifest);
for (Class<?> clazz : classes) {
addToJar(clazz, target);
}
target.close();
return tmpJar.toURI().toURL();
}
private static void addToJar(Class<?> clazz, JarOutputStream jarOutputStream) throws IOException {
InputStream inputStream = null;
try {
JarEntry entry = new JarEntry(Utils.getResourceName(clazz.getName()));
jarOutputStream.putNextEntry(entry);
inputStream =
clazz.getClassLoader().getResourceAsStream(Utils.getResourceName(clazz.getName()));
byte[] buffer = new byte[1024];
while (true) {
int count = inputStream.read(buffer);
if (count == -1) {
break;
}
jarOutputStream.write(buffer, 0, count);
}
jarOutputStream.closeEntry();
} finally {
if (inputStream != null) {
inputStream.close();
}
}
}
}

View File

@ -1,10 +0,0 @@
package dd.test;
import com.datadoghq.agent.TracingAgent;
public class TestUtils2 {
public static void canSeeAgent() {
System.out.println("Can I see the agent? " + TracingAgent.getAgentClassLoader());
}
}

View File

@ -2,6 +2,7 @@ apply from: "${rootDir}/gradle/java.gradle"
dependencies {
compile deps.bytebuddy
compile deps.bytebuddyagent
compile deps.slf4j
compile deps.opentracing
compile deps.spock

View File

@ -1,120 +0,0 @@
package dd.test
import dd.trace.Instrumenter
import io.opentracing.ActiveSpan
import io.opentracing.Tracer
import io.opentracing.util.GlobalTracer
import java.lang.reflect.Method
import java.util.jar.JarEntry
import java.util.jar.JarOutputStream
import java.util.jar.Manifest
import net.bytebuddy.agent.ByteBuddyAgent
import net.bytebuddy.agent.builder.AgentBuilder
import java.lang.reflect.Field
import java.util.concurrent.Callable
import static net.bytebuddy.matcher.ElementMatchers.nameStartsWith
import static org.assertj.core.api.Assertions.assertThat
class TestUtils {
static addByteBuddyAgent() {
AgentBuilder builder =
new AgentBuilder.Default()
.disableClassFormatChanges()
.with(AgentBuilder.RedefinitionStrategy.RETRANSFORMATION)
// .with(AgentBuilder.Listener.StreamWriting.toSystemError())
.ignore(nameStartsWith("dd.inst"))
def instrumenters = ServiceLoader.load(Instrumenter)
for (final Instrumenter instrumenter : instrumenters) {
System.err.println("Instrumenting with " + instrumenter)
builder = instrumenter.instrument(builder)
}
builder.installOn(ByteBuddyAgent.install())
}
static registerOrReplaceGlobalTracer(Tracer tracer) {
try {
GlobalTracer.register(tracer)
} catch (final Exception e) {
// Force it anyway using reflection
final Field field = GlobalTracer.getDeclaredField("tracer")
field.setAccessible(true)
field.set(null, tracer)
}
assertThat(GlobalTracer.isRegistered()).isTrue()
}
static runUnderTrace(final String rootOperationName, Callable r) {
ActiveSpan rootSpan = GlobalTracer.get().buildSpan(rootOperationName).startActive()
try {
return r.call()
} finally {
rootSpan.deactivate()
}
}
private static Method findLoadedClassMethod = ClassLoader.getDeclaredMethod("findLoadedClass", String)
static boolean isClassLoaded(String className, ClassLoader classLoader) {
try {
findLoadedClassMethod.setAccessible(true)
return null != findLoadedClassMethod.invoke(classLoader, className)
} finally {
findLoadedClassMethod.setAccessible(false)
}
}
/** com.foo.Bar -> com/foo/Bar.class */
static String getResourceName(Class<?> clazz) {
return clazz.getName().replace('.', '/') + ".class"
}
/**
* Create a temporary jar on the filesystem with the bytes of the given classes.
*
* <p>The jar file will be removed when the jvm exits.
*
* @param classes classes to package into the jar.
* @return the location of the newly created jar.
* @throws IOException
*/
static URL createJarWithClasses(Class<?>... classes) throws IOException {
final File tmpJar = File.createTempFile(UUID.randomUUID().toString() + "", ".jar")
tmpJar.deleteOnExit()
final Manifest manifest = new Manifest()
JarOutputStream target = new JarOutputStream(new FileOutputStream(tmpJar), manifest)
for (Class<?> clazz : classes) {
addToJar(clazz, target)
}
target.close()
return tmpJar.toURI().toURL()
}
static void addToJar(Class<?> clazz, JarOutputStream jarOutputStream) throws IOException {
InputStream inputStream = null
try {
JarEntry entry = new JarEntry(getResourceName(clazz))
jarOutputStream.putNextEntry(entry)
inputStream = clazz.getClassLoader().getResourceAsStream(getResourceName(clazz))
byte[] buffer = new byte[1024]
while (true) {
int count = inputStream.read(buffer)
if (count == -1) {
break
}
jarOutputStream.write(buffer, 0, count)
}
jarOutputStream.closeEntry()
} finally {
if (inputStream != null) {
inputStream.close()
}
}
}
}