First iteration on a better context store api

Separate context storage from actual fetching/putting
This commit is contained in:
Nikolay Martynov 2018-11-01 16:44:13 -04:00
parent 3c80c25ea5
commit 98df97af3e
8 changed files with 403 additions and 184 deletions

View File

@ -0,0 +1,59 @@
package datadog.trace.bootstrap;
/**
* Interface to represent context storage for instrumentations.
*
* <p>Context instances are weakly referenced and will be garbage collected when their corresponding
* key instance is collected.
*
* @param <K> key type to do context lookups
* @param <C> context type
*/
public interface ContextStore<K, C> {
/**
* Factory interface to create context instances
*
* @param <C> context type
*/
interface Factory<C> {
/** @return new context instance */
C create();
}
/**
* Get context given the key
*
* @param key the key to looup
* @return context object
*/
C get(K key);
/**
* Put new context instance for given key
*
* @param key key to use
* @param context context instance to save
*/
void put(K key, C context);
/**
* Put new context instance if key is absent
*
* @param key key to use
* @param context new context instance to put
* @return old instance if it was present, or new instance
*/
C putIfAbsent(K key, C context);
/**
* Put new context instance if key is absent. Uses context factory to avoid creating objects if
* not needed.
*
* @param key key to use
* @param contextFactory factory instance to produce new context object
* @return old instance if it was present, or new instance
*/
C putIfAbsent(K key, Factory<C> contextFactory);
}

View File

@ -5,39 +5,22 @@ public class InstrumentationContext {
private InstrumentationContext() {} private InstrumentationContext() {}
/** /**
* Fetch a context instance out of the context store. * Find a {@link ContextStore} instance for given key class and context class.
* *
* <p> * <p>Conceptually this can be thought of as a map lookup to fetch a second level map given
* * keyClass.
* <p>Conceptually, this can be thought of as a two pass map look up.
*
* <p>For example: <em>RunnableState runnableState = get(runnableImpl, Runnable.class,
* RunnableState.class)</em> --> <em>RunnableState runnableState = (RunnableState)
* GlobalContextMap.get(Runnable.class).get(runnableImpl)</em>
*
* <p>
* *
* <p>However, the implementation is actually provided by bytecode transformation for performance * <p>However, the implementation is actually provided by bytecode transformation for performance
* reasons. * reasons.
* *
* <p> * @param keyClass The key class context is attached to.
*
* <p>Context classes are weakly referenced and will be garbage collected when their corresponding
* user instance is collected.
*
* <p>
*
* <p>Instrumenters making this call must define the user-context class relationship in
* datadog.trace.agent.tooling.Instrumenter.Default#contextStore.
*
* @param userInstance The instance to store context on.
* @param userClass The user class context is attached to.
* @param contextClass The context class attached to the user class. * @param contextClass The context class attached to the user class.
* @param <K> user class * @param <K> key class
* @param <V> context class * @param <C> context class
* @return The context instance attached to userInstance. * @return The instance of context store for given arguments.
*/ */
public static <K, V> V get(K userInstance, Class<K> userClass, Class<V> contextClass) { public static <K, C> ContextStore<K, C> get(
throw new RuntimeException("calls to this method will be rewritten"); final Class<K> keyClass, final Class<C> contextClass) {
throw new RuntimeException("Calls to this method will be rewritten by MapBackedProvider");
} }
} }

View File

