karmada/vendor/github.com/agiledragon/gomonkey/v2/patch.go

341 lines
9.3 KiB
Go

package gomonkey
import (
"fmt"
"github.com/agiledragon/gomonkey/v2/creflect"
"reflect"
"syscall"
"unsafe"
)
type Patches struct {
originals map[uintptr][]byte
values map[reflect.Value]reflect.Value
valueHolders map[reflect.Value]reflect.Value
}
type Params []interface{}
type OutputCell struct {
Values Params
Times int
}
func ApplyFunc(target, double interface{}) *Patches {
return create().ApplyFunc(target, double)
}
func ApplyMethod(target reflect.Type, methodName string, double interface{}) *Patches {
return create().ApplyMethod(target, methodName, double)
}
func ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
return create().ApplyMethodFunc(target, methodName, doubleFunc)
}
func ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
return create().ApplyPrivateMethod(target, methodName, double)
}
func ApplyGlobalVar(target, double interface{}) *Patches {
return create().ApplyGlobalVar(target, double)
}
func ApplyFuncVar(target, double interface{}) *Patches {
return create().ApplyFuncVar(target, double)
}
func ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
return create().ApplyFuncSeq(target, outputs)
}
func ApplyMethodSeq(target reflect.Type, methodName string, outputs []OutputCell) *Patches {
return create().ApplyMethodSeq(target, methodName, outputs)
}
func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
return create().ApplyFuncVarSeq(target, outputs)
}
func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches {
return create().ApplyFuncReturn(target, output...)
}
func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches {
return create().ApplyMethodReturn(target, methodName, output...)
}
func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
return create().ApplyFuncVarReturn(target, output...)
}
func create() *Patches {
return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
}
func NewPatches() *Patches {
return create()
}
func (this *Patches) ApplyFunc(target, double interface{}) *Patches {
t := reflect.ValueOf(target)
d := reflect.ValueOf(double)
return this.ApplyCore(t, d)
}
func (this *Patches) ApplyMethod(target reflect.Type, methodName string, double interface{}) *Patches {
m, ok := target.MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
d := reflect.ValueOf(double)
return this.ApplyCore(m.Func, d)
}
func (this *Patches) ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
m, ok := target.MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
d := funcToMethod(m.Type, doubleFunc)
return this.ApplyCore(m.Func, d)
}
func (this *Patches) ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
m, ok := creflect.MethodByName(target, methodName)
if !ok {
panic("retrieve method by name failed")
}
d := reflect.ValueOf(double)
return this.ApplyCoreOnlyForPrivateMethod(m, d)
}
func (this *Patches) ApplyGlobalVar(target, double interface{}) *Patches {
t := reflect.ValueOf(target)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
this.values[t] = reflect.ValueOf(t.Elem().Interface())
d := reflect.ValueOf(double)
t.Elem().Set(d)
return this
}
func (this *Patches) ApplyFuncVar(target, double interface{}) *Patches {
t := reflect.ValueOf(target)
d := reflect.ValueOf(double)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
this.check(t.Elem(), d)
return this.ApplyGlobalVar(target, double)
}
func (this *Patches) ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
funcType := reflect.TypeOf(target)
t := reflect.ValueOf(target)
d := getDoubleFunc(funcType, outputs)
return this.ApplyCore(t, d)
}
func (this *Patches) ApplyMethodSeq(target reflect.Type, methodName string, outputs []OutputCell) *Patches {
m, ok := target.MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
d := getDoubleFunc(m.Type, outputs)
return this.ApplyCore(m.Func, d)
}
func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
t := reflect.ValueOf(target)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
if t.Elem().Kind() != reflect.Func {
panic("target is not a func")
}
funcType := reflect.TypeOf(target).Elem()
double := getDoubleFunc(funcType, outputs).Interface()
return this.ApplyGlobalVar(target, double)
}
func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches {
funcType := reflect.TypeOf(target)
t := reflect.ValueOf(target)
outputs := []OutputCell{{Values: returns, Times: -1}}
d := getDoubleFunc(funcType, outputs)
return this.ApplyCore(t, d)
}
func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches {
m, ok := reflect.TypeOf(target).MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
outputs := []OutputCell{{Values: returns, Times: -1}}
d := getDoubleFunc(m.Type, outputs)
return this.ApplyCore(m.Func, d)
}
func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches {
t := reflect.ValueOf(target)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
if t.Elem().Kind() != reflect.Func {
panic("target is not a func")
}
funcType := reflect.TypeOf(target).Elem()
outputs := []OutputCell{{Values: returns, Times: -1}}
double := getDoubleFunc(funcType, outputs).Interface()
return this.ApplyGlobalVar(target, double)
}
func (this *Patches) Reset() {
for target, bytes := range this.originals {
modifyBinary(target, bytes)
delete(this.originals, target)
}
for target, variable := range this.values {
target.Elem().Set(variable)
}
}
func (this *Patches) ApplyCore(target, double reflect.Value) *Patches {
this.check(target, double)
assTarget := *(*uintptr)(getPointer(target))
if _, ok := this.originals[assTarget]; ok {
panic("patch has been existed")
}
this.valueHolders[double] = double
original := replace(assTarget, uintptr(getPointer(double)))
this.originals[assTarget] = original
return this
}
func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double reflect.Value) *Patches {
if double.Kind() != reflect.Func {
panic("double is not a func")
}
assTarget := *(*uintptr)(target)
if _, ok := this.originals[assTarget]; ok {
panic("patch has been existed")
}
this.valueHolders[double] = double
original := replace(assTarget, uintptr(getPointer(double)))
this.originals[assTarget] = original
return this
}
func (this *Patches) check(target, double reflect.Value) {
if target.Kind() != reflect.Func {
panic("target is not a func")
}
if double.Kind() != reflect.Func {
panic("double is not a func")
}
if target.Type() != double.Type() {
panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
}
}
func replace(target, double uintptr) []byte {
code := buildJmpDirective(double)
bytes := entryAddress(target, len(code))
original := make([]byte, len(bytes))
copy(original, bytes)
modifyBinary(target, code)
return original
}
func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
if funcType.NumOut() != len(outputs[0].Values) {
panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double",
funcType.NumOut(), len(outputs[0].Values)))
}
needReturn := false
slice := make([]Params, 0)
for _, output := range outputs {
if output.Times == -1 {
needReturn = true
slice = []Params{output.Values}
break
}
t := 0
if output.Times <= 1 {
t = 1
} else {
t = output.Times
}
for j := 0; j < t; j++ {
slice = append(slice, output.Values)
}
}
i := 0
lenOutputs := len(slice)
return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
if needReturn {
return GetResultValues(funcType, slice[0]...)
}
if i < lenOutputs {
i++
return GetResultValues(funcType, slice[i-1]...)
}
panic("double seq is less than call seq")
})
}
func GetResultValues(funcType reflect.Type, results ...interface{}) []reflect.Value {
var resultValues []reflect.Value
for i, r := range results {
var resultValue reflect.Value
if r == nil {
resultValue = reflect.Zero(funcType.Out(i))
} else {
v := reflect.New(funcType.Out(i))
v.Elem().Set(reflect.ValueOf(r))
resultValue = v.Elem()
}
resultValues = append(resultValues, resultValue)
}
return resultValues
}
type funcValue struct {
_ uintptr
p unsafe.Pointer
}
func getPointer(v reflect.Value) unsafe.Pointer {
return (*funcValue)(unsafe.Pointer(&v)).p
}
func entryAddress(p uintptr, l int) []byte {
return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l}))
}
func pageStart(ptr uintptr) uintptr {
return ptr & ^(uintptr(syscall.Getpagesize() - 1))
}
func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value {
rf := reflect.TypeOf(doubleFunc)
if rf.Kind() != reflect.Func {
panic("doubleFunc is not a func")
}
vf := reflect.ValueOf(doubleFunc)
return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
return vf.Call(in[1:])
})
}