diff --git a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/ContextStore.java b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/ContextStore.java
new file mode 100644
index 0000000000..bc61140f05
--- /dev/null
+++ b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/ContextStore.java
@@ -0,0 +1,59 @@
+package datadog.trace.bootstrap;
+
+/**
+ * Interface to represent context storage for instrumentations.
+ *
+ *
Context instances are weakly referenced and will be garbage collected when their corresponding
+ * key instance is collected.
+ *
+ * @param key type to do context lookups
+ * @param context type
+ */
+public interface ContextStore {
+
+ /**
+ * Factory interface to create context instances
+ *
+ * @param context type
+ */
+ interface Factory {
+
+ /** @return new context instance */
+ C create();
+ }
+
+ /**
+ * Get context given the key
+ *
+ * @param key the key to looup
+ * @return context object
+ */
+ C get(K key);
+
+ /**
+ * Put new context instance for given key
+ *
+ * @param key key to use
+ * @param context context instance to save
+ */
+ void put(K key, C context);
+
+ /**
+ * Put new context instance if key is absent
+ *
+ * @param key key to use
+ * @param context new context instance to put
+ * @return old instance if it was present, or new instance
+ */
+ C putIfAbsent(K key, C context);
+
+ /**
+ * Put new context instance if key is absent. Uses context factory to avoid creating objects if
+ * not needed.
+ *
+ * @param key key to use
+ * @param contextFactory factory instance to produce new context object
+ * @return old instance if it was present, or new instance
+ */
+ C putIfAbsent(K key, Factory contextFactory);
+}
diff --git a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java
index 06e689440c..03332f8dce 100644
--- a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java
+++ b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java
@@ -5,39 +5,22 @@ public class InstrumentationContext {
private InstrumentationContext() {}
/**
- * Fetch a context instance out of the context store.
+ * Find a {@link ContextStore} instance for given key class and context class.
*
- *
- *
- *
Conceptually, this can be thought of as a two pass map look up.
- *
- *
For example: RunnableState runnableState = get(runnableImpl, Runnable.class,
- * RunnableState.class) --> RunnableState runnableState = (RunnableState)
- * GlobalContextMap.get(Runnable.class).get(runnableImpl)
- *
- *
+ *
Conceptually this can be thought of as a map lookup to fetch a second level map given
+ * keyClass.
*
*
However, the implementation is actually provided by bytecode transformation for performance
* reasons.
*
- *
- *
- *
Context classes are weakly referenced and will be garbage collected when their corresponding
- * user instance is collected.
- *
- *
- *
- *
Instrumenters making this call must define the user-context class relationship in
- * datadog.trace.agent.tooling.Instrumenter.Default#contextStore.
- *
- * @param userInstance The instance to store context on.
- * @param userClass The user class context is attached to.
+ * @param keyClass The key class context is attached to.
* @param contextClass The context class attached to the user class.
- * @param user class
- * @param context class
- * @return The context instance attached to userInstance.
+ * @param key class
+ * @param context class
+ * @return The instance of context store for given arguments.
*/
- public static V get(K userInstance, Class userClass, Class contextClass) {
- throw new RuntimeException("calls to this method will be rewritten");
+ public static ContextStore get(
+ final Class keyClass, final Class contextClass) {
+ throw new RuntimeException("Calls to this method will be rewritten by MapBackedProvider");
}
}
diff --git a/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java b/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java
index 4fafaed14a..557e2f0dfd 100644
--- a/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java
+++ b/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java
@@ -5,6 +5,7 @@ import static datadog.trace.agent.tooling.ClassLoaderMatcher.BOOTSTRAP_CLASSLOAD
import datadog.trace.agent.tooling.HelperInjector;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.Utils;
+import datadog.trace.bootstrap.ContextStore;
import datadog.trace.bootstrap.InstrumentationContext;
import datadog.trace.bootstrap.WeakMap;
import java.lang.reflect.Method;
@@ -37,8 +38,8 @@ import net.bytebuddy.utility.JavaModule;
* This is accomplished by
*
*
- * Injecting a Dynamic Class to store a static map
- * Rewritting calls to the context-store to access the map on the dynamic class
+ * Injecting a Dynamic Class created from {@link MapHolder} to store a static map
+ * Rewritting calls to the context-store to access the specific dynamic {@link MapHolder}
*
*
* Storing the map on a dynamic class and doing bytecode rewrites allows for a 1-pass lookup.
@@ -47,11 +48,8 @@ import net.bytebuddy.utility.JavaModule;
* Example:
* InstrumentationContext.get(runnableInstance, Runnable.class, RunnableState.class)")
* is rewritten to:
- * RunnableInstrumentation$ContextStore$RunnableState12345.getOrCreate(runnableInstance,
- * Runnable.class, RunnableState.class)
- *
- *
Map lookup implementation defined in template class: {@link MapHolder#getOrCreate(Object,
- * Class, Class)}
+ * RunnableInstrumentation$ContextStore$RunnableState12345.getMapHolder(runnableRunnable.class,
+ * RunnableState.class)
*/
@Slf4j
public class MapBackedProvider implements InstrumentationContextProvider {
@@ -60,11 +58,9 @@ public class MapBackedProvider implements InstrumentationContextProvider {
static {
try {
- contextGetMethod =
- InstrumentationContext.class.getMethod("get", Object.class, Class.class, Class.class);
- mapGetMethod =
- MapHolder.class.getMethod("getOrCreate", Object.class, Class.class, Class.class);
- } catch (Exception e) {
+ contextGetMethod = InstrumentationContext.class.getMethod("get", Class.class, Class.class);
+ mapGetMethod = MapHolder.class.getMethod("getMapHolder", Class.class, Class.class);
+ } catch (final Exception e) {
throw new IllegalStateException(e);
}
}
@@ -72,16 +68,17 @@ public class MapBackedProvider implements InstrumentationContextProvider {
/** dynamic-class-name -> dynamic-class-bytes */
private final AtomicReference> dynamicClasses = new AtomicReference<>(null);
- /** user-class-name -> dynamic-class-name */
+ /** key-class-name -> dynamic-class-name */
private final AtomicReference> dynamicClassNames =
new AtomicReference<>(null);
private final Instrumenter.Default instrumenter;
- public MapBackedProvider(Instrumenter.Default instrumenter) {
+ public MapBackedProvider(final Instrumenter.Default instrumenter) {
this.instrumenter = instrumenter;
}
+ @Override
public AgentBuilder.Identified.Extendable instrumentationTransformer(
AgentBuilder.Identified.Extendable builder) {
if (instrumenter.contextStore().size() > 0) {
@@ -90,10 +87,10 @@ public class MapBackedProvider implements InstrumentationContextProvider {
new AgentBuilder.Transformer() {
@Override
public DynamicType.Builder> transform(
- DynamicType.Builder> builder,
- TypeDescription typeDescription,
- ClassLoader classLoader,
- JavaModule module) {
+ final DynamicType.Builder> builder,
+ final TypeDescription typeDescription,
+ final ClassLoader classLoader,
+ final JavaModule module) {
return builder.visit(getInstrumentationVisitor());
}
});
@@ -104,10 +101,10 @@ public class MapBackedProvider implements InstrumentationContextProvider {
@Override
public DynamicType.Builder> transform(
- DynamicType.Builder> builder,
- TypeDescription typeDescription,
- ClassLoader classLoader,
- JavaModule module) {
+ final DynamicType.Builder> builder,
+ final TypeDescription typeDescription,
+ final ClassLoader classLoader,
+ final JavaModule module) {
return injector.transform(
builder,
typeDescription,
@@ -123,41 +120,45 @@ public class MapBackedProvider implements InstrumentationContextProvider {
private AsmVisitorWrapper getInstrumentationVisitor() {
return new AsmVisitorWrapper() {
@Override
- public int mergeWriter(int flags) {
+ public int mergeWriter(final int flags) {
return flags | ClassWriter.COMPUTE_MAXS;
}
@Override
- public int mergeReader(int flags) {
+ public int mergeReader(final int flags) {
return flags;
}
@Override
public ClassVisitor wrap(
final TypeDescription instrumentedType,
- ClassVisitor classVisitor,
- Implementation.Context implementationContext,
- TypePool typePool,
- FieldList fields,
- MethodList> methods,
- int writerFlags,
- int readerFlags) {
+ final ClassVisitor classVisitor,
+ final Implementation.Context implementationContext,
+ final TypePool typePool,
+ final FieldList fields,
+ final MethodList> methods,
+ final int writerFlags,
+ final int readerFlags) {
generateMapHolderClasses();
return new ClassVisitor(Opcodes.ASM7, classVisitor) {
@Override
public void visit(
- int version,
- int access,
- String name,
- String signature,
- String superName,
- String[] interfaces) {
+ final int version,
+ final int access,
+ final String name,
+ final String signature,
+ final String superName,
+ final String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
}
@Override
public MethodVisitor visitMethod(
- int access, String name, String descriptor, String signature, String[] exceptions) {
+ final int access,
+ final String name,
+ final String descriptor,
+ final String signature,
+ final String[] exceptions) {
final MethodVisitor mv =
super.visitMethod(access, name, descriptor, signature, exceptions);
return new MethodVisitor(Opcodes.ASM7, mv) {
@@ -168,7 +169,11 @@ public class MapBackedProvider implements InstrumentationContextProvider {
@Override
public void visitMethodInsn(
- int opcode, String owner, String name, String descriptor, boolean isInterface) {
+ final int opcode,
+ final String owner,
+ final String name,
+ final String descriptor,
+ final boolean isInterface) {
pushOpcode(opcode);
if (Utils.getInternalName(contextGetMethod.getDeclaringClass().getName())
.equals(owner)
@@ -180,28 +185,38 @@ public class MapBackedProvider implements InstrumentationContextProvider {
&& insnStack[2] == Opcodes.LDC)
&& (stack[0] instanceof Type && stack[1] instanceof Type)) {
final String contextClassName = ((Type) stack[0]).getClassName();
- final String userClassName = ((Type) stack[1]).getClassName();
- final String mapHolderClass = dynamicClassNames.get().get(userClassName);
+ final String keyClassName = ((Type) stack[1]).getClassName();
+ final String mapHolderClass = dynamicClassNames.get().get(keyClassName);
log.debug(
"Rewriting context-store map fetch for instrumenter {}: {} -> {}",
instrumenter.getClass().getName(),
- userClassName,
+ keyClassName,
contextClassName);
- if (mapHolderClass != null
- && contextClassName.equals(
- instrumenter.contextStore().get(userClassName))) {
- // stack: contextClass | userClass | instance
- mv.visitMethodInsn(
- Opcodes.INVOKESTATIC,
- Utils.getInternalName(mapHolderClass),
- mapGetMethod.getName(),
- Type.getMethodDescriptor(mapGetMethod),
- false);
- return;
+ if (mapHolderClass == null) {
+ throw new IllegalStateException(
+ String.format(
+ "Incorrect Context Api Usage detected. Cannot find map holder class for %s. Was that class defined in contextStore for instrumentation %s?",
+ keyClassName, instrumenter.getClass().getName()));
}
+ if (!contextClassName.equals(instrumenter.contextStore().get(keyClassName))) {
+ throw new IllegalStateException(
+ String.format(
+ "Incorrect Context Api Usage detected. Incorrect context class %s, expected %s for instrumentation %s",
+ contextClassName,
+ instrumenter.contextStore().get(keyClassName),
+ instrumenter.getClass().getName()));
+ }
+ // stack: contextClass | keyClass
+ mv.visitMethodInsn(
+ Opcodes.INVOKESTATIC,
+ Utils.getInternalName(mapHolderClass),
+ mapGetMethod.getName(),
+ Type.getMethodDescriptor(mapGetMethod),
+ false);
+ return;
}
throw new IllegalStateException(
- "Incorrect Context Api Usage detected. User and context class must be class-literals. Example of correct usage: InstrumentationContext.get(runnableInstance, Runnable.class, RunnableState.class)");
+ "Incorrect Context Api Usage detected. Key and context class must be class-literals. Example of correct usage: InstrumentationContext.get(Runnable.class, RunnableContext.class)");
} else {
super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
}
@@ -216,38 +231,38 @@ public class MapBackedProvider implements InstrumentationContextProvider {
/**
* Tracking the most recently pushed objects on the stack to assert proper api usage.
*/
- private void pushStack(Object o) {
+ private void pushStack(final Object o) {
System.arraycopy(stack, 0, stack, 1, stack.length - 1);
stack[0] = o;
}
@Override
- public void visitInsn(int opcode) {
+ public void visitInsn(final int opcode) {
pushOpcode(opcode);
super.visitInsn(opcode);
}
@Override
- public void visitJumpInsn(int opcode, Label label) {
+ public void visitJumpInsn(final int opcode, final Label label) {
pushOpcode(opcode);
super.visitJumpInsn(opcode, label);
}
@Override
- public void visitIntInsn(int opcode, int operand) {
+ public void visitIntInsn(final int opcode, final int operand) {
pushOpcode(opcode);
super.visitIntInsn(opcode, operand);
}
@Override
- public void visitVarInsn(int opcode, int var) {
+ public void visitVarInsn(final int opcode, final int var) {
pushOpcode(opcode);
pushStack(var);
super.visitVarInsn(opcode, var);
}
@Override
- public void visitLdcInsn(Object value) {
+ public void visitLdcInsn(final Object value) {
pushOpcode(Opcodes.LDC);
pushStack(value);
super.visitLdcInsn(value);
@@ -265,7 +280,7 @@ public class MapBackedProvider implements InstrumentationContextProvider {
@Override
public AgentBuilder.Identified.Extendable additionalInstrumentation(
- AgentBuilder.Identified.Extendable builder) {
+ final AgentBuilder.Identified.Extendable builder) {
return builder;
}
@@ -281,14 +296,14 @@ public class MapBackedProvider implements InstrumentationContextProvider {
final Map dynamicClasses = new HashMap<>(instrumenter.contextStore().size());
final Map dynamicClassNames =
new HashMap<>(instrumenter.contextStore().size());
- for (final String userClassName : instrumenter.contextStore().keySet()) {
+ for (final String keyClassName : instrumenter.contextStore().keySet()) {
final String dynamicClassName =
instrumenter.getClass().getName()
+ "$ContextStore"
- + userClassName.replaceAll(".*([^\\.]+)$", "\\1")
+ + keyClassName.replaceAll(".*([^\\.]+)$", "\\1")
+ UUID.randomUUID().toString().replace('-', '_');
- dynamicClassNames.put(userClassName, dynamicClassName);
+ dynamicClassNames.put(keyClassName, dynamicClassName);
dynamicClasses.put(dynamicClassName, makeMapHolderClass(dynamicClassName));
}
this.dynamicClassNames.compareAndSet(null, dynamicClassNames);
@@ -296,7 +311,7 @@ public class MapBackedProvider implements InstrumentationContextProvider {
}
}
- private byte[] makeMapHolderClass(String className) {
+ private byte[] makeMapHolderClass(final String className) {
return new ByteBuddy()
.rebase(MapHolder.class)
.modifiers(Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL)
@@ -306,41 +321,62 @@ public class MapBackedProvider implements InstrumentationContextProvider {
}
/** Template class used to generate the class holding the global map. */
- private static final class MapHolder {
- public static final WeakMap MAP = WeakMap.Provider.newWeakMap();
+ private static final class MapHolder implements ContextStore {
+ private static final MapHolder INSTANCE = new MapHolder(WeakMap.Provider.newWeakMap());
- /**
- * Fetch a context class out of the backing map. Create and return a new context class if none
- * currently exists.
- *
- * This method is thread safe.
- */
- public static Object getOrCreate(Object instance, Class userClass, Class contextClass) {
- if (!userClass.isAssignableFrom(instance.getClass())) {
- throw new RuntimeException(
- "Illegal context lookup. "
- + instance.getClass().getName()
- + " cannot be cast to "
- + userClass.getName());
- }
- Object contextInstance = MAP.get(instance);
- if (null == contextInstance) {
- synchronized (instance) {
- contextInstance = MAP.get(instance);
- if (null == contextInstance) {
- try {
- contextInstance = contextClass.newInstance();
- MAP.put(instance, contextInstance);
- } catch (Exception e) {
- throw new RuntimeException(
- contextClass.getName() + " must define a public, no-arg constructor.", e);
- }
- }
- }
- }
- return contextInstance;
+ private final WeakMap map;
+
+ private MapHolder(final WeakMap map) {
+ this.map = map;
}
- private MapHolder() {}
+ @Override
+ public Object get(final Object key) {
+ return map.get(key);
+ }
+
+ @Override
+ public Object putIfAbsent(final Object key, final Object context) {
+ Object existingContext = map.get(key);
+ if (null != existingContext) {
+ return existingContext;
+ }
+ synchronized (map) {
+ existingContext = map.get(key);
+ if (null != existingContext) {
+ return existingContext;
+ }
+ map.put(key, context);
+ return context;
+ }
+ }
+
+ @Override
+ public Object putIfAbsent(final Object key, final ContextStore.Factory contextFactory) {
+ Object existingContext = map.get(key);
+ if (null != existingContext) {
+ return existingContext;
+ }
+ synchronized (map) {
+ existingContext = map.get(key);
+ if (null != existingContext) {
+ return existingContext;
+ }
+ final Object context = contextFactory.create();
+ map.put(key, context);
+ return context;
+ }
+ }
+
+ @Override
+ public void put(final Object key, final Object context) {
+ map.put(key, context);
+ }
+
+ public static ContextStore getMapHolder(final Class keyClass, final Class contextClass) {
+ // We do not actually check the keyClass here - but that should be fine since compiler would
+ // check things for us.
+ return INSTANCE;
+ }
}
}
diff --git a/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java b/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java
index 6ed7386198..6a22d6507d 100644
--- a/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java
+++ b/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java
@@ -12,6 +12,7 @@ import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.api.DDSpanTypes;
import datadog.trace.api.DDTags;
import datadog.trace.bootstrap.CallDepthThreadLocalMap;
+import datadog.trace.bootstrap.ContextStore;
import datadog.trace.bootstrap.InstrumentationContext;
import io.opentracing.Scope;
import io.opentracing.Span;
@@ -48,9 +49,13 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default {
@Override
public String[] helperClassNames() {
- return new String[] {HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState"};
+ return new String[] {
+ HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState",
+ HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState$1"
+ };
}
+ @Override
public Map contextStore() {
return Collections.singletonMap(
"java.net.HttpURLConnection", getClass().getName() + "$HttpURLState");
@@ -72,9 +77,9 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default {
@Advice.This final HttpURLConnection thiz,
@Advice.FieldValue("connected") final boolean connected,
@Advice.Origin("#m") final String methodName) {
-
- final HttpURLState state =
- InstrumentationContext.get(thiz, HttpURLConnection.class, HttpURLState.class);
+ final ContextStore contextStore =
+ InstrumentationContext.get(HttpURLConnection.class, HttpURLState.class);
+ final HttpURLState state = contextStore.putIfAbsent(thiz, HttpURLState.FACTORY);
String operationName = "http.request";
@@ -183,6 +188,15 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default {
}
public static class HttpURLState {
+
+ public static final ContextStore.Factory FACTORY =
+ new ContextStore.Factory() {
+ @Override
+ public HttpURLState create() {
+ return new HttpURLState();
+ }
+ };
+
public boolean calledOutputStream = false;
public boolean calledInputStream = false;
}
diff --git a/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy b/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy
index 50261197f9..e351fcea25 100644
--- a/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy
+++ b/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy
@@ -6,10 +6,15 @@ import net.bytebuddy.utility.JavaModule
import java.lang.ref.WeakReference
+import static context.ContextTestInstrumentation.IncorrectCallUsageKeyClass
+import static context.ContextTestInstrumentation.IncorrectContextClassUsageKeyClass
+import static context.ContextTestInstrumentation.IncorrectKeyClassUsageKeyClass
+import static context.ContextTestInstrumentation.KeyClass
+
class MapBackedProviderTest extends AgentTestRunner {
def setupSpec() {
- assert new UserClass1().isInstrumented()
+ assert new KeyClass().isInstrumented()
}
@Override
@@ -19,14 +24,16 @@ class MapBackedProviderTest extends AgentTestRunner {
final JavaModule module,
final boolean loaded,
final Throwable throwable) {
- // UserClass2 asserts on incorrect api usage. Error expected.
- return !(typeName.equals(UserClass2.getName()) && throwable.getMessage().startsWith("Incorrect Context Api Usage detected."))
+ // Incorrect* classes assert on incorrect api usage. Error expected.
+ return !(typeName.startsWith(ContextTestInstrumentation.getName() + "$Incorrect") && throwable.getMessage().startsWith("Incorrect Context Api Usage detected."))
}
def "correct api usage stores state in map"() {
+ setup:
+ KeyClass instance1 = new KeyClass()
+ KeyClass instance2 = new KeyClass()
+
when:
- UserClass1 instance1 = new UserClass1()
- UserClass1 instance2 = new UserClass1()
instance1.incrementContextCount()
then:
@@ -34,11 +41,22 @@ class MapBackedProviderTest extends AgentTestRunner {
instance2.incrementContextCount() == 1
}
- def "backing map should not create strong refs to user instances"() {
+ def "get/put test"() {
+ setup:
+ KeyClass instance1 = new KeyClass()
+
when:
- UserClass1 instance = new UserClass1()
+ instance1.putContextCount(10)
+
+ then:
+ instance1.getContextCount() == 10
+ }
+
+ def "backing map should not create strong refs to key class instances"() {
+ when:
+ KeyClass instance = new KeyClass()
final int count = instance.incrementContextCount()
- WeakReference instanceRef = new WeakReference(instance)
+ WeakReference instanceRef = new WeakReference(instance)
instance = null
TestUtils.awaitGC(instanceRef)
@@ -47,8 +65,18 @@ class MapBackedProviderTest extends AgentTestRunner {
count == 1
}
- def "incorrect api usage fails at class load time"() {
+ def "incorrect key class usage fails at class load time"() {
expect:
- !new UserClass2().isInstrumented()
+ !new IncorrectKeyClassUsageKeyClass().isInstrumented()
+ }
+
+ def "incorrect context class usage fails at class load time"() {
+ expect:
+ !new IncorrectContextClassUsageKeyClass().isInstrumented()
+ }
+
+ def "incorrect call usage fails at class load time"() {
+ expect:
+ !new IncorrectCallUsageKeyClass().isInstrumented()
}
}
diff --git a/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java b/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java
index 9f87c589b8..d038eaad91 100644
--- a/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java
+++ b/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java
@@ -1,9 +1,11 @@
package context;
+import static net.bytebuddy.matcher.ElementMatchers.nameStartsWith;
import static net.bytebuddy.matcher.ElementMatchers.named;
import com.google.auto.service.AutoService;
import datadog.trace.agent.tooling.Instrumenter;
+import datadog.trace.bootstrap.ContextStore;
import datadog.trace.bootstrap.InstrumentationContext;
import java.util.Collections;
import java.util.HashMap;
@@ -20,59 +22,181 @@ public class ContextTestInstrumentation extends Instrumenter.Default {
@Override
public ElementMatcher super TypeDescription> typeMatcher() {
- return named(getClass().getPackage().getName() + ".UserClass1")
- .or(named(getClass().getPackage().getName() + ".UserClass2"));
+ return nameStartsWith(getClass().getName() + "$");
}
@Override
public Map extends ElementMatcher, String> transformers() {
final Map transformers = new HashMap<>(2);
transformers.put(named("isInstrumented"), MarkInstrumentedAdvice.class.getName());
- transformers.put(named("incrementContextCount"), CorrectContextApiUsageAdvice.class.getName());
transformers.put(
- named("incrementContextCountCountBroken"), IncorrectContextApiUsageAdvice.class.getName());
+ named("incrementContextCount"), StoreAndIncrementApiUsageAdvice.class.getName());
+ transformers.put(named("getContextCount"), GetApiUsageAdvice.class.getName());
+ transformers.put(named("putContextCount"), PutApiUsageAdvice.class.getName());
+ transformers.put(
+ named("incorrectKeyClassUsage"), IncorrectKeyClassContextApiUsageAdvice.class.getName());
+ transformers.put(
+ named("incorrectContextClassUsage"),
+ IncorrectContextClassContextApiUsageAdvice.class.getName());
+ transformers.put(
+ named("incorrectCallUsage"), IncorrectCallContextApiUsageAdvice.class.getName());
return transformers;
}
@Override
public String[] helperClassNames() {
- return new String[] {getClass().getName() + "$UserClass1State"};
+ return new String[] {getClass().getName() + "$Context", getClass().getName() + "$Context$1"};
}
@Override
public Map contextStore() {
return Collections.singletonMap(
- getClass().getPackage().getName() + ".UserClass1",
- getClass().getName() + "$UserClass1State");
+ getClass().getName() + "$KeyClass", getClass().getName() + "$Context");
}
public static class MarkInstrumentedAdvice {
@Advice.OnMethodExit
- public static void markInstrumented(@Advice.Return(readOnly = false) boolean isInstrumented) {
+ public static void methodExit(@Advice.Return(readOnly = false) boolean isInstrumented) {
isInstrumented = true;
}
}
- public static class CorrectContextApiUsageAdvice {
+ public static class StoreAndIncrementApiUsageAdvice {
@Advice.OnMethodExit
- public static void storeAndIncrement(
- @Advice.This Object thiz, @Advice.Return(readOnly = false) int contextCount) {
- UserClass1State state =
- InstrumentationContext.get((UserClass1) thiz, UserClass1.class, UserClass1State.class);
- contextCount = ++state.count;
+ public static void methodExit(
+ @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) {
+ final ContextStore contextStore =
+ InstrumentationContext.get(KeyClass.class, Context.class);
+ final Context context = contextStore.putIfAbsent(thiz, new Context());
+ contextCount = ++context.count;
}
}
- public static class IncorrectContextApiUsageAdvice {
+ public static class StoreAndIncrementWithFactoryApiUsageAdvice {
@Advice.OnMethodExit
- public static void storeAndIncrement(
- @Advice.This Object thiz, @Advice.Return(readOnly = false) int contextCount) {
- UserClass1State state = InstrumentationContext.get(thiz, Object.class, UserClass1State.class);
- contextCount = ++state.count;
+ public static void methodExit(
+ @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) {
+ final ContextStore contextStore =
+ InstrumentationContext.get(KeyClass.class, Context.class);
+ final Context context = contextStore.putIfAbsent(thiz, Context.FACTORY);
+ contextCount = ++context.count;
}
}
- public static class UserClass1State {
+ public static class GetApiUsageAdvice {
+ @Advice.OnMethodExit
+ public static void methodExit(
+ @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) {
+ final ContextStore contextStore =
+ InstrumentationContext.get(KeyClass.class, Context.class);
+ contextCount = contextStore.get(thiz).count;
+ }
+ }
+
+ public static class PutApiUsageAdvice {
+ @Advice.OnMethodExit
+ public static void methodExit(
+ @Advice.This final KeyClass thiz, @Advice.Argument(0) final int value) {
+ final ContextStore contextStore =
+ InstrumentationContext.get(KeyClass.class, Context.class);
+ final Context context = new Context();
+ context.count = value;
+ contextStore.put(thiz, context);
+ }
+ }
+
+ public static class IncorrectKeyClassContextApiUsageAdvice {
+ @Advice.OnMethodExit
+ public static void methodExit() {
+ InstrumentationContext.get(Object.class, Context.class);
+ }
+ }
+
+ public static class IncorrectContextClassContextApiUsageAdvice {
+ @Advice.OnMethodExit
+ public static void methodExit() {
+ InstrumentationContext.get(KeyClass.class, Object.class);
+ }
+ }
+
+ public static class IncorrectCallContextApiUsageAdvice {
+ @Advice.OnMethodExit
+ public static void methodExit() {
+ // Our instrumentation doesn't handle variables being passed to InstrumentationContext.get,
+ // so we make sure that this actually fails instrumentation.
+ final Class clazz = null;
+ InstrumentationContext.get(clazz, Object.class);
+ }
+ }
+
+ public static class Context {
+ public static final ContextStore.Factory FACTORY =
+ new ContextStore.Factory() {
+ @Override
+ public Context create() {
+ return new Context();
+ }
+ };
+
int count = 0;
}
+
+ public static class KeyClass {
+ public boolean isInstrumented() {
+ // implementation replaced with test instrumentation
+ return false;
+ }
+
+ public int incrementContextCount() {
+ // implementation replaced with test instrumentation
+ return -1;
+ }
+
+ public int incrementContextCountWithFactory() {
+ // implementation replaced with test instrumentation
+ return -1;
+ }
+
+ public int getContextCount() {
+ // implementation replaced with test instrumentation
+ return -1;
+ }
+
+ public void putContextCount(final int value) {
+ // implementation replaced with test instrumentation
+ }
+ }
+
+ public static class IncorrectKeyClassUsageKeyClass {
+ public boolean isInstrumented() {
+ return false;
+ }
+
+ public int incorrectKeyClassUsage() {
+ // instrumentation will not apply to this class because advice incorrectly uses context api
+ return -1;
+ }
+ }
+
+ public static class IncorrectContextClassUsageKeyClass {
+ public boolean isInstrumented() {
+ return false;
+ }
+
+ public int incorrectContextClassUsage() {
+ // instrumentation will not apply to this class because advice incorrectly uses context api
+ return -1;
+ }
+ }
+
+ public static class IncorrectCallUsageKeyClass {
+ public boolean isInstrumented() {
+ return false;
+ }
+
+ public int incorrectCallUsage() {
+ // instrumentation will not apply to this class because advice incorrectly uses context api
+ return -1;
+ }
+ }
}
diff --git a/dd-java-agent/testing/src/test/java/context/UserClass1.java b/dd-java-agent/testing/src/test/java/context/UserClass1.java
deleted file mode 100644
index b2d5ced2f1..0000000000
--- a/dd-java-agent/testing/src/test/java/context/UserClass1.java
+++ /dev/null
@@ -1,13 +0,0 @@
-package context;
-
-public class UserClass1 {
- public boolean isInstrumented() {
- // implementation replaced with test instrumentation
- return false;
- }
-
- public int incrementContextCount() {
- // implementation replaced with test instrumentation
- return -1;
- }
-}
diff --git a/dd-java-agent/testing/src/test/java/context/UserClass2.java b/dd-java-agent/testing/src/test/java/context/UserClass2.java
deleted file mode 100644
index 9d676db335..0000000000
--- a/dd-java-agent/testing/src/test/java/context/UserClass2.java
+++ /dev/null
@@ -1,12 +0,0 @@
-package context;
-
-public class UserClass2 {
- public boolean isInstrumented() {
- return false;
- }
-
- public int incrementContextCountCountBroken() {
- // instrumentation will not apply to this class because advice incorrectly uses context api
- return -1;
- }
-}