Muzzle field matching

This commit is contained in:
Andrew Kent 2018-07-31 19:27:31 -07:00
parent c4daf007e3
commit f8f45d7f14
7 changed files with 353 additions and 71 deletions

View File

@ -229,29 +229,75 @@ public class MuzzleVisitor implements AsmVisitorWrapper {
false);
}
for (Reference.Field field : references[i].getFields()) {
mv.visitLdcInsn(field.getName());
mv.visitLdcInsn(field.getFlags().size());
mv.visitTypeInsn(
Opcodes.ANEWARRAY, "datadog/trace/agent/tooling/muzzle/Reference$Flag");
{ // sources
mv.visitLdcInsn(field.getSources().size());
mv.visitTypeInsn(
Opcodes.ANEWARRAY, "datadog/trace/agent/tooling/muzzle/Reference$Source");
int j = 0;
for (Reference.Flag flag : field.getFlags()) {
mv.visitInsn(Opcodes.DUP);
mv.visitLdcInsn(j);
mv.visitFieldInsn(
Opcodes.GETSTATIC,
"datadog/trace/agent/tooling/muzzle/Reference$Flag",
flag.name(),
"Ldatadog/trace/agent/tooling/muzzle/Reference$Flag;");
mv.visitInsn(Opcodes.AASTORE);
++j;
int j = 0;
for (Reference.Source source : field.getSources()) {
mv.visitInsn(Opcodes.DUP);
mv.visitLdcInsn(j);
mv.visitTypeInsn(
Opcodes.NEW, "datadog/trace/agent/tooling/muzzle/Reference$Source");
mv.visitInsn(Opcodes.DUP);
mv.visitLdcInsn(source.getName());
mv.visitLdcInsn(source.getLine());
mv.visitMethodInsn(
Opcodes.INVOKESPECIAL,
"datadog/trace/agent/tooling/muzzle/Reference$Source",
"<init>",
"(Ljava/lang/String;I)V",
false);
mv.visitInsn(Opcodes.AASTORE);
++j;
}
}
{ // flags
mv.visitLdcInsn(field.getFlags().size());
mv.visitTypeInsn(
Opcodes.ANEWARRAY, "datadog/trace/agent/tooling/muzzle/Reference$Flag");
int j = 0;
for (Reference.Flag flag : field.getFlags()) {
mv.visitInsn(Opcodes.DUP);
mv.visitLdcInsn(j);
mv.visitFieldInsn(
Opcodes.GETSTATIC,
"datadog/trace/agent/tooling/muzzle/Reference$Flag",
flag.name(),
"Ldatadog/trace/agent/tooling/muzzle/Reference$Flag;");
mv.visitInsn(Opcodes.AASTORE);
++j;
}
}
mv.visitLdcInsn(field.getName());
{ // field type
mv.visitLdcInsn(field.getType().getDescriptor());
mv.visitMethodInsn(
Opcodes.INVOKESTATIC,
Type.getInternalName(Type.class),
"getType",
Type.getMethodDescriptor(Type.class.getMethod("getType", String.class)),
false);
}
mv.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"datadog/trace/agent/tooling/muzzle/Reference$Builder",
"withField",
"(Ljava/lang/String;[Ldatadog/trace/agent/tooling/muzzle/Reference$Flag;)Ldatadog/trace/agent/tooling/muzzle/Reference$Builder;",
Type.getMethodDescriptor(
Reference.Builder.class.getMethod(
"withField",
Reference.Source[].class,
Reference.Flag[].class,
String.class,
Type.class)),
false);
}
for (Reference.Method method : references[i].getMethods()) {

View File

@ -67,9 +67,7 @@ public class Reference {
}
/**
* Create a new reference which combines this reference with another reference.
*
* <p>Attempts to merge incompatible references will throw an IllegalStateException.
* Create a new reference which combines this reference with another reference of the same type.
*
* @param anotherReference A reference to the same class
* @return a new Reference which merges the two references
@ -212,6 +210,24 @@ public class Reference {
}
}
public static class MissingField extends Mismatch {
private final String className;
private final String fieldName;
private final String fieldDesc;
public MissingField(Source[] sources, String className, String fieldName, String fieldDesc) {
super(sources);
this.className = className;
this.fieldName = fieldName;
this.fieldDesc = fieldDesc;
}
@Override
String getMismatchDetails() {
return "Missing field " + className + "#" + fieldName + fieldDesc;
}
}
public static class MissingMethod extends Mismatch {
final String className;
final String method;
@ -418,8 +434,7 @@ public class Reference {
public Method merge(Method anotherMethod) {
if (!this.equals(anotherMethod)) {
throw new IllegalStateException(
"Cannot merge incompatible methods " + this + " <> " + anotherMethod);
throw new IllegalStateException("illegal merge " + this + " != " + anotherMethod);
}
final Set<Source> mergedSources = new HashSet<>();
@ -459,25 +474,45 @@ public class Reference {
private final Set<Source> sources;
private final Set<Flag> flags;
private final String name;
private final Type type;
public Field(Set<Source> sources, Set<Flag> flags, String name) {
this.sources = sources;
this.flags = flags;
public Field(Source[] sources, Flag[] flags, String name, Type fieldType) {
this.sources = new HashSet<>(Arrays.asList(sources));
this.flags = new HashSet<>(Arrays.asList(flags));
this.name = name;
this.type = fieldType;
}
public String getName() {
return name;
}
public Set<Source> getSources() {
return sources;
}
public Set<Flag> getFlags() {
return flags;
}
public Type getType() {
return type;
}
public Field merge(Field anotherField) {
// TODO: implement
// also assert same class
return this;
if (!this.equals(anotherField) || (!type.equals(anotherField.type))) {
throw new IllegalStateException("illegal merge " + this + " != " + anotherField);
}
return new Field(
Reference.merge(sources, anotherField.sources).toArray(new Source[0]),
mergeFlags(flags, anotherField.flags).toArray(new Flag[0]),
name,
type);
}
@Override
public String toString() {
return "FieldRef:" + name + type.getInternalName();
}
@Override
@ -528,8 +563,15 @@ public class Reference {
return this;
}
public Builder withField(String fieldName, Flag... fieldFlags) {
// TODO
public Builder withField(
Source[] sources, Flag[] fieldFlags, String fieldName, Type fieldType) {
final Field field = new Field(sources, fieldFlags, fieldName, fieldType);
int existingIndex = fields.indexOf(field);
if (existingIndex == -1) {
fields.add(field);
} else {
fields.set(existingIndex, field.merge(fields.get(existingIndex)));
}
return this;
}

View File

@ -106,6 +106,24 @@ public class ReferenceCreator extends ClassVisitor {
}
}
/**
* Compute the minimum required access for FROM class to access a field on the TO class.
*
* @return A reference flag with the required level of access.
*/
private static Reference.Flag computeMinimumFieldAccess(Type from, Type to) {
if (from.getInternalName().equalsIgnoreCase(to.getInternalName())) {
return Reference.Flag.PRIVATE_OR_HIGHER;
} else if (internalPackageName(from.getInternalName())
.equals(internalPackageName(to.getInternalName()))) {
return Reference.Flag.PACKAGE_OR_HIGHER;
} else {
// Additional references: check the type hierarchy of FROM to distinguish public from
// protected
return Reference.Flag.PROTECTED_OR_HIGHER;
}
}
/**
* Compute the minimum required access for FROM class to access METHODTYPE on the TO class.
*
@ -163,7 +181,10 @@ public class ReferenceCreator extends ClassVisitor {
public FieldVisitor visitField(
int access, String name, String descriptor, String signature, Object value) {
// Additional references we could check
// - type of field + visible from this package
// - annotations on field
// intentionally not creating refs to fields here.
// Will create refs in method instructions to include line numbers.
return super.visitField(access, name, descriptor, signature, value);
}
@ -197,25 +218,46 @@ public class ReferenceCreator extends ClassVisitor {
public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
// Additional references we could check
// * DONE owner class
// * owner class has a field (name)
// * field is static or non-static
// * field's visibility from this point (NON_PRIVATE?)
// * owner class's visibility from this point (NON_PRIVATE?)
// * DONE owner class has a field (name)
// * DONE field is static or non-static
// * DONE field's visibility from this point (NON_PRIVATE?)
// * DONE owner class's visibility from this point (NON_PRIVATE?)
//
// * DONE field-source class (descriptor)
// * field-source visibility from this point (PRIVATE?)
// * DONE field-source visibility from this point (PRIVATE?)
final Type ownerType = Type.getType("L" + owner + ";");
final Type fieldType = Type.getType(descriptor);
List<Reference.Flag> fieldFlags = new ArrayList<>();
fieldFlags.add(computeMinimumFieldAccess(refSourceType, ownerType));
fieldFlags.add(
opcode == Opcodes.GETSTATIC || opcode == Opcodes.PUTSTATIC
? Reference.Flag.STATIC
: Reference.Flag.NON_STATIC);
// owning class has a field
addReference(
new Reference.Builder(owner).withSource(refSourceClassName, currentLineNumber).build());
Type fieldType = Type.getType(descriptor);
if (fieldType.getSort() == Type.ARRAY) {
fieldType = fieldType.getElementType();
new Reference.Builder(ownerType.getInternalName())
.withSource(refSourceClassName, currentLineNumber)
.withFlag(computeMinimumClassAccess(refSourceType, ownerType))
.withField(
new Reference.Source[] {
new Reference.Source(refSourceClassName, currentLineNumber)
},
fieldFlags.toArray(new Reference.Flag[0]),
name,
fieldType)
.build());
Type underlyingFieldType = fieldType;
while (underlyingFieldType.getSort() == Type.ARRAY) {
underlyingFieldType = underlyingFieldType.getElementType();
}
if (fieldType.getSort() == Type.OBJECT) {
if (underlyingFieldType.getSort() == Type.OBJECT) {
addReference(
new Reference.Builder(fieldType.getInternalName())
new Reference.Builder(underlyingFieldType.getInternalName())
.withSource(refSourceClassName, currentLineNumber)
.withFlag(computeMinimumClassAccess(refSourceType, underlyingFieldType))
.build());
}
super.visitFieldInsn(opcode, owner, name, descriptor);
@ -241,7 +283,7 @@ public class ReferenceCreator extends ClassVisitor {
{ // ref for method return type
Type returnType = methodType.getReturnType();
if (returnType.getSort() == Type.ARRAY) {
while (returnType.getSort() == Type.ARRAY) {
returnType = returnType.getElementType();
}
if (returnType.getSort() == Type.OBJECT) {
@ -254,7 +296,7 @@ public class ReferenceCreator extends ClassVisitor {
}
// refs for method param types
for (Type paramType : methodType.getArgumentTypes()) {
if (paramType.getSort() == Type.ARRAY) {
while (paramType.getSort() == Type.ARRAY) {
paramType = paramType.getElementType();
}
if (paramType.getSort() == Type.OBJECT) {
@ -267,9 +309,6 @@ public class ReferenceCreator extends ClassVisitor {
}
Type ownerType = Type.getType("L" + owner + ";");
if (ownerType.getSort() == Type.ARRAY) {
ownerType = ownerType.getElementType();
}
final List<Reference.Flag> methodFlags = new ArrayList<>();
methodFlags.add(

View File

@ -19,6 +19,7 @@ import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;
import net.bytebuddy.jar.asm.ClassReader;
import net.bytebuddy.jar.asm.ClassVisitor;
import net.bytebuddy.jar.asm.FieldVisitor;
import net.bytebuddy.jar.asm.MethodVisitor;
import net.bytebuddy.jar.asm.Opcodes;
import net.bytebuddy.jar.asm.Type;
@ -130,6 +131,7 @@ public class ReferenceMatcher {
private final List<UnloadedType> unloadedInterfaces = new ArrayList<>();
private int flags;
private final List<Method> methods = new ArrayList<>();
private final List<Field> fields = new ArrayList<>();
public static UnloadedType of(String className, ClassLoader classLoader) throws Exception {
className = Utils.getInternalName(className);
@ -189,23 +191,48 @@ public class ReferenceMatcher {
return mismatches;
}
public List<Reference.Mismatch> checkMatch(Reference.Method method) {
public List<Reference.Mismatch> checkMatch(Reference.Field fieldRef) {
final List<Reference.Mismatch> mismatches = new ArrayList<>(0);
// does the method exist?
Method unloadedMethod = findMethod(method, true);
final Field unloadedField = findField(fieldRef, true);
if (unloadedField == null) {
mismatches.add(
new Reference.Mismatch.MissingField(
fieldRef.getSources().toArray(new Reference.Source[0]),
className,
fieldRef.getName(),
fieldRef.getType().getInternalName()));
} else {
for (Reference.Flag flag : fieldRef.getFlags()) {
if (!flag.matches(unloadedField.getFlags())) {
final String desc = this.getClassName() + "#" + unloadedField.signature;
mismatches.add(
new Mismatch.MissingFlag(
fieldRef.getSources().toArray(new Source[0]),
desc,
flag,
unloadedField.getFlags()));
}
}
}
return mismatches;
}
public List<Reference.Mismatch> checkMatch(Reference.Method methodRef) {
final List<Reference.Mismatch> mismatches = new ArrayList<>(0);
final Method unloadedMethod = findMethod(methodRef, true);
if (unloadedMethod == null) {
mismatches.add(
new Reference.Mismatch.MissingMethod(
method.getSources().toArray(new Reference.Source[0]),
methodRef.getSources().toArray(new Reference.Source[0]),
className,
method.toString()));
methodRef.toString()));
} else {
for (Reference.Flag flag : method.getFlags()) {
for (Reference.Flag flag : methodRef.getFlags()) {
if (!flag.matches(unloadedMethod.getFlags())) {
final String desc = this.getClassName() + "#" + unloadedMethod.signature;
mismatches.add(
new Mismatch.MissingFlag(
method.getSources().toArray(new Source[0]),
methodRef.getSources().toArray(new Source[0]),
desc,
flag,
unloadedMethod.getFlags()));
@ -243,10 +270,32 @@ public class ReferenceMatcher {
return null;
}
public boolean hasField(Reference.Field field) {
// TODO does the field exist?
// TODO are the expected field flags present (static, public, etc)
throw new RuntimeException("TODO");
private Field findField(Reference.Field fieldRef, boolean includePrivateFields) {
final Field key = new Field(0, fieldRef.getName(), fieldRef.getType().getDescriptor());
final int index = fields.indexOf(key);
if (index != -1) {
final Field foundField = fields.get(index);
if (foundField.is(Opcodes.ACC_PRIVATE)) {
return includePrivateFields ? foundField : null;
} else {
return foundField;
}
} else {
Field superField = null;
if (unloadedSuper != null) {
superField = unloadedSuper.findField(fieldRef, false);
if (superField != null) {
return superField;
}
}
for (UnloadedType unloadedInterface : unloadedInterfaces) {
superField = unloadedInterface.findField(fieldRef, false);
if (superField != null) {
return superField;
}
}
}
return null;
}
@Override
@ -264,6 +313,13 @@ public class ReferenceMatcher {
super.visit(version, access, name, signature, superName, interfaces);
}
@Override
public FieldVisitor visitField(
int access, String name, String descriptor, String signature, Object value) {
fields.add(new Field(access, name, descriptor));
return super.visitField(access, name, descriptor, signature, value);
}
@Override
public MethodVisitor visitMethod(
final int access,
@ -271,8 +327,6 @@ public class ReferenceMatcher {
final String descriptor,
final String signature,
final String[] exceptions) {
// Additional references we could check
// - Classes in signature (return type, params) and visible from this package
methods.add(new Method(access, name, descriptor));
return super.visitMethod(access, name, descriptor, signature, exceptions);
}
@ -304,7 +358,45 @@ public class ReferenceMatcher {
@Override
public boolean equals(Object o) {
if (o instanceof Method) {
return signature.toString().equals(((Method) o).signature);
return signature.equals(((Method) o).signature);
}
return false;
}
@Override
public int hashCode() {
return signature.hashCode();
}
}
private static class Field {
private final int flags;
// name + typeDesc
private final String signature;
public Field(int flags, String name, String typeDesc) {
this.flags = flags;
this.signature = name + typeDesc;
}
private int getFlags() {
return flags;
}
public boolean is(int flag) {
boolean result = (flags & flag) != 0;
return result;
}
@Override
public String toString() {
return new StringBuilder("Unloaded: ").append(signature).toString();
}
@Override
public boolean equals(Object o) {
if (o instanceof Field) {
return signature.equals(((Field) o).signature);
}
return false;
}

View File

@ -27,6 +27,7 @@ class ReferenceCreatorTest extends AgentTestRunner {
references.get('java.lang.Object').getFlags().contains(Reference.Flag.PUBLIC)
references.get('muzzle.TestClasses$MethodBodyAdvice$B').getFlags().contains(Reference.Flag.PACKAGE_OR_HIGHER)
// method refs
Set<Reference.Method> bMethods = references.get('muzzle.TestClasses$MethodBodyAdvice$B').getMethods()
findMethod(bMethods, "aMethod", "(Ljava/lang/String;)Ljava/lang/String;") != null
findMethod(bMethods, "aMethodWithPrimitives", "(Z)V") != null
@ -35,6 +36,15 @@ class ReferenceCreatorTest extends AgentTestRunner {
findMethod(bMethods, "aMethod", "(Ljava/lang/String;)Ljava/lang/String;").getFlags().contains(Reference.Flag.NON_STATIC)
findMethod(bMethods, "aStaticMethod", "()V").getFlags().contains(Reference.Flag.STATIC)
// field refs
references.get('muzzle.TestClasses$MethodBodyAdvice$B').getFields().isEmpty()
Set<Reference.Field> aFieldRefs = references.get('muzzle.TestClasses$MethodBodyAdvice$A').getFields()
findField(aFieldRefs, "b").getFlags().contains(Reference.Flag.PACKAGE_OR_HIGHER)
findField(aFieldRefs, "b").getFlags().contains(Reference.Flag.NON_STATIC)
findField(aFieldRefs, "staticB").getFlags().contains(Reference.Flag.PACKAGE_OR_HIGHER)
findField(aFieldRefs, "staticB").getFlags().contains(Reference.Flag.STATIC)
aFieldRefs.size() == 2
}
def "protected ref test"() {
@ -47,7 +57,7 @@ class ReferenceCreatorTest extends AgentTestRunner {
findMethod(bMethods, "protectedMethod", "()V").getFlags().contains(Reference.Flag.PROTECTED_OR_HIGHER)
}
private static findMethod(Set<Reference.Method> methods, String methodName, String methodDesc) {
private static Reference.Method findMethod(Set<Reference.Method> methods, String methodName, String methodDesc) {
for (Reference.Method method : methods) {
if (method == new Reference.Method(methodName, methodDesc)) {
return method
@ -55,4 +65,13 @@ class ReferenceCreatorTest extends AgentTestRunner {
}
return null
}
private static Reference.Field findField(Set<Reference.Field> fields, String fieldName) {
for (Reference.Field field : fields) {
if (field.getName().equals(fieldName)) {
return field
}
}
return null
}
}

View File

@ -4,7 +4,6 @@ import datadog.trace.agent.test.AgentTestRunner
import datadog.trace.agent.test.TestUtils
import datadog.trace.agent.tooling.Utils
import datadog.trace.agent.tooling.muzzle.Reference
import datadog.trace.agent.tooling.muzzle.Reference.Method
import datadog.trace.agent.tooling.muzzle.Reference.Source
import datadog.trace.agent.tooling.muzzle.Reference.Flag
import datadog.trace.agent.tooling.muzzle.ReferenceCreator
@ -69,48 +68,87 @@ class ReferenceMatcherTest extends AgentTestRunner {
ReferenceMatcher.UnloadedType unloadedB = ReferenceMatcher.UnloadedType.of(MethodBodyAdvice.B.getName(), MethodBodyAdvice.B.getClassLoader())
ReferenceMatcher.UnloadedType unloadedB2 = ReferenceMatcher.UnloadedType.of(MethodBodyAdvice.B2.getName(), MethodBodyAdvice.B2.getClassLoader())
ReferenceMatcher.UnloadedType unloadedInterface = ReferenceMatcher.UnloadedType.of(MethodBodyAdvice.AnotherInterface.getName(), MethodBodyAdvice.AnotherInterface.getClassLoader())
Method methodRef
Reference.Method methodRef
// match method declared in the class
when:
methodRef = new Method("aMethod", "(Ljava/lang/String;)Ljava/lang/String;")
methodRef = new Reference.Method("aMethod", "(Ljava/lang/String;)Ljava/lang/String;")
then:
unloadedB.checkMatch(methodRef).size() == 0
// match method declared in the supertype
when:
methodRef = new Method("hashCode", "()I")
methodRef = new Reference.Method("hashCode", "()I")
then:
unloadedB.checkMatch(methodRef).size() == 0
// match method declared in interface
when:
methodRef = new Method("someMethod", "()V")
methodRef = new Reference.Method("someMethod", "()V")
then:
unloadedInterface.checkMatch(methodRef).size() == 0
// match private method in the class
when:
methodRef = new Method("privateStuff", "()V")
methodRef = new Reference.Method("privateStuff", "()V")
then:
unloadedB.checkMatch(methodRef).size() == 0
// fail to match private method in superclass
when:
methodRef = new Method("privateStuff", "()V")
methodRef = new Reference.Method("privateStuff", "()V")
then:
unloadedB2.checkMatch(methodRef).size() == 1
// static method flag mismatch
when:
methodRef = new Method(new Source[0], [Flag.NON_STATIC] as Flag[], "aStaticMethod", Type.getType("V"))
methodRef = new Reference.Method(new Source[0], [Flag.NON_STATIC] as Flag[], "aStaticMethod", Type.getType("V"))
then:
unloadedB2.checkMatch(methodRef).size() == 1
// missing method mismatch
when:
methodRef = new Method(new Source[0], new Flag[0], "missingTestMethod", Type.VOID_TYPE, new Type[0])
methodRef = new Reference.Method(new Source[0], new Flag[0], "missingTestMethod", Type.VOID_TYPE, new Type[0])
then:
unloadedB.checkMatch(methodRef).size() == 1
}
def "match fields" () {
ReferenceMatcher.UnloadedType unloadedA = ReferenceMatcher.UnloadedType.of(MethodBodyAdvice.A.getName(), MethodBodyAdvice.A.getClassLoader())
ReferenceMatcher.UnloadedType unloadedA2 = ReferenceMatcher.UnloadedType.of(MethodBodyAdvice.A2.getName(), MethodBodyAdvice.A2.getClassLoader())
Reference.Field fieldRef
when:
fieldRef = new Reference.Field(new Source[0], new Flag[0], "missingField", Type.getType("Ljava/lang/String;"))
then:
unloadedA.checkMatch(fieldRef).size() == 1
when:
// wrong field type sig should create a mismatch
fieldRef = new Reference.Field(new Source[0], new Flag[0], "privateField", Type.getType("Ljava/lang/String;"))
then:
unloadedA.checkMatch(fieldRef).size() == 1
when:
fieldRef = new Reference.Field(new Source[0], new Flag[0], "privateField", Type.getType("Ljava/lang/Object;"))
then:
unloadedA.checkMatch(fieldRef).size() == 0
unloadedA2.checkMatch(fieldRef).size() == 1
when:
fieldRef = new Reference.Field(new Source[0], [Flag.NON_STATIC, Flag.PROTECTED_OR_HIGHER] as Flag[], "protectedField", Type.getType("Ljava/lang/Object;"))
then:
unloadedA.checkMatch(fieldRef).size() == 0
unloadedA2.checkMatch(fieldRef).size() == 0
when:
fieldRef = new Reference.Field(new Source[0], [Flag.STATIC] as Flag[], "protectedField", Type.getType("Ljava/lang/Object;"))
then:
unloadedA.checkMatch(fieldRef).size() == 1
when:
fieldRef = new Reference.Field(new Source[0], [Flag.PROTECTED_OR_HIGHER, Flag.STATIC] as Flag[], "staticB", Type.getType(MethodBodyAdvice.B))
then:
unloadedA.checkMatch(fieldRef).size() == 0
}
}

View File

@ -14,10 +14,14 @@ public class TestClasses {
a.b.aMethodWithPrimitives(false);
a.b.aMethodWithArrays(new String[0]);
B.aStaticMethod();
A.staticB.aMethod("bar");
}
public static class A {
public B b = new B();
protected Object protectedField = null;
private Object privateField = null;
public static B staticB = new B();
}
public static class B {
@ -45,6 +49,8 @@ public class TestClasses {
}
}
public static class A2 extends A {}
public interface SomeInterface {
void someMethod();
}