Skip to content

Commit

Permalink
fix: wrong parameter value when using To/When on generic functions
Browse files Browse the repository at this point in the history
Change-Id: I8cf1138d5d7964a6eeee8a8964eb8f7461cbce60
  • Loading branch information
Sychorius committed Oct 9, 2023
1 parent 7076ba5 commit 83bfd61
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 67 deletions.
2 changes: 1 addition & 1 deletion internal/monkey/fn/copy_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func Copy(targetPtr, oriFn interface{}) {
targetType := targetVal.Type().Elem()
tool.Assert(targetType.Kind() == reflect.Func, "'%v' is not a function pointer", targetPtr)
oriVal := reflect.ValueOf(oriFn)
tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0), "target and ori not match")
tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0, 0), "target and ori not match")

oriAddr := oriVal.Pointer()
tool.DebugPrintf("Copy: copy start for %v\n", runtime.FuncForPC(oriAddr).Name())
Expand Down
4 changes: 2 additions & 2 deletions internal/monkey/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func (p *Patch) Unpatch() {
func PatchValue(target, hook, proxy reflect.Value, unsafe, generic bool) *Patch {
tool.Assert(hook.Kind() == reflect.Func, "'%s' is not a function", hook.Kind())
tool.Assert(proxy.Kind() == reflect.Ptr, "'%v' is not a function pointer", proxy.Kind())
tool.Assert(hook.Type() == target.Type(), "'%v' and '%s' mismatch", hook.Type(), target.Type())
tool.Assert(proxy.Elem().Type() == target.Type(), "'*%v' and '%s' mismatch", proxy.Elem().Type(), target.Type())
// tool.Assert(hook.Type() == target.Type(), "'%v' and '%s' mismatch", hook.Type(), target.Type())
// tool.Assert(proxy.Elem().Type() == target.Type(), "'*%v' and '%s' mismatch", proxy.Elem().Type(), target.Type())

targetAddr := target.Pointer()
if generic {
Expand Down
8 changes: 1 addition & 7 deletions internal/tool/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,8 @@ import (
"reflect"
)

func ReflectCallWithShiftOne(f reflect.Value, args []reflect.Value, shift bool) []reflect.Value {
if shift {
return ReflectCall(f, args[1:])
}
return ReflectCall(f, args)
}

func ReflectCall(f reflect.Value, args []reflect.Value) []reflect.Value {
// fmt.Println(f.Type())
if f.Type().IsVariadic() {
newArgs := make([]reflect.Value, 0)
lastArg := args[len(args)-1]
Expand Down
8 changes: 4 additions & 4 deletions internal/tool/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func CheckReturnType(fn interface{}, results ...interface{}) {
}
}

func CheckFuncArgs(a, b reflect.Type, shift int) bool {
if a.NumIn() == b.NumIn()+shift {
for i := shift; i < a.NumIn(); i++ {
if a.In(i) != b.In(i-shift) {
func CheckFuncArgs(a, b reflect.Type, shiftA, shiftB int) bool {
if a.NumIn()-shiftA == b.NumIn()-shiftB {
for indexA, indexB := shiftA, shiftB; indexA < a.NumIn(); indexA, indexB = indexA+1, indexB+1 {
if a.In(indexA) != b.In(indexB) {
return false
}
}
Expand Down
72 changes: 41 additions & 31 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ func MockUnsafe(target interface{}) *MockBuilder {
return Mock(target, OptUnsafe)
}

func (builder *MockBuilder) hookType() reflect.Type {
targetType := reflect.TypeOf(builder.target)
if builder.generic {
targetIn := []reflect.Type{genericInfoType}
for i := 0; i < targetType.NumIn(); i++ {
targetIn = append(targetIn, targetType.In(i))
}
targetOut := []reflect.Type{}
for i := 0; i < targetType.NumOut(); i++ {
targetOut = append(targetOut, targetType.Out(i))
}
return reflect.FuncOf(targetIn, targetOut, targetType.IsVariadic())
}
return targetType
}

func (builder *MockBuilder) resetCondition() *MockBuilder {
builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed
return builder
Expand Down Expand Up @@ -187,15 +203,15 @@ func (builder *MockBuilder) Build() *Mocker {
return &mocker
}

func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool {
func (mocker *Mocker) missReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
if tool.CheckFuncArgs(target, hType, 0, 0) {
return false
}
if tool.CheckFuncArgs(target, hType, 1) {
if tool.CheckFuncArgs(target, hType, 1, 0) {
return true
}
tool.Assert(false, "target:%v, hook:%v args not match", target, hook)
Expand All @@ -205,40 +221,36 @@ func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool
func (mocker *Mocker) buildHook() {
proxySetter := mocker.buildProxy()

origin := reflect.ValueOf(mocker.proxy).Elem()
originExec := func(args []reflect.Value) []reflect.Value {
return tool.ReflectCall(origin, args)
return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args)
}

match := []func(args []reflect.Value) bool{}
exec := []func(args []reflect.Value) []reflect.Value{}

for _, condition := range mocker.builder.conditions {
when := condition.when
hook := condition.hook

if when == nil {
for i := range mocker.builder.conditions {
condition := mocker.builder.conditions[i]
if condition.when == nil {
// when condition is not set, just go into hook exec
match = append(match, func(args []reflect.Value) bool { return true })
} else {
missWhenReceiver := mocker.checkReceiver(mocker.target.Type(), when)
match = append(match, func(args []reflect.Value) bool {
return tool.ReflectCallWithShiftOne(reflect.ValueOf(when), args, missWhenReceiver)[0].Bool()
return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool()
})
}

if hook == nil {
if condition.hook == nil {
// hook condition is not set, just go into original exec
exec = append(exec, originExec)
} else {
missHookReceiver := mocker.checkReceiver(mocker.target.Type(), hook)
exec = append(exec, func(args []reflect.Value) []reflect.Value {
mocker.mock()
return tool.ReflectCallWithShiftOne(reflect.ValueOf(hook), args, missHookReceiver)
return tool.ReflectCall(reflect.ValueOf(condition.hook), args)
})
}
}

mockerHook := reflect.MakeFunc(mocker.target.Type(), func(args []reflect.Value) []reflect.Value {
mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value {
proxySetter(args) // 设置origin调用proxy

mocker.access()
Expand Down Expand Up @@ -267,29 +279,27 @@ func (mocker *Mocker) buildHook() {
mocker.hook = mockerHook
}

// buildProx create a porxyCaller which could call origin directly
func (mocker *Mocker) buildProxy() func(args []reflect.Value) {
proxy := reflect.New(mocker.target.Type())
proxy := reflect.New(mocker.builder.hookType())

proxyCallerSetter := func(args []reflect.Value) {}
missProxyReceiver := false
if mocker.builder.proxyCaller != nil {
pVal := reflect.ValueOf(mocker.builder.proxyCaller)
tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer")
pElem := pVal.Elem()
missProxyReceiver = mocker.checkReceiver(mocker.target.Type(), pElem.Interface())

if missProxyReceiver {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:1], innerArgs...))
}))
}
} else {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), innerArgs)
}))
}
shift := 0
if mocker.builder.generic {
shift += 1
}
if mocker.missReceiver(mocker.target.Type(), pElem.Interface()) {
shift += 1
}
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:shift], innerArgs...))
}))
}
}
mocker.proxy = proxy.Interface()
Expand Down
85 changes: 68 additions & 17 deletions mock_condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,19 @@ func (m *mockCondition) SetWhenForce(when interface{}) {
tool.Assert(wVal.Type().NumOut() == 1, "when func ret value not bool")
out1 := wVal.Type().Out(0)
tool.Assert(out1.Kind() == reflect.Bool, "when func ret value not bool")
checkReceiver(reflect.TypeOf(m.builder.target), when) // inputs must be in same or has an extra self receiver
m.when = when

hookType := m.builder.hookType()
inTypes := []reflect.Type{}
for i := 0; i < hookType.NumIn(); i++ {
inTypes = append(inTypes, hookType.In(i))
}

hasGeneric, hasReceiver := m.checkGenericAndReceiver(wVal.Type())
whenType := reflect.FuncOf(inTypes, []reflect.Type{out1}, hookType.IsVariadic())
m.when = reflect.MakeFunc(whenType, func(args []reflect.Value) (results []reflect.Value) {
results = tool.ReflectCall(wVal, m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver))
return
}).Interface()
}