@ -5,6 +5,7 @@ import static datadog.trace.agent.tooling.ClassLoaderMatcher.BOOTSTRAP_CLASSLOAD
import datadog.trace.agent.tooling.HelperInjector; import datadog.trace.agent.tooling.HelperInjector;
import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.Utils; import datadog.trace.agent.tooling.Utils;
import datadog.trace.bootstrap.ContextStore;
import datadog.trace.bootstrap.InstrumentationContext; import datadog.trace.bootstrap.InstrumentationContext;
import datadog.trace.bootstrap.WeakMap; import datadog.trace.bootstrap.WeakMap;
import java.lang.reflect.Method; import java.lang.reflect.Method;
@ -37,8 +38,8 @@ import net.bytebuddy.utility.JavaModule;
* <p>This is accomplished by * <p>This is accomplished by
* *
* <ol> * <ol>
* <li>Injecting a Dynamic Class to store a static map * <li>Injecting a Dynamic Class created from {@link MapHolder} to store a static map
* <li>Rewritting calls to the context-store to access the map on the dynamic class * <li>Rewritting calls to the context-store to access the specific dynamic {@link MapHolder}
* </ol> * </ol>
* *
* Storing the map on a dynamic class and doing bytecode rewrites allows for a 1-pass lookup. * Storing the map on a dynamic class and doing bytecode rewrites allows for a 1-pass lookup.
@ -47,11 +48,8 @@ import net.bytebuddy.utility.JavaModule;
* <p>Example:<br> * <p>Example:<br>
* <em>InstrumentationContext.get(runnableInstance, Runnable.class, RunnableState.class)")</em><br> * <em>InstrumentationContext.get(runnableInstance, Runnable.class, RunnableState.class)")</em><br>
* is rewritten to:<br> * is rewritten to:<br>
* <em>RunnableInstrumentation$ContextStore$RunnableState12345.getOrCreate(runnableInstance, * <em>RunnableInstrumentation$ContextStore$RunnableState12345.getMapHolder(runnableRunnable.class,
* Runnable.class, RunnableState.class)</em> * RunnableState.class)</em>
*
* <p>Map lookup implementation defined in template class: {@link MapHolder#getOrCreate(Object,
* Class, Class)}
*/ */
@Slf4j @Slf4j
public class MapBackedProvider implements InstrumentationContextProvider { public class MapBackedProvider implements InstrumentationContextProvider {
@ -60,11 +58,9 @@ public class MapBackedProvider implements InstrumentationContextProvider {
static { static {
try { try {
contextGetMethod = contextGetMethod = InstrumentationContext.class.getMethod("get", Class.class, Class.class);
InstrumentationContext.class.getMethod("get", Object.class, Class.class, Class.class); mapGetMethod = MapHolder.class.getMethod("getMapHolder", Class.class, Class.class);
mapGetMethod = } catch (final Exception e) {
MapHolder.class.getMethod("getOrCreate", Object.class, Class.class, Class.class);
} catch (Exception e) {
throw new IllegalStateException(e); throw new IllegalStateException(e);
} }
} }
@ -72,16 +68,17 @@ public class MapBackedProvider implements InstrumentationContextProvider {
/** dynamic-class-name -> dynamic-class-bytes */ /** dynamic-class-name -> dynamic-class-bytes */
private final AtomicReference<Map<String, byte[]>> dynamicClasses = new AtomicReference<>(null); private final AtomicReference<Map<String, byte[]>> dynamicClasses = new AtomicReference<>(null);
/** user-class-name -> dynamic-class-name */ /** key-class-name -> dynamic-class-name */
private final AtomicReference<Map<String, String>> dynamicClassNames = private final AtomicReference<Map<String, String>> dynamicClassNames =
new AtomicReference<>(null); new AtomicReference<>(null);
private final Instrumenter.Default instrumenter; private final Instrumenter.Default instrumenter;
public MapBackedProvider(Instrumenter.Default instrumenter) { public MapBackedProvider(final Instrumenter.Default instrumenter) {
this.instrumenter = instrumenter; this.instrumenter = instrumenter;
} }
@Override
public AgentBuilder.Identified.Extendable instrumentationTransformer( public AgentBuilder.Identified.Extendable instrumentationTransformer(
AgentBuilder.Identified.Extendable builder) { AgentBuilder.Identified.Extendable builder) {
if (instrumenter.contextStore().size() > 0) { if (instrumenter.contextStore().size() > 0) {
@ -90,10 +87,10 @@ public class MapBackedProvider implements InstrumentationContextProvider {
new AgentBuilder.Transformer() { new AgentBuilder.Transformer() {
@Override @Override
public DynamicType.Builder<?> transform( public DynamicType.Builder<?> transform(
DynamicType.Builder<?> builder, final DynamicType.Builder<?> builder,
TypeDescription typeDescription, final TypeDescription typeDescription,
ClassLoader classLoader, final ClassLoader classLoader,
JavaModule module) { final JavaModule module) {
return builder.visit(getInstrumentationVisitor()); return builder.visit(getInstrumentationVisitor());
} }
}); });
@ -104,10 +101,10 @@ public class MapBackedProvider implements InstrumentationContextProvider {
@Override @Override
public DynamicType.Builder<?> transform( public DynamicType.Builder<?> transform(
DynamicType.Builder<?> builder, final DynamicType.Builder<?> builder,
TypeDescription typeDescription, final TypeDescription typeDescription,
ClassLoader classLoader, final ClassLoader classLoader,
JavaModule module) { final JavaModule module) {
return injector.transform( return injector.transform(
builder, builder,
typeDescription, typeDescription,
@ -123,41 +120,45 @@ public class MapBackedProvider implements InstrumentationContextProvider {
private AsmVisitorWrapper getInstrumentationVisitor() { private AsmVisitorWrapper getInstrumentationVisitor() {
return new AsmVisitorWrapper() { return new AsmVisitorWrapper() {
@Override @Override
public int mergeWriter(int flags) { public int mergeWriter(final int flags) {
return flags | ClassWriter.COMPUTE_MAXS; return flags | ClassWriter.COMPUTE_MAXS;
} }
@Override @Override
public int mergeReader(int flags) { public int mergeReader(final int flags) {
return flags; return flags;
} }
@Override @Override
public ClassVisitor wrap( public ClassVisitor wrap(
final TypeDescription instrumentedType, final TypeDescription instrumentedType,
ClassVisitor classVisitor, final ClassVisitor classVisitor,
Implementation.Context implementationContext, final Implementation.Context implementationContext,
TypePool typePool, final TypePool typePool,
FieldList<FieldDescription.InDefinedShape> fields, final FieldList<FieldDescription.InDefinedShape> fields,
MethodList<?> methods, final MethodList<?> methods,
int writerFlags, final int writerFlags,
int readerFlags) { final int readerFlags) {
generateMapHolderClasses(); generateMapHolderClasses();
return new ClassVisitor(Opcodes.ASM7, classVisitor) { return new ClassVisitor(Opcodes.ASM7, classVisitor) {
@Override @Override
public void visit( public void visit(
int version, final int version,
int access, final int access,
String name, final String name,
String signature, final String signature,
String superName, final String superName,
String[] interfaces) { final String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces); super.visit(version, access, name, signature, superName, interfaces);
} }
@Override @Override
public MethodVisitor visitMethod( public MethodVisitor visitMethod(
int access, String name, String descriptor, String signature, String[] exceptions) { final int access,
final String name,
final String descriptor,
final String signature,
final String[] exceptions) {
final MethodVisitor mv = final MethodVisitor mv =
super.visitMethod(access, name, descriptor, signature, exceptions); super.visitMethod(access, name, descriptor, signature, exceptions);
return new MethodVisitor(Opcodes.ASM7, mv) { return new MethodVisitor(Opcodes.ASM7, mv) {
@ -168,7 +169,11 @@ public class MapBackedProvider implements InstrumentationContextProvider {
@Override @Override
public void visitMethodInsn( public void visitMethodInsn(
int opcode, String owner, String name, String descriptor, boolean isInterface) { final int opcode,
final String owner,
final String name,
final String descriptor,
final boolean isInterface) {
pushOpcode(opcode); pushOpcode(opcode);
if (Utils.getInternalName(contextGetMethod.getDeclaringClass().getName()) if (Utils.getInternalName(contextGetMethod.getDeclaringClass().getName())
.equals(owner) .equals(owner)
@ -180,17 +185,28 @@ public class MapBackedProvider implements InstrumentationContextProvider {
&& insnStack[2] == Opcodes.LDC) && insnStack[2] == Opcodes.LDC)
&& (stack[0] instanceof Type && stack[1] instanceof Type)) { && (stack[0] instanceof Type && stack[1] instanceof Type)) {
final String contextClassName = ((Type) stack[0]).getClassName(); final String contextClassName = ((Type) stack[0]).getClassName();
final String userClassName = ((Type) stack[1]).getClassName(); final String keyClassName = ((Type) stack[1]).getClassName();
final String mapHolderClass = dynamicClassNames.get().get(userClassName); final String mapHolderClass = dynamicClassNames.get().get(keyClassName);
log.debug( log.debug(
"Rewriting context-store map fetch for instrumenter {}: {} -> {}", "Rewriting context-store map fetch for instrumenter {}: {} -> {}",
instrumenter.getClass().getName(), instrumenter.getClass().getName(),
userClassName, keyClassName,
contextClassName); contextClassName);
if (mapHolderClass != null if (mapHolderClass == null) {
&& contextClassName.equals( throw new IllegalStateException(
instrumenter.contextStore().get(userClassName))) { String.format(
// stack: contextClass | userClass | instance "Incorrect Context Api Usage detected. Cannot find map holder class for %s. Was that class defined in contextStore for instrumentation %s?",
keyClassName, instrumenter.getClass().getName()));
}
if (!contextClassName.equals(instrumenter.contextStore().get(keyClassName))) {
throw new IllegalStateException(
String.format(
"Incorrect Context Api Usage detected. Incorrect context class %s, expected %s for instrumentation %s",
contextClassName,
instrumenter.contextStore().get(keyClassName),
instrumenter.getClass().getName()));
}
// stack: contextClass | keyClass
mv.visitMethodInsn( mv.visitMethodInsn(
Opcodes.INVOKESTATIC, Opcodes.INVOKESTATIC,
Utils.getInternalName(mapHolderClass), Utils.getInternalName(mapHolderClass),
@ -199,9 +215,8 @@ public class MapBackedProvider implements InstrumentationContextProvider {
false); false);
return; return;
} }
}
throw new IllegalStateException( throw new IllegalStateException(
"Incorrect Context Api Usage detected. User and context class must be class-literals. Example of correct usage: InstrumentationContext.get(runnableInstance, Runnable.class, RunnableState.class)"); "Incorrect Context Api Usage detected. Key and context class must be class-literals. Example of correct usage: InstrumentationContext.get(Runnable.class, RunnableContext.class)");
} else { } else {
super.visitMethodInsn(opcode, owner, name, descriptor, isInterface); super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
} }
@ -216,38 +231,38 @@ public class MapBackedProvider implements InstrumentationContextProvider {
/** /**
* Tracking the most recently pushed objects on the stack to assert proper api usage. * Tracking the most recently pushed objects on the stack to assert proper api usage.
*/ */
private void pushStack(Object o) { private void pushStack(final Object o) {
System.arraycopy(stack, 0, stack, 1, stack.length - 1); System.arraycopy(stack, 0, stack, 1, stack.length - 1);
stack[0] = o; stack[0] = o;
} }
@Override @Override
public void visitInsn(int opcode) { public void visitInsn(final int opcode) {
pushOpcode(opcode); pushOpcode(opcode);
super.visitInsn(opcode); super.visitInsn(opcode);
} }
@Override @Override
public void visitJumpInsn(int opcode, Label label) { public void visitJumpInsn(final int opcode, final Label label) {
pushOpcode(opcode); pushOpcode(opcode);
super.visitJumpInsn(opcode, label); super.visitJumpInsn(opcode, label);
} }
@Override @Override
public void visitIntInsn(int opcode, int operand) { public void visitIntInsn(final int opcode, final int operand) {
pushOpcode(opcode); pushOpcode(opcode);
super.visitIntInsn(opcode, operand); super.visitIntInsn(opcode, operand);
} }
@Override @Override
public void visitVarInsn(int opcode, int var) { public void visitVarInsn(final int opcode, final int var) {
pushOpcode(opcode); pushOpcode(opcode);
pushStack(var); pushStack(var);
super.visitVarInsn(opcode, var); super.visitVarInsn(opcode, var);
} }
@Override @Override
public void visitLdcInsn(Object value) { public void visitLdcInsn(final Object value) {
pushOpcode(Opcodes.LDC); pushOpcode(Opcodes.LDC);
pushStack(value); pushStack(value);
super.visitLdcInsn(value); super.visitLdcInsn(value);
@ -265,7 +280,7 @@ public class MapBackedProvider implements InstrumentationContextProvider {
@Override @Override
public AgentBuilder.Identified.Extendable additionalInstrumentation( public AgentBuilder.Identified.Extendable additionalInstrumentation(
AgentBuilder.Identified.Extendable builder) { final AgentBuilder.Identified.Extendable builder) {
return builder; return builder;
} }
@ -281,14 +296,14 @@ public class MapBackedProvider implements InstrumentationContextProvider {
final Map<String, byte[]> dynamicClasses = new HashMap<>(instrumenter.contextStore().size()); final Map<String, byte[]> dynamicClasses = new HashMap<>(instrumenter.contextStore().size());
final Map<String, String> dynamicClassNames = final Map<String, String> dynamicClassNames =
new HashMap<>(instrumenter.contextStore().size()); new HashMap<>(instrumenter.contextStore().size());
for (final String userClassName : instrumenter.contextStore().keySet()) { for (final String keyClassName : instrumenter.contextStore().keySet()) {
final String dynamicClassName = final String dynamicClassName =
instrumenter.getClass().getName() instrumenter.getClass().getName()
+ "$ContextStore" + "$ContextStore"
+ userClassName.replaceAll(".*([^\\.]+)$", "\\1") + keyClassName.replaceAll(".*([^\\.]+)$", "\\1")
+ UUID.randomUUID().toString().replace('-', '_'); + UUID.randomUUID().toString().replace('-', '_');
dynamicClassNames.put(userClassName, dynamicClassName); dynamicClassNames.put(keyClassName, dynamicClassName);
dynamicClasses.put(dynamicClassName, makeMapHolderClass(dynamicClassName)); dynamicClasses.put(dynamicClassName, makeMapHolderClass(dynamicClassName));
} }
this.dynamicClassNames.compareAndSet(null, dynamicClassNames); this.dynamicClassNames.compareAndSet(null, dynamicClassNames);
@ -296,7 +311,7 @@ public class MapBackedProvider implements InstrumentationContextProvider {
} }
} }
private byte[] makeMapHolderClass(String className) { private byte[] makeMapHolderClass(final String className) {
return new ByteBuddy() return new ByteBuddy()
.rebase(MapHolder.class) .rebase(MapHolder.class)
.modifiers(Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL) .modifiers(Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL)
@ -306,41 +321,62 @@ public class MapBackedProvider implements InstrumentationContextProvider {
} }
/** Template class used to generate the class holding the global map. */ /** Template class used to generate the class holding the global map. */
private static final class MapHolder { private static final class MapHolder implements ContextStore<Object, Object> {
public static final WeakMap MAP = WeakMap.Provider.newWeakMap(); private static final MapHolder INSTANCE = new MapHolder(WeakMap.Provider.newWeakMap());
/** private final WeakMap map;
* Fetch a context class out of the backing map. Create and return a new context class if none
* currently exists. private MapHolder(final WeakMap map) {
* this.map = map;
* <p>This method is thread safe.
*/
public static Object getOrCreate(Object instance, Class userClass, Class contextClass) {
if (!userClass.isAssignableFrom(instance.getClass())) {
throw new RuntimeException(
"Illegal context lookup. "
+ instance.getClass().getName()
+ " cannot be cast to "
+ userClass.getName());
}
Object contextInstance = MAP.get(instance);
if (null == contextInstance) {
synchronized (instance) {
contextInstance = MAP.get(instance);
if (null == contextInstance) {
try {
contextInstance = contextClass.newInstance();
MAP.put(instance, contextInstance);
} catch (Exception e) {
throw new RuntimeException(
contextClass.getName() + " must define a public, no-arg constructor.", e);
}
}
}
}
return contextInstance;
} }
private MapHolder() {} @Override
public Object get(final Object key) {
return map.get(key);
}
@Override
public Object putIfAbsent(final Object key, final Object context) {
Object existingContext = map.get(key);
if (null != existingContext) {
return existingContext;
}
synchronized (map) {
existingContext = map.get(key);
if (null != existingContext) {
return existingContext;
}
map.put(key, context);
return context;
}
}
@Override
public Object putIfAbsent(final Object key, final ContextStore.Factory<Object> contextFactory) {
Object existingContext = map.get(key);
if (null != existingContext) {
return existingContext;
}
synchronized (map) {
existingContext = map.get(key);
if (null != existingContext) {
return existingContext;
}
final Object context = contextFactory.create();
map.put(key, context);
return context;
}
}
@Override
public void put(final Object key, final Object context) {
map.put(key, context);
}
public static ContextStore getMapHolder(final Class keyClass, final Class contextClass) {
// We do not actually check the keyClass here - but that should be fine since compiler would
// check things for us.
return INSTANCE;
}
} }
} }

View File

@ -12,6 +12,7 @@ import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.api.DDSpanTypes; import datadog.trace.api.DDSpanTypes;
import datadog.trace.api.DDTags; import datadog.trace.api.DDTags;
import datadog.trace.bootstrap.CallDepthThreadLocalMap; import datadog.trace.bootstrap.CallDepthThreadLocalMap;
import datadog.trace.bootstrap.ContextStore;
import datadog.trace.bootstrap.InstrumentationContext; import datadog.trace.bootstrap.InstrumentationContext;
import io.opentracing.Scope; import io.opentracing.Scope;
import io.opentracing.Span; import io.opentracing.Span;
@ -48,9 +49,13 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default {
@Override @Override
public String[] helperClassNames() { public String[] helperClassNames() {
return new String[] {HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState"}; return new String[] {
HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState",
HttpUrlConnectionInstrumentation.class.getName() + "$HttpURLState$1"
};
} }
@Override
public Map<String, String> contextStore() { public Map<String, String> contextStore() {
return Collections.singletonMap( return Collections.singletonMap(
"java.net.HttpURLConnection", getClass().getName() + "$HttpURLState"); "java.net.HttpURLConnection", getClass().getName() + "$HttpURLState");
@ -72,9 +77,9 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default {
@Advice.This final HttpURLConnection thiz, @Advice.This final HttpURLConnection thiz,
@Advice.FieldValue("connected") final boolean connected, @Advice.FieldValue("connected") final boolean connected,
@Advice.Origin("#m") final String methodName) { @Advice.Origin("#m") final String methodName) {
final ContextStore<HttpURLConnection, HttpURLState> contextStore =
final HttpURLState state = InstrumentationContext.get(HttpURLConnection.class, HttpURLState.class);
InstrumentationContext.get(thiz, HttpURLConnection.class, HttpURLState.class); final HttpURLState state = contextStore.putIfAbsent(thiz, HttpURLState.FACTORY);
String operationName = "http.request"; String operationName = "http.request";
@ -183,6 +188,15 @@ public class HttpUrlConnectionInstrumentation extends Instrumenter.Default {
} }
public static class HttpURLState { public static class HttpURLState {
public static final ContextStore.Factory<HttpURLState> FACTORY =
new ContextStore.Factory<HttpURLState>() {
@Override
public HttpURLState create() {
return new HttpURLState();
}
};
public boolean calledOutputStream = false; public boolean calledOutputStream = false;
public boolean calledInputStream = false; public boolean calledInputStream = false;
} }

View File

@ -6,10 +6,15 @@ import net.bytebuddy.utility.JavaModule
import java.lang.ref.WeakReference import java.lang.ref.WeakReference
import static context.ContextTestInstrumentation.IncorrectCallUsageKeyClass
import static context.ContextTestInstrumentation.IncorrectContextClassUsageKeyClass
import static context.ContextTestInstrumentation.IncorrectKeyClassUsageKeyClass
import static context.ContextTestInstrumentation.KeyClass
class MapBackedProviderTest extends AgentTestRunner { class MapBackedProviderTest extends AgentTestRunner {
def setupSpec() { def setupSpec() {
assert new UserClass1().isInstrumented() assert new KeyClass().isInstrumented()
} }
@Override @Override
@ -19,14 +24,16 @@ class MapBackedProviderTest extends AgentTestRunner {
final JavaModule module, final JavaModule module,
final boolean loaded, final boolean loaded,
final Throwable throwable) { final Throwable throwable) {
// UserClass2 asserts on incorrect api usage. Error expected. // Incorrect* classes assert on incorrect api usage. Error expected.
return !(typeName.equals(UserClass2.getName()) && throwable.getMessage().startsWith("Incorrect Context Api Usage detected.")) return !(typeName.startsWith(ContextTestInstrumentation.getName() + "$Incorrect") && throwable.getMessage().startsWith("Incorrect Context Api Usage detected."))
} }
def "correct api usage stores state in map"() { def "correct api usage stores state in map"() {
setup:
KeyClass instance1 = new KeyClass()
KeyClass instance2 = new KeyClass()
when: when:
UserClass1 instance1 = new UserClass1()
UserClass1 instance2 = new UserClass1()
instance1.incrementContextCount() instance1.incrementContextCount()
then: then:
@ -34,11 +41,22 @@ class MapBackedProviderTest extends AgentTestRunner {
instance2.incrementContextCount() == 1 instance2.incrementContextCount() == 1
} }
def "backing map should not create strong refs to user instances"() { def "get/put test"() {
setup:
KeyClass instance1 = new KeyClass()
when: when:
UserClass1 instance = new UserClass1() instance1.putContextCount(10)
then:
instance1.getContextCount() == 10
}
def "backing map should not create strong refs to key class instances"() {
when:
KeyClass instance = new KeyClass()
final int count = instance.incrementContextCount() final int count = instance.incrementContextCount()
WeakReference<UserClass1> instanceRef = new WeakReference(instance) WeakReference<KeyClass> instanceRef = new WeakReference(instance)
instance = null instance = null
TestUtils.awaitGC(instanceRef) TestUtils.awaitGC(instanceRef)
@ -47,8 +65,18 @@ class MapBackedProviderTest extends AgentTestRunner {
count == 1 count == 1
} }
def "incorrect api usage fails at class load time"() { def "incorrect key class usage fails at class load time"() {
expect: expect:
!new UserClass2().isInstrumented() !new IncorrectKeyClassUsageKeyClass().isInstrumented()
}
def "incorrect context class usage fails at class load time"() {
expect:
!new IncorrectContextClassUsageKeyClass().isInstrumented()
}
def "incorrect call usage fails at class load time"() {
expect:
!new IncorrectCallUsageKeyClass().isInstrumented()
} }
} }

View File

@ -1,9 +1,11 @@
package context; package context;
import static net.bytebuddy.matcher.ElementMatchers.nameStartsWith;
import static net.bytebuddy.matcher.ElementMatchers.named; import static net.bytebuddy.matcher.ElementMatchers.named;
import com.google.auto.service.AutoService; import com.google.auto.service.AutoService;
import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.bootstrap.ContextStore;
import datadog.trace.bootstrap.InstrumentationContext; import datadog.trace.bootstrap.InstrumentationContext;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -20,59 +22,181 @@ public class ContextTestInstrumentation extends Instrumenter.Default {
@Override @Override
public ElementMatcher<? super TypeDescription> typeMatcher() { public ElementMatcher<? super TypeDescription> typeMatcher() {
return named(getClass().getPackage().getName() + ".UserClass1") return nameStartsWith(getClass().getName() + "$");
.or(named(getClass().getPackage().getName() + ".UserClass2"));
} }
@Override @Override
public Map<? extends ElementMatcher, String> transformers() { public Map<? extends ElementMatcher, String> transformers() {
final Map<ElementMatcher, String> transformers = new HashMap<>(2); final Map<ElementMatcher, String> transformers = new HashMap<>(2);
transformers.put(named("isInstrumented"), MarkInstrumentedAdvice.class.getName()); transformers.put(named("isInstrumented"), MarkInstrumentedAdvice.class.getName());
transformers.put(named("incrementContextCount"), CorrectContextApiUsageAdvice.class.getName());
transformers.put( transformers.put(
named("incrementContextCountCountBroken"), IncorrectContextApiUsageAdvice.class.getName()); named("incrementContextCount"), StoreAndIncrementApiUsageAdvice.class.getName());
transformers.put(named("getContextCount"), GetApiUsageAdvice.class.getName());
transformers.put(named("putContextCount"), PutApiUsageAdvice.class.getName());
transformers.put(
named("incorrectKeyClassUsage"), IncorrectKeyClassContextApiUsageAdvice.class.getName());
transformers.put(
named("incorrectContextClassUsage"),
IncorrectContextClassContextApiUsageAdvice.class.getName());
transformers.put(
named("incorrectCallUsage"), IncorrectCallContextApiUsageAdvice.class.getName());
return transformers; return transformers;
} }
@Override @Override
public String[] helperClassNames() { public String[] helperClassNames() {
return new String[] {getClass().getName() + "$UserClass1State"}; return new String[] {getClass().getName() + "$Context", getClass().getName() + "$Context$1"};
} }
@Override @Override
public Map<String, String> contextStore() { public Map<String, String> contextStore() {
return Collections.singletonMap( return Collections.singletonMap(
getClass().getPackage().getName() + ".UserClass1", getClass().getName() + "$KeyClass", getClass().getName() + "$Context");
getClass().getName() + "$UserClass1State");
} }
public static class MarkInstrumentedAdvice { public static class MarkInstrumentedAdvice {
@Advice.OnMethodExit @Advice.OnMethodExit
public static void markInstrumented(@Advice.Return(readOnly = false) boolean isInstrumented) { public static void methodExit(@Advice.Return(readOnly = false) boolean isInstrumented) {
isInstrumented = true; isInstrumented = true;
} }
} }
public static class CorrectContextApiUsageAdvice { public static class StoreAndIncrementApiUsageAdvice {
@Advice.OnMethodExit @Advice.OnMethodExit
public static void storeAndIncrement( public static void methodExit(
@Advice.This Object thiz, @Advice.Return(readOnly = false) int contextCount) { @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) {
UserClass1State state = final ContextStore<KeyClass, Context> contextStore =
InstrumentationContext.get((UserClass1) thiz, UserClass1.class, UserClass1State.class); InstrumentationContext.get(KeyClass.class, Context.class);
contextCount = ++state.count; final Context context = contextStore.putIfAbsent(thiz, new Context());
contextCount = ++context.count;
} }
} }
public static class IncorrectContextApiUsageAdvice { public static class StoreAndIncrementWithFactoryApiUsageAdvice {
@Advice.OnMethodExit @Advice.OnMethodExit
public static void storeAndIncrement( public static void methodExit(
@Advice.This Object thiz, @Advice.Return(readOnly = false) int contextCount) { @Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) {
UserClass1State state = InstrumentationContext.get(thiz, Object.class, UserClass1State.class); final ContextStore<KeyClass, Context> contextStore =
contextCount = ++state.count; InstrumentationContext.get(KeyClass.class, Context.class);
final Context context = contextStore.putIfAbsent(thiz, Context.FACTORY);
contextCount = ++context.count;
} }
} }
public static class UserClass1State { public static class GetApiUsageAdvice {
@Advice.OnMethodExit
public static void methodExit(
@Advice.This final KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) {
final ContextStore<KeyClass, Context> contextStore =
InstrumentationContext.get(KeyClass.class, Context.class);
contextCount = contextStore.get(thiz).count;
}
}
public static class PutApiUsageAdvice {
@Advice.OnMethodExit
public static void methodExit(
@Advice.This final KeyClass thiz, @Advice.Argument(0) final int value) {
final ContextStore<KeyClass, Context> contextStore =
InstrumentationContext.get(KeyClass.class, Context.class);
final Context context = new Context();
context.count = value;
contextStore.put(thiz, context);
}
}
public static class IncorrectKeyClassContextApiUsageAdvice {
@Advice.OnMethodExit
public static void methodExit() {
InstrumentationContext.get(Object.class, Context.class);
}
}
public static class IncorrectContextClassContextApiUsageAdvice {
@Advice.OnMethodExit
public static void methodExit() {
InstrumentationContext.get(KeyClass.class, Object.class);
}
}
public static class IncorrectCallContextApiUsageAdvice {
@Advice.OnMethodExit
public static void methodExit() {
// Our instrumentation doesn't handle variables being passed to InstrumentationContext.get,
// so we make sure that this actually fails instrumentation.
final Class clazz = null;
InstrumentationContext.get(clazz, Object.class);
}
}
public static class Context {
public static final ContextStore.Factory<Context> FACTORY =
new ContextStore.Factory<Context>() {
@Override
public Context create() {
return new Context();
}
};
int count = 0; int count = 0;
} }
public static class KeyClass {
public boolean isInstrumented() {
// implementation replaced with test instrumentation
return false;
}
public int incrementContextCount() {
// implementation replaced with test instrumentation
return -1;
}
public int incrementContextCountWithFactory() {
// implementation replaced with test instrumentation
return -1;
}
public int getContextCount() {
// implementation replaced with test instrumentation
return -1;
}
public void putContextCount(final int value) {
// implementation replaced with test instrumentation
}
}
public static class IncorrectKeyClassUsageKeyClass {
public boolean isInstrumented() {
return false;
}
public int incorrectKeyClassUsage() {
// instrumentation will not apply to this class because advice incorrectly uses context api
return -1;
}
}
public static class IncorrectContextClassUsageKeyClass {
public boolean isInstrumented() {
return false;
}
public int incorrectContextClassUsage() {
// instrumentation will not apply to this class because advice incorrectly uses context api
return -1;
}
}
public static class IncorrectCallUsageKeyClass {
public boolean isInstrumented() {
return false;
}
public int incorrectCallUsage() {
// instrumentation will not apply to this class because advice incorrectly uses context api
return -1;
}
}
} }

View File

@ -1,13 +0,0 @@
package context;
public class UserClass1 {
public boolean isInstrumented() {
// implementation replaced with test instrumentation
return false;
}
public int incrementContextCount() {
// implementation replaced with test instrumentation
return -1;
}
}

View File

@ -1,12 +0,0 @@
package context;
public class UserClass2 {
public boolean isInstrumented() {
return false;
}
public int incrementContextCountCountBroken() {
// instrumentation will not apply to this class because advice incorrectly uses context api
return -1;
}
}