From d2004fe2e720d7d128fe223b3ea6814e089b057a Mon Sep 17 00:00:00 2001 From: SYC_ Date: Mon, 23 Oct 2023 12:00:29 +0000 Subject: [PATCH] feat: support private struct-method mock Change-Id: I08200a4403b7f25aaf1df984b7b622e43c2ef812 --- exp/utils_above_1_18.go | 47 ------ exp/utils_above_1_18_test.go | 69 --------- internal/unsafereflect/type.go | 136 +++++++----------- .../unsafereflect/type_above_1_17_test.go | 63 -------- internal/unsafereflect/type_test.go | 33 ++--- utils.go | 43 ++++++ utils_test.go | 55 +++++++ 7 files changed, 161 insertions(+), 285 deletions(-) delete mode 100644 exp/utils_above_1_18.go delete mode 100644 exp/utils_above_1_18_test.go delete mode 100644 internal/unsafereflect/type_above_1_17_test.go diff --git a/exp/utils_above_1_18.go b/exp/utils_above_1_18.go deleted file mode 100644 index 011b4e3..0000000 --- a/exp/utils_above_1_18.go +++ /dev/null @@ -1,47 +0,0 @@ -//go:build go1.18 -// +build go1.18 - -/* - * Copyright 2023 ByteDance Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package exp - -import ( - "unsafe" - - "github.com/bytedance/mockey/internal/tool" - "github.com/bytedance/mockey/internal/unsafereflect" -) - -// GetPrivateMemberMethod resolve a method from an instance, include private method. -// -// F must fit the shape of specific method, include receiver as the first argument. -// Especially, the receiver can be replaced as interface when F is declaring, -// this will be very useful when receiver type is not exported for other packages. -// -// for example: -// -// GetPrivateMemberMethod[func(*bytes.Buffer) bool](&bytes.Buffer{}, "empty") -// GetPrivateMemberMethod[func(hash.Hash) [sha256.Size]byte](sha256.New(), "checkSum") -func GetPrivateMemberMethod[F interface{}](instance interface{}, methodName string) interface{} { - tfn, ok := unsafereflect.MethodByName(instance, methodName) - if !ok { - tool.Assert(false, "can't reflect instance method :%v", methodName) - return nil - } - // return with (unsafe) function type cast - return *(*F)(unsafe.Pointer(&tfn)) -} diff --git a/exp/utils_above_1_18_test.go b/exp/utils_above_1_18_test.go deleted file mode 100644 index c9d24b6..0000000 --- a/exp/utils_above_1_18_test.go +++ /dev/null @@ -1,69 +0,0 @@ -//go:build go1.18 -// +build go1.18 - -/* - * Copyright 2023 ByteDance Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package exp - -import ( - "bytes" - "testing" - - "github.com/bytedance/mockey" - "github.com/bytedance/mockey/internal/tool" - "github.com/smartystreets/goconvey/convey" -) - -func TestGetPrivateMemberMethod(t *testing.T) { - mockey.PatchConvey("FakeMethod", t, func() { - convey.So(func() { - GetPrivateMemberMethod[func()](&bytes.Buffer{}, "FakeMethod") - }, convey.ShouldPanicWith, "can't reflect instance method :FakeMethod") - }) - - mockey.PatchConvey("OriginalGetMethod", t, func() { - convey.So(func() { - mockey.GetMethod(&bytes.Buffer{}, "empty") - }, convey.ShouldPanicWith, "can't reflect instance method :empty") - }) - - mockey.PatchConvey("ExportFunc", t, func() { - convey.So(func() { - exportedFunc := GetPrivateMemberMethod[func(*bytes.Buffer) int](&bytes.Buffer{}, "Len") - var mocked bool - mockey.Mock(exportedFunc).To(func(buffer *bytes.Buffer) int { - mocked = true - return 0 - }).Build() - _ = new(bytes.Buffer).Len() - tool.Assert(mocked, "function should be mocked") - }, convey.ShouldNotPanic) - }) - - mockey.PatchConvey("PrivateFunc", t, func() { - convey.So(func() { - privateFunc := GetPrivateMemberMethod[func(*bytes.Buffer) bool](&bytes.Buffer{}, "empty") - var mocked bool - mockey.Mock(privateFunc).To(func(buffer *bytes.Buffer) bool { - mocked = true - return true - }).Build() - _, _ = new(bytes.Buffer).ReadByte() - tool.Assert(mocked, "function should be mocked") - }, convey.ShouldNotPanic) - }) -} diff --git a/internal/unsafereflect/type.go b/internal/unsafereflect/type.go index af1209d..821c736 100644 --- a/internal/unsafereflect/type.go +++ b/internal/unsafereflect/type.go @@ -24,21 +24,22 @@ import ( "unsafe" ) -func MethodByName(target interface{}, name string) (fn unsafe.Pointer, ok bool) { - r := castRType(target) - rt := toRType(r) - if r.Kind() == reflect.Interface { - return funcPointer(r.MethodByName(name)) - } +func MethodByName(target interface{}, name string) (typ reflect.Type, fn unsafe.Pointer, ok bool) { + r := reflect.TypeOf(target) + rt := (*rtype)((*struct { + _ uintptr + data unsafe.Pointer + })(unsafe.Pointer(&r)).data) for _, p := range rt.methods() { if rt.nameOff(p.name).name() == name { - return rt.Method(p), true + return toType(rt.typeOff(p.mtyp)), rt.Method(p), true } } - return nil, false + return nil, nil, false } +// copy from src/reflect/type.go // rtype is the common implementation of most values. // It is embedded in other struct types. // @@ -51,36 +52,17 @@ type rtype struct { align uint8 // alignment of variable with this type fieldAlign uint8 // alignment of struct field with this type kind uint8 // enumeration for C - // function for comparing objects of this type - // (ptr to object A, ptr to object B) -> ==? + + // In go 1.13 equal was replaced with "alg *typeAlg". + // Since size(func) == size(ptr), the total size of rtype + // and alignment of other field keeps the same, we do not + // need to make an adaption for go1.13. equal func(unsafe.Pointer, unsafe.Pointer) bool gcdata *byte // garbage collection data str nameOff // string form ptrToThis typeOff // type for pointer to this type, may be zero } -func castRType(val interface{}) reflect.Type { - if rTypeVal, ok := val.(reflect.Type); ok { - return rTypeVal - } - return reflect.TypeOf(val) -} - -func toRType(t reflect.Type) *rtype { - i := *(*funcValue)(unsafe.Pointer(&t)) - r := (*rtype)(i.p) - return r -} - -type funcValue struct { - _ uintptr - p unsafe.Pointer -} - -func funcPointer(v reflect.Method, ok bool) (unsafe.Pointer, bool) { - return (*funcValue)(unsafe.Pointer(&v.Func)).p, ok -} - func (t *rtype) Method(p method) (fn unsafe.Pointer) { tfn := t.textOff(p.tfn) fn = unsafe.Pointer(&tfn) @@ -91,10 +73,12 @@ const kindMask = (1 << 5) - 1 func (t *rtype) Kind() reflect.Kind { return reflect.Kind(t.kind & kindMask) } -type tflag uint8 -type nameOff int32 // offset to a name -type typeOff int32 // offset to an *rtype -type textOff int32 // offset from top of text section +type ( + tflag uint8 + nameOff int32 // offset to a name + typeOff int32 // offset to an *rtype + textOff int32 // offset from top of text section +) // resolveNameOff resolves a name offset from a base pointer. // The (*rtype).nameOff method is a convenience wrapper for this function. @@ -107,6 +91,26 @@ func (t *rtype) nameOff(off nameOff) name { return name{(*byte)(resolveNameOff(unsafe.Pointer(t), int32(off)))} } +// resolveTypeOff resolves an *rtype offset from a base type. +// The (*rtype).typeOff method is a convenience wrapper for this function. +// +//go:linkname resolveTypeOff reflect.resolveTypeOff +func resolveTypeOff(rtype unsafe.Pointer, off int32) unsafe.Pointer + +func (t *rtype) typeOff(off typeOff) *rtype { + return (*rtype)(resolveTypeOff(unsafe.Pointer(t), int32(off))) +} + +// toType convert rtype to reflect.Type +// +// The conversion is not guaranteed to be successful. +// If conversion failed, response will be nil +func toType(r *rtype) reflect.Type { + var vt interface{} + *(*uintptr)(unsafe.Pointer(&vt)) = uintptr(unsafe.Pointer(r)) + return reflect.TypeOf(vt) +} + // resolveTextOff resolves a function pointer offset from a base type. // The (*rtype).textOff method is a convenience wrapper for this function. // Implemented in the runtime package. @@ -142,34 +146,6 @@ type funcType struct { outCount uint16 // top bit is set if last input parameter is ... } -func (t *funcType) in() []*rtype { - uadd := unsafe.Sizeof(*t) - if t.tflag&tflagUncommon != 0 { - uadd += unsafe.Sizeof(uncommonType{}) - } - if t.inCount == 0 { - return nil - } - return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "t.inCount > 0"))[:t.inCount:t.inCount] -} - -func (t *funcType) out() []*rtype { - uadd := unsafe.Sizeof(*t) - if t.tflag&tflagUncommon != 0 { - uadd += unsafe.Sizeof(uncommonType{}) - } - outCount := t.outCount & (1<<15 - 1) - if outCount == 0 { - return nil - } - return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "outCount > 0"))[t.inCount : t.inCount+outCount : t.inCount+outCount] -} - -func (t *rtype) IsVariadic() bool { - tt := (*funcType)(unsafe.Pointer(t)) - return tt.outCount&(1<<15) != 0 -} - func add(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer { return unsafe.Pointer(uintptr(p) + x) } @@ -182,8 +158,8 @@ type interfaceType struct { } type imethod struct { - name nameOff // name of method - typ typeOff // .(*FuncType) underneath + _ nameOff // unused name of method + _ typeOff // unused .(*FuncType) underneath } func (t *rtype) methods() []method { @@ -192,29 +168,25 @@ func (t *rtype) methods() []method { } switch t.Kind() { case reflect.Ptr: - type u struct { + return (*struct { ptrType u uncommonType - } - return (*u)(unsafe.Pointer(t)).u.methods() + })(unsafe.Pointer(t)).u.methods() case reflect.Func: - type u struct { + return (*struct { funcType u uncommonType - } - return (*u)(unsafe.Pointer(t)).u.methods() + })(unsafe.Pointer(t)).u.methods() case reflect.Interface: - type u struct { + return (*struct { interfaceType u uncommonType - } - return (*u)(unsafe.Pointer(t)).u.methods() + })(unsafe.Pointer(t)).u.methods() case reflect.Struct: - type u struct { + return (*struct { structType u uncommonType - } - return (*u)(unsafe.Pointer(t)).u.methods() + })(unsafe.Pointer(t)).u.methods() default: return nil } @@ -224,7 +196,7 @@ func (t *rtype) methods() []method { type method struct { name nameOff // name of method mtyp typeOff // method type (without receiver), not valid for private methods - ifn textOff // fn used in interface call (one-word receiver) + _ textOff // unused fn used in interface call (one-word receiver) tfn textOff // fn used for normal method call } @@ -237,9 +209,9 @@ func (t *uncommonType) methods() []method { // Struct field type structField struct { - name name // name is always non-empty - typ *rtype // type of field - offset uintptr // byte offset of field + _ name // unused name is always non-empty + _ *rtype // unused type of field + _ uintptr // unused byte offset of field } // structType diff --git a/internal/unsafereflect/type_above_1_17_test.go b/internal/unsafereflect/type_above_1_17_test.go deleted file mode 100644 index 7f5d459..0000000 --- a/internal/unsafereflect/type_above_1_17_test.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build go1.17 -// +build go1.17 - -/* - * Copyright 2023 ByteDance Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package unsafereflect_test - -import ( - "crypto/sha256" - "hash" - "reflect" - "testing" - "unsafe" - - "github.com/bytedance/mockey" - "github.com/bytedance/mockey/internal/tool" - "github.com/bytedance/mockey/internal/unsafereflect" -) - -func TestMethodByNameV17(t *testing.T) { - // private structure private method: *sha256.digest.checkSum - tfn, ok := unsafereflect.MethodByName(sha256.New(), "checkSum") - tool.Assert(ok, "private member of private structure is allowed") - // type of `func(*sha256.digest, []byte) [32]byte` - pFn := unsafe.Pointer(&tfn) - - mockey.PatchConvey("InterfaceFuncReturn", t, func() { - fn := *(*func(hash.Hash) [sha256.Size]byte)(pFn) - // Interface to fit the function shape is allowed here - mockey.Mock(fn).Return([sha256.Size]byte{1: 1}).Build() - rets := sha256.New().Sum(nil) - want := make([]byte, sha256.Size) - want[1] = 1 - tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") - }) - - mockey.PatchConvey("InterfaceFuncTo", t, func() { - fn := *(*func(hash.Hash) [sha256.Size]byte)(pFn) - // Interface to fit the function shape is allowed here, - // since the receiver's type is interface, To API can be used here - mockey.Mock(fn).To(func(hash.Hash) [sha256.Size]byte { - return [sha256.Size]byte{1: 1} - }).Build() - rets := sha256.New().Sum(nil) - want := make([]byte, sha256.Size) - want[1] = 1 - tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") - }) -} diff --git a/internal/unsafereflect/type_test.go b/internal/unsafereflect/type_test.go index 1cf8f12..a15b54a 100644 --- a/internal/unsafereflect/type_test.go +++ b/internal/unsafereflect/type_test.go @@ -14,38 +14,23 @@ * limitations under the License. */ -package unsafereflect_test +package unsafereflect import ( "crypto/sha256" "reflect" "testing" - "unsafe" - "github.com/bytedance/mockey" - "github.com/bytedance/mockey/internal/tool" - "github.com/bytedance/mockey/internal/unsafereflect" + "github.com/smartystreets/goconvey/convey" ) func TestMethodByName(t *testing.T) { - // private structure private method: *sha256.digest.checkSum - tfn, ok := unsafereflect.MethodByName(sha256.New(), "checkSum") - tool.Assert(ok, "private member of private structure is allowed") - // type of `func(*sha256.digest, []byte) [32]byte` - pFn := unsafe.Pointer(&tfn) - - mockey.PatchConvey("ReflectFuncReturn", t, func() { - f := reflect.FuncOf([]reflect.Type{reflect.TypeOf(sha256.New())}, - []reflect.Type{reflect.TypeOf([sha256.Size]byte{})}, false) - fn := reflect.NewAt(f, pFn).Elem().Interface() - // Such function cannot be exported as `(*sha256.digest).checkSum`, - // since the receiver's type is *sha256.digest, only Return API can be used - mockey.Mock(fn).Return([sha256.Size]byte{1: 1}).Build() - rets := sha256.New().Sum(nil) - want := make([]byte, sha256.Size) - want[1] = 1 - tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") + convey.Convey("MethodByName", t, func() { + inst := sha256.New() + // private structure private method: *sha256.digest.checkSum + typ, fn, ok := MethodByName(inst, "checkSum") + convey.So(ok, convey.ShouldBeTrue) + convey.So(fn, convey.ShouldNotBeNil) + convey.So(typ, convey.ShouldEqual, reflect.TypeOf(func() [sha256.Size]byte { return [sha256.Size]byte{} })) }) - - // See also TestMethodByNameV17 while go version above 1.17 } diff --git a/utils.go b/utils.go index 48a9c18..4f6bcac 100644 --- a/utils.go +++ b/utils.go @@ -21,6 +21,7 @@ import ( "unsafe" "github.com/bytedance/mockey/internal/tool" + "github.com/bytedance/mockey/internal/unsafereflect" ) // GetMethod resolve a certain public method from an instance. @@ -35,6 +36,10 @@ func GetMethod(instance interface{}, methodName string) interface{} { if m, ok := getFieldMethod(instance, methodName); ok { return m } + ch0 := methodName[0] + if !(ch0 >= 'A' && ch0 <= 'Z') { + return unsafeMethodByName(instance, methodName) + } } tool.Assert(false, "can't reflect instance method :%v", methodName) return nil @@ -127,6 +132,44 @@ func getNestedMethod(val reflect.Value, methodName string) (reflect.Method, bool return reflect.PtrTo(typ).MethodByName(methodName) } +// unsafeMethodByName resolve a method from an instance, include private method. +// +// THIS IS UNSAFE FOR LOWER GO VERSION(<1.12) +// +// for example: +// +// unsafeMethodByName(&bytes.Buffer{}, "empty") +// unsafeMethodByName(sha256.New(), "checkSum") +func unsafeMethodByName(instance interface{}, methodName string) interface{} { + typ, tfn, ok := unsafereflect.MethodByName(instance, methodName) + if !ok { + tool.Assert(false, "can't reflect instance method :%v", methodName) + return nil + } + if typ == nil { + tool.Assert(false, "failed to determine %v's type", methodName) + } + + if typ.Kind() != reflect.Func { + tool.Assert(false, "invalid instance method type: %v,%v", methodName, typ.Kind().String()) + return nil + } + + in := []reflect.Type{reflect.TypeOf(instance)} + out := []reflect.Type{} + for i := 0; i < typ.NumIn(); i++ { + in = append(in, typ.In(i)) + } + for i := 0; i < typ.NumOut(); i++ { + out = append(out, typ.Out(i)) + } + + hook := reflect.FuncOf(in, out, typ.IsVariadic()) + vt := reflect.Zero(hook).Interface() + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&vt)) + 8)) = uintptr(unsafe.Pointer(tfn)) + return vt +} + // GetGoroutineId gets the current goroutine ID func GetGoroutineId() int64 { return tool.GetGoroutineID() diff --git a/utils_test.go b/utils_test.go index d9ec822..119739f 100644 --- a/utils_test.go +++ b/utils_test.go @@ -17,7 +17,10 @@ package mockey import ( + "bytes" + "crypto/sha256" "fmt" + "io" "reflect" "testing" @@ -391,3 +394,55 @@ func TestGetNested(t *testing.T) { }) }) } + +func TestPrivateMethod(t *testing.T) { + PatchConvey("PrivateMethod", t, func() { + PatchConvey("unsafeMethodByName", func() { + PatchConvey("struct method", func() { + fn := unsafeMethodByName(&bytes.Buffer{}, "empty") + targetType := reflect.TypeOf(func(*bytes.Buffer) bool { return false }) + + convey.So(reflect.TypeOf(fn), convey.ShouldEqual, targetType) + + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + b, err := buf.ReadByte() + convey.So(b, convey.ShouldEqual, 1) + convey.So(err, convey.ShouldBeNil) + + mocker := Mock(fn).Return(true).Build() + _, err = buf.ReadByte() + convey.So(err, convey.ShouldEqual, io.EOF) + convey.So(mocker.MockTimes(), convey.ShouldEqual, 1) + }) + PatchConvey("struct method mock", func() { + buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) + b, err := buf.ReadByte() + convey.So(b, convey.ShouldEqual, 1) + convey.So(err, convey.ShouldBeNil) + + mocker := Mock(GetMethod(bytes.NewBuffer(nil), "empty")).Return(true).Build() + _, err = buf.ReadByte() + convey.So(err, convey.ShouldEqual, io.EOF) + convey.So(mocker.MockTimes(), convey.ShouldEqual, 1) + }) + + PatchConvey("interface method mock", func() { + mocker := Mock(GetMethod(sha256.New(), "checkSum")).Return([sha256.Size]byte{}).Build() + convey.So(sha256.New().Sum([]byte{}), convey.ShouldResemble, make([]byte, 32)) + convey.So(mocker.MockTimes(), convey.ShouldEqual, 1) + }) + PatchConvey("interface method", func() { + targetType := reflect.FuncOf([]reflect.Type{reflect.TypeOf(sha256.New())}, []reflect.Type{reflect.TypeOf([sha256.Size]byte{})}, false) + fn := unsafeMethodByName(sha256.New(), "checkSum") + convey.So(reflect.TypeOf(fn), convey.ShouldEqual, targetType) + + convey.So(sha256.New().Sum([]byte{}), convey.ShouldResemble, []byte{227, 176, 196, 66, 152, 252, 28, 20, 154, 251, 244, 200, 153, 111, 185, 36, 39, 174, 65, 228, 100, 155, 147, 76, 164, 149, 153, 27, 120, 82, 184, 85}) + mocker := Mock(fn).To(func() [sha256.Size]byte { + return [sha256.Size]byte{} + }).Build() + convey.So(sha256.New().Sum([]byte{}), convey.ShouldResemble, make([]byte, 32)) + convey.So(mocker.MockTimes(), convey.ShouldEqual, 1) + }) + }) + }) +}