func (m *mockCondition) SetReturn(results ...interface{}) {
Expand All @@ -61,15 +72,15 @@ func (m *mockCondition) SetReturnForce(results ...interface{}) {
}
}

targetType := reflect.TypeOf(m.builder.target)
m.hook = reflect.MakeFunc(targetType, func(args []reflect.Value) []reflect.Value {
hookType := m.builder.hookType()
m.hook = reflect.MakeFunc(hookType, func(_ []reflect.Value) []reflect.Value {
results := getResult()
tool.CheckReturnType(m.builder.target, results...)
valueResults := make([]reflect.Value, 0)
for i, result := range results {
rValue := reflect.Zero(targetType.Out(i))
rValue := reflect.Zero(hookType.Out(i))
if result != nil {
rValue = reflect.ValueOf(result).Convert(targetType.Out(i))
rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
}
valueResults = append(valueResults, rValue)
}
Expand All @@ -85,20 +96,60 @@ func (m *mockCondition) SetTo(to interface{}) {
func (m *mockCondition) SetToForce(to interface{}) {
hType := reflect.TypeOf(to)
tool.Assert(hType.Kind() == reflect.Func, "to a is not a func")
m.hook = to
hasGeneric, hasReciever := m.checkGenericAndReceiver(hType)
tool.Assert(m.builder.generic || !hasGeneric, "non-generic function should not have 'GenericInfo' as first argument")
m.hook = reflect.MakeFunc(m.builder.hookType(), func(args []reflect.Value) (results []reflect.Value) {
results = tool.ReflectCall(reflect.ValueOf(to), m.adaptArgsForReflectCall(args, hasGeneric, hasReciever))
return
}).Interface()
}

func checkReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// checkGenericAndReceiver check if typ has GenericsInfo and selfReceiver as argument
//
// The hook function will looks like func(_ GenericInfo, self *struct, arg0 int ...)
// When we use 'When' or 'To', our input hook function will looks like:
// 1. func(arg0 int ...)
// 2. func(info GenericInfo, arg0 int ...)
// 3. func(self *struct, arg0 int ...)
// 4. func(info GenericInfo, self *struct, arg0 int ...)
//
// All above input hooks are legal, but we need to make an adaptation when calling then
func (m *mockCondition) checkGenericAndReceiver(typ reflect.Type) (bool, bool) {
targetType := reflect.TypeOf(m.builder.target)
tool.Assert(typ.Kind() == reflect.Func, "Param(%v) a is not a func", typ.Kind())
tool.Assert(targetType.IsVariadic() == typ.IsVariadic(), "target:%v, hook:%v args not match", targetType, typ)

shiftTyp := 0
if typ.NumIn() > 0 && typ.In(0) == genericInfoType {
shiftTyp = 1
}

// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
return false
if tool.CheckFuncArgs(targetType, typ, 0, shiftTyp) {
return shiftTyp == 1, true
}

if tool.CheckFuncArgs(targetType, typ, 1, shiftTyp) {
return shiftTyp == 1, false
}
tool.Assert(false, "target:%v, hook:%v args not match", targetType, typ)
return false, false
}

// adaptArgsForReflectCall makes an adaption for reflect call
//
// see (*mockCondition).checkGenericAndReceiver for more info
func (m *mockCondition) adaptArgsForReflectCall(args []reflect.Value, hasGeneric, hasReceiver bool) []reflect.Value {
adaption := []reflect.Value{}
if m.builder.generic {
if hasGeneric {
adaption = append(adaption, args[0])
}
args = args[1:]
}
if tool.CheckFuncArgs(target, hType, 1) {
return true
if !hasReceiver {
args = args[1:]
}
tool.Assert(false, "target:%v, hook:%v args not match", target, hook)
return false
adaption = append(adaption, args...)
return adaption
}
15 changes: 15 additions & 0 deletions mock_generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@

package mockey

import (
"reflect"
"unsafe"
)

func MockGeneric(target interface{}) *MockBuilder {
return Mock(target, OptGeneric)
}

type GenericInfo uintptr

var genericInfoType = reflect.TypeOf(GenericInfo(0))

func (g GenericInfo) TypeOfUsedArgN(n uintptr) reflect.Type {
var vt interface{}
*(*uintptr)(unsafe.Pointer(&vt)) = *(*uintptr)(unsafe.Pointer(uintptr(g) + 8*n))
return reflect.TypeOf(vt)
}
Loading

0 comments on commit 83bfd61

Please sign in to comment.