From 98df97af3e1ab562b2a01b16aae6d9bc21a6d3da Mon Sep 17 00:00:00 2001 From: Nikolay Martynov Date: Thu, 1 Nov 2018 16:44:13 -0400 Subject: [PATCH] First iteration on a better context store api Separate context storage from actual fetching/putting --- .../datadog/trace/bootstrap/ContextStore.java | 59 +++++ .../bootstrap/InstrumentationContext.java | 37 +-- .../tooling/context/MapBackedProvider.java | 232 ++++++++++-------- .../HttpUrlConnectionInstrumentation.java | 22 +- .../context/MapBackedProviderTest.groovy | 48 +++- .../context/ContextTestInstrumentation.java | 164 +++++++++++-- .../src/test/java/context/UserClass1.java | 13 - .../src/test/java/context/UserClass2.java | 12 - 8 files changed, 403 insertions(+), 184 deletions(-) create mode 100644 dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/ContextStore.java delete mode 100644 dd-java-agent/testing/src/test/java/context/UserClass1.java delete mode 100644 dd-java-agent/testing/src/test/java/context/UserClass2.java diff --git a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/ContextStore.java b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/ContextStore.java new file mode 100644 index 0000000000..bc61140f05 --- /dev/null +++ b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/ContextStore.java @@ -0,0 +1,59 @@ +package datadog.trace.bootstrap; + +/** + * Interface to represent context storage for instrumentations. + * + *

Context instances are weakly referenced and will be garbage collected when their corresponding + * key instance is collected. + * + * @param key type to do context lookups + * @param context type + */ +public interface ContextStore { + + /** + * Factory interface to create context instances + * + * @param context type + */ + interface Factory { + + /** @return new context instance */ + C create(); + } + + /** + * Get context given the key + * + * @param key the key to looup + * @return context object + */ + C get(K key); + + /** + * Put new context instance for given key + * + * @param key key to use + * @param context context instance to save + */ + void put(K key, C context); + + /** + * Put new context instance if key is absent + * + * @param key key to use + * @param context new context instance to put + * @return old instance if it was present, or new instance + */ + C putIfAbsent(K key, C context); + + /** + * Put new context instance if key is absent. Uses context factory to avoid creating objects if + * not needed. + * + * @param key key to use + * @param contextFactory factory instance to produce new context object + * @return old instance if it was present, or new instance + */ + C putIfAbsent(K key, Factory contextFactory); +} diff --git a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java index 06e689440c..03332f8dce 100644 --- a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java +++ b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/InstrumentationContext.java @@ -5,39 +5,22 @@ public class InstrumentationContext { private InstrumentationContext() {} /** - * Fetch a context instance out of the context store. + * Find a {@link ContextStore} instance for given key class and context class. * - *

- * - *

Conceptually, this can be thought of as a two pass map look up. - * - *

For example: RunnableState runnableState = get(runnableImpl, Runnable.class, - * RunnableState.class) --> RunnableState runnableState = (RunnableState) - * GlobalContextMap.get(Runnable.class).get(runnableImpl) - * - *

+ *

Conceptually this can be thought of as a map lookup to fetch a second level map given + * keyClass. * *

However, the implementation is actually provided by bytecode transformation for performance * reasons. * - *

- * - *

Context classes are weakly referenced and will be garbage collected when their corresponding - * user instance is collected. - * - *

- * - *

Instrumenters making this call must define the user-context class relationship in - * datadog.trace.agent.tooling.Instrumenter.Default#contextStore. - * - * @param userInstance The instance to store context on. - * @param userClass The user class context is attached to. + * @param keyClass The key class context is attached to. * @param contextClass The context class attached to the user class. - * @param user class - * @param context class - * @return The context instance attached to userInstance. + * @param key class + * @param context class + * @return The instance of context store for given arguments. */ - public static V get(K userInstance, Class userClass, Class contextClass) { - throw new RuntimeException("calls to this method will be rewritten"); + public static ContextStore get( + final Class keyClass, final Class contextClass) { + throw new RuntimeException("Calls to this method will be rewritten by MapBackedProvider"); } } diff --git a/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java b/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java index 4fafaed14a..557e2f0dfd 100644 --- a/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java +++ b/dd-java-agent/agent-tooling/src/main/java/datadog/trace/agent/tooling/context/MapBackedProvider.java @@ -5,6 +5,7 @@ import static datadog.trace.agent.tooling.ClassLoaderMatcher.BOOTSTRAP_CLASSLOAD import datadog.trace.agent.tooling.HelperInjector; import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.agent.tooling.Utils; +import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.InstrumentationContext; import datadog.trace.bootstrap.WeakMap; import java.lang.reflect.Method; @@ -37,8 +38,8 @@ import net.bytebuddy.utility.JavaModule; *

This is accomplished by * *

    - *
  1. Injecting a Dynamic Class to store a static map - *
  2. Rewritting calls to the context-store to access the map on the dynamic class + *
  3. Injecting a Dynamic Class created from {@link MapHolder} to store a static map + *
  4. Rewritting calls to the context-store to access the specific dynamic {@link MapHolder} *
* * Storing the map on a dynamic class and doing bytecode rewrites allows for a 1-pass lookup. @@ -47,11 +48,8 @@ import net.bytebuddy.utility.JavaModule; *

Example:
* InstrumentationContext.get(runnableInstance, Runnable.class, RunnableState.class)")
* is rewritten to:
- * RunnableInstrumentation$ContextStore$RunnableState12345.getOrCreate(runnableInstance, - * Runnable.class, RunnableState.class) - * - *

Map lookup implementation defined in template class: {@link MapHolder#getOrCreate(Object, - * Class, Class)} + * RunnableInstrumentation$ContextStore$RunnableState12345.getMapHolder(runnableRunnable.class, + * RunnableState.class) */ @Slf4j public class MapBackedProvider implements InstrumentationContextProvider { @@ -60,11 +58,9 @@ public class MapBackedProvider implements InstrumentationContextProvider { static { try { - contextGetMethod = - InstrumentationContext.class.getMethod("get", Object.class, Class.class, Class.class); - mapGetMethod = - MapHolder.class.getMethod("getOrCreate", Object.class, Class.class, Class.class); - } catch (Exception e) { + contextGetMethod = InstrumentationContext.class.getMethod("get", Class.class, Class.class); + mapGetMethod = MapHolder.class.getMethod("getMapHolder", Class.class, Class.class); + } catch (final Exception e) { throw new IllegalStateException(e); } } @@ -72,16 +68,17 @@ public class MapBackedProvider implements InstrumentationContextProvider { /** dynamic-class-name -> dynamic-class-bytes */ private final AtomicReference> dynamicClasses = new AtomicReference<>(null); - /** user-class-name -> dynamic-class-name */ + /** key-class-name -> dynamic-class-name */ private final AtomicReference> dynamicClassNames = new AtomicReference<>(null); private final Instrumenter.Default instrumenter; - public MapBackedProvider(Instrumenter.Default instrumenter) { + public MapBackedProvider(final Instrumenter.Default instrumenter) { this.instrumenter = instrumenter; } + @Override public AgentBuilder.Identified.Extendable instrumentationTransformer( AgentBuilder.Identified.Extendable builder) { if (instrumenter.contextStore().size() > 0) { @@ -90,10 +87,10 @@ public class MapBackedProvider implements InstrumentationContextProvider { new AgentBuilder.Transformer() { @Override public DynamicType.Builder transform( - DynamicType.Builder builder, - TypeDescription typeDescription, - ClassLoader classLoader, - JavaModule module) { + final DynamicType.Builder builder, + final TypeDescription typeDescription, + final ClassLoader classLoader, + final JavaModule module) { return builder.visit(getInstrumentationVisitor()); } }); @@ -104,10 +101,10 @@ public class MapBackedProvider implements InstrumentationContextProvider { @Override public DynamicType.Builder transform( - DynamicType.Builder builder, - TypeDescription typeDescription, - ClassLoader classLoader, - JavaModule module) { + final DynamicType.Builder builder, + final TypeDescription typeDescription, + final ClassLoader classLoader, + final JavaModule module) { return injector.transform( builder, typeDescription, @@ -123,41 +120,45 @@ public class MapBackedProvider implements InstrumentationContextProvider { private AsmVisitorWrapper getInstrumentationVisitor() { return new AsmVisitorWrapper() { @Override - public int mergeWriter(int flags) { + public int mergeWriter(final int flags) { return flags | ClassWriter.COMPUTE_MAXS; } @Override - public int mergeReader(int flags) { + public int mergeReader(final int flags) { return flags; } @Override public ClassVisitor wrap( final TypeDescription instrumentedType, - ClassVisitor classVisitor, - Implementation.Context implementationContext, - TypePool typePool, - FieldList fields, - MethodList methods, - int writerFlags, - int readerFlags) { + final ClassVisitor classVisitor, + final Implementation.Context implementationContext, + final TypePool typePool, + final FieldList fields, + final MethodList methods, + final int writerFlags, + final int readerFlags) { generateMapHolderClasses(); return new ClassVisitor(Opcodes.ASM7, classVisitor) { @Override public void visit( - int version, - int access, - String name, - String signature, - String superName, - String[] interfaces) { + final int version, + final int access, + final String name, + final String signature, + final String superName, + final String[] interfaces) { super.visit(version, access, name, signature, superName, interfaces); } @Override public MethodVisitor visitMethod( - int access, String name, String descriptor, String signature, String[] exceptions) { + final int access, + final String name, + final String descriptor, + final String signature, + final String[] exceptions) { final MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions); return new MethodVisitor(Opcodes.ASM7, mv) { @@ -168,7 +169,11 @@ public class MapBackedProvider implements InstrumentationContextProvider { @Override public void visitMethodInsn( - int opcode, String owner, String name, String descriptor, boolean isInterface) { + final int opcode, + final String owner, + final String name, + final String descriptor, + final boolean isInterface) { pushOpcode(opcode); if (Utils.getInternalName(contextGetMethod.getDeclaringClass().getName()) .equals(owner) @@ -180,28 +185,38 @@ public class MapBackedProvider implements InstrumentationContextProvider { && insnStack[2] == Opcodes.LDC) && (stack[0] instanceof Type && stack[1] instanceof Type)) { final String contextClassName = ((Type) stack[0]).getClassName(); - final String userClassName = ((Type) stack[1]).getClassName(); - final String mapHolderClass = dynamicClassNames.get().get(userClassName); + final String keyClassName = ((Type) stack[1]).getClassName(); + final String mapHolderClass = dynamicClassNames.get().get(keyClassName); log.debug( "Rewriting context-store map fetch for instrumenter {}: {} -> {}", instrumenter.getClass().getName(), - userClassName, + keyClassName, contextClassName); - if (mapHolderClass != null - && contextClassName.equals( - instrumenter.contextStore().get(userClassName))) { - // stack: contextClass | userClass | instance - mv.visitMethodInsn( - Opcodes.INVOKESTATIC, - Utils.getInternalName(mapHolderClass), - mapGetMethod.getName(), - Type.getMethodDescriptor(mapGetMethod), - false); - return; + if (mapHolderClass == null) { + throw new IllegalStateException( + String.format( + "Incorrect Context Api Usage detected. Cannot find map holder class for %s. Was that class defined in contextStore for instrumentation %s?", + keyClassName, instrumenter.getClass().getName())); } + if (!contextClassName.equals(instrumenter.contextStore().get(keyClassName))) { + throw new IllegalStateException( + String.format( + "Incorrect Context Api Usage detected. Incorrect context class %s, expected %s for instrumentation %s", + contextClassName, + instrumenter.contextStore().get(keyClassName), + instrumenter.getClass().getName())); + } + // stack: contextClass | keyClass + mv.visitMethodInsn( + Opcodes.INVOKESTATIC, + Utils.getInternalName(mapHolderClass), + mapGetMethod.getName(), + Type.getMethodDescriptor(mapGetMethod), + false); + return; } throw new IllegalStateException( - "Incorrect Context Api Usage detected. User and context class must be class-literals. Example of correct usage: InstrumentationContext.get(runnableInstance, Runnable.class, RunnableState.class)"); + "Incorrect Context Api Usage detected. Key and context class must be class-literals. Example of correct usage: InstrumentationContext.get(Runnable.class, RunnableContext.class)"); } else { super.visitMethodInsn(opcode, owner, name, descriptor, isInterface); } @@ -216,38 +231,38 @@ public class MapBackedProvider implements InstrumentationContextProvider { /** * Tracking the most recently pushed objects on the stack to assert proper api usage. */ - private void pushStack(Object o) { + private void pushStack(final Object o) { System.arraycopy(stack, 0, stack, 1, stack.length - 1); stack[0] = o; } @Override - public void visitInsn(int opcode) { + public void visitInsn(final int opcode) { pushOpcode(opcode); super.visitInsn(opcode); } @Override - public void visitJumpInsn(int opcode, Label label) { + public void visitJumpInsn(final int opcode, final Label label) { pushOpcode(opcode); super.visitJumpInsn(opcode, label); } @Override - public void visitIntInsn(int opcode, int operand) { + public void visitIntInsn(final int opcode, final int operand) { pushOpcode(opcode); super.visitIntInsn(opcode, operand); } @Override - public void visitVarInsn(int opcode, int var) { + public void visitVarInsn(final int opcode, final int var) { pushOpcode(opcode); pushStack(var); super.visitVarInsn(opcode, var); } @Override - public void visitLdcInsn(Object value) { + public void visitLdcInsn(final Object value) { pushOpcode(Opcodes.LDC); pushStack(value); super.visitLdcInsn(value); @@ -265,7 +280,7 @@ public class MapBackedProvider implements InstrumentationContextProvider { @Override public AgentBuilder.Identified.Extendable additionalInstrumentation( - AgentBuilder.Identified.Extendable builder) { + final AgentBuilder.Identified.Extendable builder) { return builder; } @@ -281,14 +296,14 @@ public class MapBackedProvider implements InstrumentationContextProvider { final Map dynamicClasses = new HashMap<>(instrumenter.contextStore().size()); final Map dynamicClassNames = new HashMap<>(instrumenter.contextStore().size()); - for (final String userClassName : instrumenter.contextStore().keySet()) { + for (final String keyClassName : instrumenter.contextStore().keySet()) { final String dynamicClassName = instrumenter.getClass().getName() + "$ContextStore" - + userClassName.replaceAll(".*([^\\.]+)$", "\\1") + + keyClassName.replaceAll(".*([^\\.]+)$", "\\1") + UUID.randomUUID().toString().replace('-', '_'); - dynamicClassNames.put(userClassName, dynamicClassName); + dynamicClassNames.put(keyClassName, dynamicClassName); dynamicClasses.put(dynamicClassName, makeMapHolderClass(dynamicClassName)); } this.dynamicClassNames.compareAndSet(null, dynamicClassNames); @@ -296,7 +311,7 @@ public class MapBackedProvider implements InstrumentationContextProvider { } } - private byte[] makeMapHolderClass(String className) { + private byte[] makeMapHolderClass(final String className) { return new ByteBuddy() .rebase(MapHolder.class) .modifiers(Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL) @@ -306,41 +321,62 @@ public class MapBackedProvider implements InstrumentationContextProvider { } /** Template class used to generate the class holding the global map. */ - private static final class MapHolder { - public static final WeakMap MAP = WeakMap.Provider.newWeakMap(); + private static final class MapHolder implements ContextStore { + private static final MapHolder INSTANCE = new MapHolder(WeakMap.Provider.newWeakMap()); - /** - * Fetch a context class out of the backing map. Create and return a new context class if none - * currently exists. - * - *

This method is thread safe. - */ - public static Object getOrCreate(Object instance, Class userClass, Class contextClass) { - if (!userClass.isAssignableFrom(instance.getClass())) { - throw new RuntimeException( - "Illegal context lookup. " - + instance.getClass().getName() - + " cannot be cast to " - + userClass.getName()); - } - Object contextInstance = MAP.get(instance); - if (null == contextInstance) { - synchronized (instance) { - contextInstance = MAP.get(instance); - if (null == contextInstance) { - try { - contextInstance = contextClass.newInstance(); - MAP.put(instance, contextInstance); - } catch (Exception e) { - throw new RuntimeException( - contextClass.getName() + " must define a public, no-arg constructor.", e); - } - } - } - } - return contextInstance; + private final WeakMap map; + + private MapHolder(final WeakMap map) { + this.map = map; } - private MapHolder() {} + @Override + public Object get(final Object key) { + return map.get(key); + } + + @Override + public Object putIfAbsent(final Object key, final Object context) { + Object existingContext = map.get(key); + if (null != existingContext) { + return existingContext; + } + synchronized (map) { + existingContext = map.get(key); + if (null != existingContext) { + return existingContext; + } + map.put(key, context); + return context; + } + } + + @Override + public Object putIfAbsent(final Object key, final ContextStore.Factory contextFactory) { + Object existingContext = map.get(key); + if (null != existingContext) { + return existingContext; + } + synchronized (map) { + existingContext = map.get(key); + if (null != existingContext) { + return existingContext; + } + final Object context = contextFactory.create(); + map.put(key, context); + return context; + } + } + + @Override + public void put(final Object key, final Object context) { + map.put(key, context); + } + + public static ContextStore getMapHolder(final Class keyClass, final Class contextClass) { + // We do not actually check the keyClass here - but that should be fine since compiler would + // check things for us. + return INSTANCE; + } } } diff --git a/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java b/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java index 6ed7386198..6a22d6507d 100644 --- a/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java +++ b/dd-java-agent/instrumentation/http-url-connection/src/main/java/datadog/trace/instrumentation/http_url_connection/HttpUrlConnectionInstrumentation.java @@ -12,6 +12,7 @@ import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.api.DDSpanTypes; import datadog.trace.api.DDTags; import datadog.trace.bootstrap.CallDepthThreadLocalMap; +import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.InstrumentationContext; import io.opentracing.Scope; import io.opentracing.Span; @@ -48,9 +49,13 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default { @Override public String[] helperClassNames() { - return new String[] {HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState"}; + return new String[] { + HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState", + HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState$1" + }; } + @Override public Map contextStore() { return Collections.singletonMap( "java.net.HttpURLConnection", getClass().getName() + "$HttpURLState"); @@ -72,9 +77,9 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default { @Advice.This final HttpURLConnection thiz, @Advice.FieldValue("connected") final boolean connected, @Advice.Origin("#m") final String methodName) { - - final HttpURLState state = - InstrumentationContext.get(thiz, HttpURLConnection.class, HttpURLState.class); + final ContextStore contextStore = + InstrumentationContext.get(HttpURLConnection.class, HttpURLState.class); + final HttpURLState state = contextStore.putIfAbsent(thiz, HttpURLState.FACTORY); String operationName = "http.request"; @@ -183,6 +188,15 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default { } public static class HttpURLState { + + public static final ContextStore.Factory FACTORY = + new ContextStore.Factory() { + @Override + public HttpURLState create() { + return new HttpURLState(); + } + }; + public boolean calledOutputStream = false; public boolean calledInputStream = false; } diff --git a/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy b/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy index 50261197f9..e351fcea25 100644 --- a/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy +++ b/dd-java-agent/testing/src/test/groovy/context/MapBackedProviderTest.groovy @@ -6,10 +6,15 @@ import net.bytebuddy.utility.JavaModule import java.lang.ref.WeakReference +import static context.ContextTestInstrumentation.IncorrectCallUsageKeyClass +import static context.ContextTestInstrumentation.IncorrectContextClassUsageKeyClass +import static context.ContextTestInstrumentation.IncorrectKeyClassUsageKeyClass +import static context.ContextTestInstrumentation.KeyClass + class MapBackedProviderTest extends AgentTestRunner { def setupSpec() { - assert new UserClass1().isInstrumented() + assert new KeyClass().isInstrumented() } @Override @@ -19,14 +24,16 @@ class MapBackedProviderTest extends AgentTestRunner { final JavaModule module, final boolean loaded, final Throwable throwable) { - // UserClass2 asserts on incorrect api usage. Error expected. - return !(typeName.equals(UserClass2.getName()) && throwable.getMessage().startsWith("Incorrect Context Api Usage detected.")) + // Incorrect* classes assert on incorrect api usage. Error expected. + return !(typeName.startsWith(ContextTestInstrumentation.getName() + "$Incorrect") && throwable.getMessage().startsWith("Incorrect Context Api Usage detected.")) } def "correct api usage stores state in map"() { + setup: + KeyClass instance1 = new KeyClass() + KeyClass instance2 = new KeyClass() + when: - UserClass1 instance1 = new UserClass1() - UserClass1 instance2 = new UserClass1() instance1.incrementContextCount() then: @@ -34,11 +41,22 @@ class MapBackedProviderTest extends AgentTestRunner { instance2.incrementContextCount() == 1 } - def "backing map should not create strong refs to user instances"() { + def "get/put test"() { + setup: + KeyClass instance1 = new KeyClass() + when: - UserClass1 instance = new UserClass1() + instance1.putContextCount(10) + + then: + instance1.getContextCount() == 10 + } + + def "backing map should not create strong refs to key class instances"() { + when: + KeyClass instance = new KeyClass() final int count = instance.incrementContextCount() - WeakReference instanceRef = new WeakReference(instance) + WeakReference instanceRef = new WeakReference(instance) instance = null TestUtils.awaitGC(instanceRef) @@ -47,8 +65,18 @@ class MapBackedProviderTest extends AgentTestRunner { count == 1 } - def "incorrect api usage fails at class load time"() { + def "incorrect key class usage fails at class load time"() { expect: - !new UserClass2().isInstrumented() + !new IncorrectKeyClassUsageKeyClass().isInstrumented() + } + + def "incorrect context class usage fails at class load time"() { + expect: + !new IncorrectContextClassUsageKeyClass().isInstrumented() + } + + def "incorrect call usage fails at class load time"() { + expect: + !new IncorrectCallUsageKeyClass().isInstrumented() } } diff --git a/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java b/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java index 9f87c589b8..d038eaad91 100644 --- a/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java +++ b/dd-java-agent/testing/src/test/java/context/ContextTestInstrumentation.java @@ -1,9 +1,11 @@ package context; +import static net.bytebuddy.matcher.ElementMatchers.nameStartsWith; import static net.bytebuddy.matcher.ElementMatchers.named; import com.google.auto.service.AutoService; import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.InstrumentationContext; import java.util.Collections; import java.util.HashMap; @@ -20,59 +22,181 @@ public class ContextTestInstrumentation extends Instrumenter.Default { @Override public ElementMatcher typeMatcher() { - return named(getClass().getPackage().getName() + ".UserClass1") - .or(named(getClass().getPackage().getName() + ".UserClass2")); + return nameStartsWith(getClass().getName() + "$"); } @Override public Map transformers() { final Map transformers = new HashMap<>(2); transformers.put(named("isInstrumented"), MarkInstrumentedAdvice.class.getName()); - transformers.put(named("incrementContextCount"), CorrectContextApiUsageAdvice.class.getName()); transformers.put( - named("incrementContextCountCountBroken"), IncorrectContextApiUsageAdvice.class.getName()); + named("incrementContextCount"), StoreAndIncrementApiUsageAdvice.class.getName()); + transformers.put(named("getContextCount"), GetApiUsageAdvice.class.getName()); + transformers.put(named("putContextCount"), PutApiUsageAdvice.class.getName()); + transformers.put( + named("incorrectKeyClassUsage"), IncorrectKeyClassContextApiUsageAdvice.class.getName()); + transformers.put( + named("incorrectContextClassUsage"), + IncorrectContextClassContextApiUsageAdvice.class.getName()); + transformers.put( + named("incorrectCallUsage"), IncorrectCallContextApiUsageAdvice.class.getName()); return transformers; } @Override public String[] helperClassNames() { - return new String[] {getClass().getName() + "$UserClass1State"}; + return new String[] {getClass().getName() + "$Context", getClass().getName() + "$Context$1"}; } @Override public Map contextStore() { return Collections.singletonMap( - getClass().getPackage().getName() + ".UserClass1", - getClass().getName() + "$UserClass1State"); + getClass().getName() + "$KeyClass", getClass().getName() + "$Context"); } public static class MarkInstrumentedAdvice { @Advice.OnMethodExit - public static void markInstrumented(@Advice.Return(readOnly = false) boolean isInstrumented) { + public static void methodExit(@Advice.Return(readOnly = false) boolean isInstrumented) { isInstrumented = true; } } - public static class CorrectContextApiUsageAdvice { + public static class StoreAndIncrementApiUsageAdvice { @Advice.OnMethodExit - public static void storeAndIncrement( - @Advice.This Object thiz, @Advice.Return(readOnly = false) int contextCount) { - UserClass1State state = - InstrumentationContext.get((UserClass1) thiz, UserClass1.class, UserClass1State.class); - contextCount = ++state.count; + public static void methodExit( + @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) { + final ContextStore contextStore = + InstrumentationContext.get(KeyClass.class, Context.class); + final Context context = contextStore.putIfAbsent(thiz, new Context()); + contextCount = ++context.count; } } - public static class IncorrectContextApiUsageAdvice { + public static class StoreAndIncrementWithFactoryApiUsageAdvice { @Advice.OnMethodExit - public static void storeAndIncrement( - @Advice.This Object thiz, @Advice.Return(readOnly = false) int contextCount) { - UserClass1State state = InstrumentationContext.get(thiz, Object.class, UserClass1State.class); - contextCount = ++state.count; + public static void methodExit( + @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) { + final ContextStore contextStore = + InstrumentationContext.get(KeyClass.class, Context.class); + final Context context = contextStore.putIfAbsent(thiz, Context.FACTORY); + contextCount = ++context.count; } } - public static class UserClass1State { + public static class GetApiUsageAdvice { + @Advice.OnMethodExit + public static void methodExit( + @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) { + final ContextStore contextStore = + InstrumentationContext.get(KeyClass.class, Context.class); + contextCount = contextStore.get(thiz).count; + } + } + + public static class PutApiUsageAdvice { + @Advice.OnMethodExit + public static void methodExit( + @Advice.This final KeyClass thiz, @Advice.Argument(0) final int value) { + final ContextStore contextStore = + InstrumentationContext.get(KeyClass.class, Context.class); + final Context context = new Context(); + context.count = value; + contextStore.put(thiz, context); + } + } + + public static class IncorrectKeyClassContextApiUsageAdvice { + @Advice.OnMethodExit + public static void methodExit() { + InstrumentationContext.get(Object.class, Context.class); + } + } + + public static class IncorrectContextClassContextApiUsageAdvice { + @Advice.OnMethodExit + public static void methodExit() { + InstrumentationContext.get(KeyClass.class, Object.class); + } + } + + public static class IncorrectCallContextApiUsageAdvice { + @Advice.OnMethodExit + public static void methodExit() { + // Our instrumentation doesn't handle variables being passed to InstrumentationContext.get, + // so we make sure that this actually fails instrumentation. + final Class clazz = null; + InstrumentationContext.get(clazz, Object.class); + } + } + + public static class Context { + public static final ContextStore.Factory FACTORY = + new ContextStore.Factory() { + @Override + public Context create() { + return new Context(); + } + }; + int count = 0; } + + public static class KeyClass { + public boolean isInstrumented() { + // implementation replaced with test instrumentation + return false; + } + + public int incrementContextCount() { + // implementation replaced with test instrumentation + return -1; + } + + public int incrementContextCountWithFactory() { + // implementation replaced with test instrumentation + return -1; + } + + public int getContextCount() { + // implementation replaced with test instrumentation + return -1; + } + + public void putContextCount(final int value) { + // implementation replaced with test instrumentation + } + } + + public static class IncorrectKeyClassUsageKeyClass { + public boolean isInstrumented() { + return false; + } + + public int incorrectKeyClassUsage() { + // instrumentation will not apply to this class because advice incorrectly uses context api + return -1; + } + } + + public static class IncorrectContextClassUsageKeyClass { + public boolean isInstrumented() { + return false; + } + + public int incorrectContextClassUsage() { + // instrumentation will not apply to this class because advice incorrectly uses context api + return -1; + } + } + + public static class IncorrectCallUsageKeyClass { + public boolean isInstrumented() { + return false; + } + + public int incorrectCallUsage() { + // instrumentation will not apply to this class because advice incorrectly uses context api + return -1; + } + } } diff --git a/dd-java-agent/testing/src/test/java/context/UserClass1.java b/dd-java-agent/testing/src/test/java/context/UserClass1.java deleted file mode 100644 index b2d5ced2f1..0000000000 --- a/dd-java-agent/testing/src/test/java/context/UserClass1.java +++ /dev/null @@ -1,13 +0,0 @@ -package context; - -public class UserClass1 { - public boolean isInstrumented() { - // implementation replaced with test instrumentation - return false; - } - - public int incrementContextCount() { - // implementation replaced with test instrumentation - return -1; - } -} diff --git a/dd-java-agent/testing/src/test/java/context/UserClass2.java b/dd-java-agent/testing/src/test/java/context/UserClass2.java deleted file mode 100644 index 9d676db335..0000000000 --- a/dd-java-agent/testing/src/test/java/context/UserClass2.java +++ /dev/null @@ -1,12 +0,0 @@ -package context; - -public class UserClass2 { - public boolean isInstrumented() { - return false; - } - - public int incrementContextCountCountBroken() { - // instrumentation will not apply to this class because advice incorrectly uses context api - return -1; - } -}