From 47be19db8d65d5bc18f8092db7610fa34353c8cd Mon Sep 17 00:00:00 2001 From: SYC_ Date: Sat, 30 Sep 2023 23:44:48 +0800 Subject: [PATCH] fix: wrong parameter value when using To/When on generic functions Change-Id: I8cf1138d5d7964a6eeee8a8964eb8f7461cbce60 --- internal/monkey/fn/copy_darwin.go | 2 +- internal/monkey/patch.go | 2 - internal/tool/call.go | 7 -- internal/tool/check.go | 8 +-- mock.go | 105 ++++++++++++++++++------------ mock_condition.go | 85 +++++++++++++++++++----- mock_generics.go | 33 ++++++++++ mock_generics_test.go | 86 ++++++++++++++++++++++-- 8 files changed, 252 insertions(+), 76 deletions(-) diff --git a/internal/monkey/fn/copy_darwin.go b/internal/monkey/fn/copy_darwin.go index dca8d0b..2157164 100644 --- a/internal/monkey/fn/copy_darwin.go +++ b/internal/monkey/fn/copy_darwin.go @@ -31,7 +31,7 @@ func Copy(targetPtr, oriFn interface{}) { targetType := targetVal.Type().Elem() tool.Assert(targetType.Kind() == reflect.Func, "'%v' is not a function pointer", targetPtr) oriVal := reflect.ValueOf(oriFn) - tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0), "target and ori not match") + tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0, 0), "target and ori not match") oriAddr := oriVal.Pointer() tool.DebugPrintf("Copy: copy start for %v\n", runtime.FuncForPC(oriAddr).Name()) diff --git a/internal/monkey/patch.go b/internal/monkey/patch.go index c3d156f..8f4ee24 100644 --- a/internal/monkey/patch.go +++ b/internal/monkey/patch.go @@ -44,8 +44,6 @@ func (p *Patch) Unpatch() { func PatchValue(target, hook, proxy reflect.Value, unsafe, generic bool) *Patch { tool.Assert(hook.Kind() == reflect.Func, "'%s' is not a function", hook.Kind()) tool.Assert(proxy.Kind() == reflect.Ptr, "'%v' is not a function pointer", proxy.Kind()) - tool.Assert(hook.Type() == target.Type(), "'%v' and '%s' mismatch", hook.Type(), target.Type()) - tool.Assert(proxy.Elem().Type() == target.Type(), "'*%v' and '%s' mismatch", proxy.Elem().Type(), target.Type()) targetAddr := target.Pointer() if generic { diff --git a/internal/tool/call.go b/internal/tool/call.go index 79263de..40dd333 100644 --- a/internal/tool/call.go +++ b/internal/tool/call.go @@ -20,13 +20,6 @@ import ( "reflect" ) -func ReflectCallWithShiftOne(f reflect.Value, args []reflect.Value, shift bool) []reflect.Value { - if shift { - return ReflectCall(f, args[1:]) - } - return ReflectCall(f, args) -} - func ReflectCall(f reflect.Value, args []reflect.Value) []reflect.Value { if f.Type().IsVariadic() { newArgs := make([]reflect.Value, 0) diff --git a/internal/tool/check.go b/internal/tool/check.go index a2ee1fd..42e5871 100644 --- a/internal/tool/check.go +++ b/internal/tool/check.go @@ -32,10 +32,10 @@ func CheckReturnType(fn interface{}, results ...interface{}) { } } -func CheckFuncArgs(a, b reflect.Type, shift int) bool { - if a.NumIn() == b.NumIn()+shift { - for i := shift; i < a.NumIn(); i++ { - if a.In(i) != b.In(i-shift) { +func CheckFuncArgs(a, b reflect.Type, shiftA, shiftB int) bool { + if a.NumIn()-shiftA == b.NumIn()-shiftB { + for indexA, indexB := shiftA, shiftB; indexA < a.NumIn(); indexA, indexB = indexA+1, indexB+1 { + if a.In(indexA) != b.In(indexB) { return false } } diff --git a/mock.go b/mock.go index ab9e3a3..1391d0d 100644 --- a/mock.go +++ b/mock.go @@ -34,9 +34,9 @@ const ( ) type Mocker struct { - target reflect.Value // 目标函数 - hook reflect.Value // mock函数 - proxy interface{} // mock之后,原函数地址 + target reflect.Value // mock target value + hook reflect.Value // mock hook + proxy interface{} // proxy function to origin times int64 mockTimes int64 patch *monkey.Patch @@ -44,21 +44,25 @@ type Mocker struct { isPatched bool builder *MockBuilder - outerCaller tool.CallerInfo // Mocker 的外部调用位置 + outerCaller tool.CallerInfo } type MockBuilder struct { - target interface{} // 目标函数 - // hook interface{} // mock函数 - proxyCaller interface{} // mock之后,原函数地址 - // when interface{} // 条件函数 - conditions []*mockCondition // 条件转移 + target interface{} // mock target + proxyCaller interface{} // origin function caller hook + conditions []*mockCondition // mock conditions filterGoroutine FilterGoroutineType gId int64 unsafe bool generic bool } +// Mock mocks target function +// +// If target is a generic method or method of generic types, you need add a genericOpt, like this: +// +// func f[int, float64](x int, y T1) T2 +// Mock(f[int, float64], OptGeneric) func Mock(target interface{}, opt ...optionFn) *MockBuilder { tool.AssertFunc(target) @@ -79,11 +83,38 @@ func MockUnsafe(target interface{}) *MockBuilder { return Mock(target, OptUnsafe) } +func (builder *MockBuilder) hookType() reflect.Type { + targetType := reflect.TypeOf(builder.target) + if builder.generic { + targetIn := []reflect.Type{genericInfoType} + for i := 0; i < targetType.NumIn(); i++ { + targetIn = append(targetIn, targetType.In(i)) + } + targetOut := []reflect.Type{} + for i := 0; i < targetType.NumOut(); i++ { + targetOut = append(targetOut, targetType.Out(i)) + } + return reflect.FuncOf(targetIn, targetOut, targetType.IsVariadic()) + } + return targetType +} + func (builder *MockBuilder) resetCondition() *MockBuilder { builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed return builder } +// Origin add an origin hook which can be used to call un-mocked origin function +// +// For example: +// +// origin := Fun // only need the same type +// mock := func(p string) string { +// return origin(p + "mocked") +// } +// mock2 := Mock(Fun).To(mock).Origin(&origin).Build() +// +// Origin only works when call origin hook directly, target will still be mocked in recursive call func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder { tool.Assert(builder.proxyCaller == nil, "re-set builder origin") return builder.origin(funcPtr) @@ -187,15 +218,15 @@ func (builder *MockBuilder) Build() *Mocker { return &mocker } -func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool { +func (mocker *Mocker) missReceiver(target reflect.Type, hook interface{}) bool { hType := reflect.TypeOf(hook) tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind()) tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook) // has receiver - if tool.CheckFuncArgs(target, hType, 0) { + if tool.CheckFuncArgs(target, hType, 0, 0) { return false } - if tool.CheckFuncArgs(target, hType, 1) { + if tool.CheckFuncArgs(target, hType, 1, 0) { return true } tool.Assert(false, "target:%v, hook:%v args not match", target, hook) @@ -205,40 +236,36 @@ func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool func (mocker *Mocker) buildHook() { proxySetter := mocker.buildProxy() - origin := reflect.ValueOf(mocker.proxy).Elem() originExec := func(args []reflect.Value) []reflect.Value { - return tool.ReflectCall(origin, args) + return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args) } match := []func(args []reflect.Value) bool{} exec := []func(args []reflect.Value) []reflect.Value{} - for _, condition := range mocker.builder.conditions { - when := condition.when - hook := condition.hook - - if when == nil { + for i := range mocker.builder.conditions { + condition := mocker.builder.conditions[i] + if condition.when == nil { // when condition is not set, just go into hook exec match = append(match, func(args []reflect.Value) bool { return true }) } else { - missWhenReceiver := mocker.checkReceiver(mocker.target.Type(), when) match = append(match, func(args []reflect.Value) bool { - return tool.ReflectCallWithShiftOne(reflect.ValueOf(when), args, missWhenReceiver)[0].Bool() + return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool() }) } - if hook == nil { + if condition.hook == nil { + // hook condition is not set, just go into original exec exec = append(exec, originExec) } else { - missHookReceiver := mocker.checkReceiver(mocker.target.Type(), hook) exec = append(exec, func(args []reflect.Value) []reflect.Value { mocker.mock() - return tool.ReflectCallWithShiftOne(reflect.ValueOf(hook), args, missHookReceiver) + return tool.ReflectCall(reflect.ValueOf(condition.hook), args) }) } } - mockerHook := reflect.MakeFunc(mocker.target.Type(), func(args []reflect.Value) []reflect.Value { + mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value { proxySetter(args) // 设置origin调用proxy mocker.access() @@ -267,29 +294,27 @@ func (mocker *Mocker) buildHook() { mocker.hook = mockerHook } +// buildProx create a proxyCaller which could call origin directly func (mocker *Mocker) buildProxy() func(args []reflect.Value) { - proxy := reflect.New(mocker.target.Type()) + proxy := reflect.New(mocker.builder.hookType()) proxyCallerSetter := func(args []reflect.Value) {} - missProxyReceiver := false if mocker.builder.proxyCaller != nil { pVal := reflect.ValueOf(mocker.builder.proxyCaller) tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer") pElem := pVal.Elem() - missProxyReceiver = mocker.checkReceiver(mocker.target.Type(), pElem.Interface()) - if missProxyReceiver { - proxyCallerSetter = func(args []reflect.Value) { - pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) { - return tool.ReflectCall(proxy.Elem(), append(args[0:1], innerArgs...)) - })) - } - } else { - proxyCallerSetter = func(args []reflect.Value) { - pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) { - return tool.ReflectCall(proxy.Elem(), innerArgs) - })) - } + shift := 0 + if mocker.builder.generic { + shift += 1 + } + if mocker.missReceiver(mocker.target.Type(), pElem.Interface()) { + shift += 1 + } + proxyCallerSetter = func(args []reflect.Value) { + pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) { + return tool.ReflectCall(proxy.Elem(), append(args[0:shift], innerArgs...)) + })) } } mocker.proxy = proxy.Interface() diff --git a/mock_condition.go b/mock_condition.go index 4719e3a..4105790 100644 --- a/mock_condition.go +++ b/mock_condition.go @@ -43,8 +43,19 @@ func (m *mockCondition) SetWhenForce(when interface{}) { tool.Assert(wVal.Type().NumOut() == 1, "when func ret value not bool") out1 := wVal.Type().Out(0) tool.Assert(out1.Kind() == reflect.Bool, "when func ret value not bool") - checkReceiver(reflect.TypeOf(m.builder.target), when) // inputs must be in same or has an extra self receiver - m.when = when + + hookType := m.builder.hookType() + inTypes := []reflect.Type{} + for i := 0; i < hookType.NumIn(); i++ { + inTypes = append(inTypes, hookType.In(i)) + } + + hasGeneric, hasReceiver := m.checkGenericAndReceiver(wVal.Type()) + whenType := reflect.FuncOf(inTypes, []reflect.Type{out1}, hookType.IsVariadic()) + m.when = reflect.MakeFunc(whenType, func(args []reflect.Value) (results []reflect.Value) { + results = tool.ReflectCall(wVal, m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver)) + return + }).Interface() } func (m *mockCondition) SetReturn(results ...interface{}) { @@ -61,15 +72,15 @@ func (m *mockCondition) SetReturnForce(results ...interface{}) { } } - targetType := reflect.TypeOf(m.builder.target) - m.hook = reflect.MakeFunc(targetType, func(args []reflect.Value) []reflect.Value { + hookType := m.builder.hookType() + m.hook = reflect.MakeFunc(hookType, func(_ []reflect.Value) []reflect.Value { results := getResult() tool.CheckReturnType(m.builder.target, results...) valueResults := make([]reflect.Value, 0) for i, result := range results { - rValue := reflect.Zero(targetType.Out(i)) + rValue := reflect.Zero(hookType.Out(i)) if result != nil { - rValue = reflect.ValueOf(result).Convert(targetType.Out(i)) + rValue = reflect.ValueOf(result).Convert(hookType.Out(i)) } valueResults = append(valueResults, rValue) } @@ -85,20 +96,60 @@ func (m *mockCondition) SetTo(to interface{}) { func (m *mockCondition) SetToForce(to interface{}) { hType := reflect.TypeOf(to) tool.Assert(hType.Kind() == reflect.Func, "to a is not a func") - m.hook = to + hasGeneric, hasReceiver := m.checkGenericAndReceiver(hType) + tool.Assert(m.builder.generic || !hasGeneric, "non-generic function should not have 'GenericInfo' as first argument") + m.hook = reflect.MakeFunc(m.builder.hookType(), func(args []reflect.Value) (results []reflect.Value) { + results = tool.ReflectCall(reflect.ValueOf(to), m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver)) + return + }).Interface() } -func checkReceiver(target reflect.Type, hook interface{}) bool { - hType := reflect.TypeOf(hook) - tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind()) - tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook) +// checkGenericAndReceiver check if typ has GenericsInfo and selfReceiver as argument +// +// The hook function will looks like func(_ GenericInfo, self *struct, arg0 int ...) +// When we use 'When' or 'To', our input hook function will looks like: +// 1. func(arg0 int ...) +// 2. func(info GenericInfo, arg0 int ...) +// 3. func(self *struct, arg0 int ...) +// 4. func(info GenericInfo, self *struct, arg0 int ...) +// +// All above input hooks are legal, but we need to make an adaptation when calling then +func (m *mockCondition) checkGenericAndReceiver(typ reflect.Type) (bool, bool) { + targetType := reflect.TypeOf(m.builder.target) + tool.Assert(typ.Kind() == reflect.Func, "Param(%v) a is not a func", typ.Kind()) + tool.Assert(targetType.IsVariadic() == typ.IsVariadic(), "target:%v, hook:%v args not match", targetType, typ) + + shiftTyp := 0 + if typ.NumIn() > 0 && typ.In(0) == genericInfoType { + shiftTyp = 1 + } + // has receiver - if tool.CheckFuncArgs(target, hType, 0) { - return false + if tool.CheckFuncArgs(targetType, typ, 0, shiftTyp) { + return shiftTyp == 1, true + } + + if tool.CheckFuncArgs(targetType, typ, 1, shiftTyp) { + return shiftTyp == 1, false + } + tool.Assert(false, "target:%v, hook:%v args not match", targetType, typ) + return false, false +} + +// adaptArgsForReflectCall makes an adaption for reflect call +// +// see (*mockCondition).checkGenericAndReceiver for more info +func (m *mockCondition) adaptArgsForReflectCall(args []reflect.Value, hasGeneric, hasReceiver bool) []reflect.Value { + adaption := []reflect.Value{} + if m.builder.generic { + if hasGeneric { + adaption = append(adaption, args[0]) + } + args = args[1:] } - if tool.CheckFuncArgs(target, hType, 1) { - return true + if !hasReceiver { + args = args[1:] } - tool.Assert(false, "target:%v, hook:%v args not match", target, hook) - return false + adaption = append(adaption, args...) + return adaption } diff --git a/mock_generics.go b/mock_generics.go index 571ddee..e0c9a21 100644 --- a/mock_generics.go +++ b/mock_generics.go @@ -16,6 +16,39 @@ package mockey +import ( + "reflect" + "unsafe" +) + +// MockGeneric mocks generic function +// +// Target must be generic method or method of generic types func MockGeneric(target interface{}) *MockBuilder { return Mock(target, OptGeneric) } + +type GenericInfo uintptr + +var genericInfoType = reflect.TypeOf(GenericInfo(0)) + +func (g GenericInfo) Equal(other GenericInfo) bool { + return g == other +} + +// UsedParamType get the type of used parameter in generic function/struct +// +// For example: assume we have generic function "f[int, float64](x int, y T1) T2" and derived type f[int, float64]: +// +// UsedParamType(0) == reflect.TypeOf(int(0)) +// UsedParamType(1) == reflect.TypeOf(float64(0)) +// +// If index n is out of range, or the derived types have more complex structure(for example: define an generic struct +// in a generic function using generic types, unused parameterized type etc.), this function may return unexpected value +// or cause unrecoverable runtime error . So it is NOT RECOMMENDED to use this function unless you actually knows what +// you are doing. +func (g GenericInfo) UsedParamType(n uintptr) reflect.Type { + var vt interface{} + *(*uintptr)(unsafe.Pointer(&vt)) = *(*uintptr)(unsafe.Pointer(uintptr(g) + 8*n)) + return reflect.TypeOf(vt) +} diff --git a/mock_generics_test.go b/mock_generics_test.go index bb919f3..e945d8f 100644 --- a/mock_generics_test.go +++ b/mock_generics_test.go @@ -20,6 +20,8 @@ package mockey import ( + "fmt" + "reflect" "testing" "github.com/smartystreets/goconvey/convey" @@ -37,29 +39,37 @@ func (g generic[T]) Value() T { return g.a } -func (g generic[T]) Value2() T { +func (g generic[T]) Value2(hint string) T { return g.a + g.a } func TestGeneric(t *testing.T) { PatchConvey("generic", t, func() { PatchConvey("func", func() { + arg1, arg2 := 1, 2 MockGeneric(sum[int]).To(func(a, b int) int { + convey.So(a, convey.ShouldEqual, arg1) + convey.So(b, convey.ShouldEqual, arg2) return 999 }).Build() - MockGeneric(sum[float64]).Return(888).Build() convey.So(sum[int](1, 2), convey.ShouldEqual, 999) + + MockGeneric(sum[float64]).Return(888).Build() convey.So(sum[float64](1, 2), convey.ShouldEqual, 888) }) PatchConvey("type", func() { Mock((generic[int]).Value, OptGeneric).Return(999).Build() - Mock(GetMethod(generic[string]{}, "Value2"), OptGeneric).To(func() string { + gi := generic[int]{a: 123} + convey.So(gi.Value(), convey.ShouldEqual, 999) + + arg1 := "hint" + Mock(GetMethod(generic[string]{}, "Value2"), OptGeneric).To(func(hint string) string { + convey.So(hint, convey.ShouldEqual, arg1) return "mock" }).Build() - gi := generic[int]{a: 123} gs := generic[string]{a: "abc"} convey.So(gi.Value(), convey.ShouldEqual, 999) - convey.So(gs.Value2(), convey.ShouldEqual, "mock") + convey.So(gs.Value2(arg1), convey.ShouldEqual, "mock") }) }) } @@ -104,6 +114,72 @@ func TestGenericArgRet(t *testing.T) { }) } +func TestGenericArgValues(t *testing.T) { + PatchConvey("args-value", t, func() { + PatchConvey("single", func() { + var arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15 uintptr = 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + MockGeneric(GenericsArg15[uintptr]).To(func(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15 uintptr) { + convey.So(_1, convey.ShouldEqual, arg1) + convey.So(_2, convey.ShouldEqual, arg2) + convey.So(_3, convey.ShouldEqual, arg3) + convey.So(_4, convey.ShouldEqual, arg4) + convey.So(_5, convey.ShouldEqual, arg5) + convey.So(_6, convey.ShouldEqual, arg6) + convey.So(_7, convey.ShouldEqual, arg7) + convey.So(_8, convey.ShouldEqual, arg8) + convey.So(_9, convey.ShouldEqual, arg9) + convey.So(_10, convey.ShouldEqual, arg10) + convey.So(_11, convey.ShouldEqual, arg11) + convey.So(_12, convey.ShouldEqual, arg12) + convey.So(_13, convey.ShouldEqual, arg13) + convey.So(_14, convey.ShouldEqual, arg14) + convey.So(_15, convey.ShouldEqual, arg15) + }).Build() + GenericsArg15(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15) + }) + PatchConvey("complex", func() { + var arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15 string = "1", "2", "3", "4", "5", " 6", "7", "8", "9", "10", "11", "12", "13", "14", "15" + MockGeneric(GenericsArg15[string]).To(func(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15 string) { + convey.So(_1, convey.ShouldEqual, arg1) + convey.So(_2, convey.ShouldEqual, arg2) + convey.So(_3, convey.ShouldEqual, arg3) + convey.So(_4, convey.ShouldEqual, arg4) + convey.So(_5, convey.ShouldEqual, arg5) + convey.So(_6, convey.ShouldEqual, arg6) + convey.So(_7, convey.ShouldEqual, arg7) + convey.So(_8, convey.ShouldEqual, arg8) + convey.So(_9, convey.ShouldEqual, arg9) + convey.So(_10, convey.ShouldEqual, arg10) + convey.So(_11, convey.ShouldEqual, arg11) + convey.So(_12, convey.ShouldEqual, arg12) + convey.So(_13, convey.ShouldEqual, arg13) + convey.So(_14, convey.ShouldEqual, arg14) + convey.So(_15, convey.ShouldEqual, arg15) + }).Build() + GenericsArg15(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15) + }) + PatchConvey("args-type", func() { + target := GenericsTemplate[int, float32, string, chan (int), []byte, struct{ _ int }] + MockGeneric(target).To( + func(info GenericInfo, t1 int, t2 float32, t3 string) (r1 chan (int), r2 []byte, r3 struct{ _ int }) { + convey.So(info.UsedParamType(0), convey.ShouldEqual, reflect.TypeOf(t1)) + convey.So(info.UsedParamType(1), convey.ShouldEqual, reflect.TypeOf(t2)) + convey.So(info.UsedParamType(2), convey.ShouldEqual, reflect.TypeOf(t3)) + convey.So(info.UsedParamType(3), convey.ShouldEqual, reflect.TypeOf(r1)) + convey.So(info.UsedParamType(4), convey.ShouldEqual, reflect.TypeOf(r2)) + convey.So(info.UsedParamType(5), convey.ShouldEqual, reflect.TypeOf(r3)) + return + }).Build() + target(1, 2, "3") + }) + }) +} + +func GenericsTemplate[T1, T2, T3, R1, R2, R3 any](t1 T1, t2 T2, t3 T3) (r1 R1, r2 R2, r3 R3) { + fmt.Println(t1, t2, t3, r1, r2, r3) + panic("not here") +} + func GenericsArg0[T any]() { panic("0") } func GenericsArg1[T any](_ T) { panic("1") } func GenericsArg2[T any](_, _ T) { panic("2") }