diff --git a/go.mod b/go.mod index edfb6d40..3d1baa9d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/golang/mock require ( + github.com/pkg/errors v0.9.1 golang.org/x/tools v0.0.0-20190425150028-36563e24a262 rsc.io/quote/v3 v3.1.0 ) diff --git a/go.sum b/go.sum index 21dbce53..86c252de 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/gomock/call.go b/gomock/call.go index 7345f654..c94e71cd 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -19,6 +19,8 @@ import ( "reflect" "strconv" "strings" + + "github.com/golang/mock/gomock/internal/calldo" ) // Call represents an expected call to a mock. @@ -106,9 +108,20 @@ func (c *Call) MaxTimes(n int) *Call { // The return values from this function are returned by the mocked function. // It takes an interface{} argument to support n-arity functions. func (c *Call) DoAndReturn(f interface{}) *Call { - // TODO: Check arity and types here, rather than dying badly elsewhere. v := reflect.ValueOf(f) + switch v.Kind() { + case reflect.Func: + mt := c.methodType + + ft := v.Type() + if err := calldo.ValidateInputAndOutputSig(ft, mt); err != nil { + panic(fmt.Sprintf("DoAndReturn: %s", err)) + } + default: + panic("DoAndReturn: argument must be a function") + } + c.addAction(func(args []interface{}) []interface{} { vargs := make([]reflect.Value, len(args)) ft := v.Type() @@ -135,9 +148,20 @@ func (c *Call) DoAndReturn(f interface{}) *Call { // return values call DoAndReturn. // It takes an interface{} argument to support n-arity functions. func (c *Call) Do(f interface{}) *Call { - // TODO: Check arity and types here, rather than dying badly elsewhere. v := reflect.ValueOf(f) + switch v.Kind() { + case reflect.Func: + mt := c.methodType + + ft := v.Type() + if err := calldo.ValidateInputAndOutputSig(ft, mt); err != nil { + panic(fmt.Sprintf("Do: %s", err)) + } + default: + panic("Do: argument must be a function") + } + c.addAction(func(args []interface{}) []interface{} { vargs := make([]reflect.Value, len(args)) ft := v.Type() diff --git a/gomock/call_test.go b/gomock/call_test.go index 3a8315b3..41c53a41 100644 --- a/gomock/call_test.go +++ b/gomock/call_test.go @@ -1,6 +1,7 @@ package gomock import ( + "reflect" "testing" ) @@ -49,3 +50,418 @@ func TestCall_After(t *testing.T) { } }) } + +func TestCall_Do(t *testing.T) { + t.Run("Do function matches Call function", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is a interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is a map[interface{}]interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x map[int]string) bool { + return true + } + + callFunc := func(x map[interface{}]interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is variadic", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []int) bool { + return true + } + + callFunc := func(x ...int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is variadic interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []int) bool { + return true + } + + callFunc := func(x ...interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("argument to Do is not a function", func(t *testing.T) { + tr := &mockTestReporter{} + + callFunc := func(x int, y int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do("meow") + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("number of args for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int, y int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("arg types for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x string) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a slice", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []string) bool { + return true + } + + callFunc := func(x []int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a slice interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []string) bool { + return true + } + + callFunc := func(x []interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a composite struct", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x b) bool { + return true + } + + callFunc := func(x a) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a map", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x map[int]string) bool { + return true + } + + callFunc := func(x map[interface{}]int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("number of return vals for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) (bool, error) { + return false, nil + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("return types for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) error { + return nil + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) +} + +type a struct { + name string +} + +func (testObj a) Name() string { + return testObj.name +} + +type b struct { + a + foo string +} + +func (testObj b) Foo() string { + return testObj.foo +} diff --git a/gomock/controller_test.go b/gomock/controller_test.go index c22908b8..1f6b09d3 100644 --- a/gomock/controller_test.go +++ b/gomock/controller_test.go @@ -492,9 +492,11 @@ func TestDo(t *testing.T) { doCalled := false var argument string ctrl.RecordCall(subject, "FooMethod", "argument").Do( - func(arg string) { + func(arg string) int { doCalled = true argument = arg + + return 0 }) if doCalled { t.Error("Do() callback called too early.") diff --git a/gomock/internal/calldo/validate.go b/gomock/internal/calldo/validate.go new file mode 100644 index 00000000..2ff28e55 --- /dev/null +++ b/gomock/internal/calldo/validate.go @@ -0,0 +1,184 @@ +// Copyright 2020 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package calldo + +import ( + "fmt" + "reflect" + + "github.com/pkg/errors" +) + +// ValidateInputAndOutputSig compares the argument and return signatures of the +// function passed to Do against those expected by Call. It returns an error +// unless everything matches. +func ValidateInputAndOutputSig(doFunc, callFunc reflect.Type) error { + // check number of arguments and type of each argument + if doFunc.NumIn() != callFunc.NumIn() { + return fmt.Errorf( + "Do: expected function to have %d arguments not %d", + callFunc.NumIn(), doFunc.NumIn()) + } + + lastIdx := callFunc.NumIn() + + // If the function has a variadic argument validate that one first so that + // we aren't checking for it while we iterate over the other args + if callFunc.IsVariadic() { + if ok := validateVariadicArg(lastIdx, doFunc, callFunc); !ok { + i := lastIdx - 1 + return fmt.Errorf( + "Do: expected function to have"+ + " arg of type %v at position %d"+ + " not type %v", + callFunc.In(i), i, doFunc.In(i), + ) + } + + lastIdx-- + } + + for i := 0; i < lastIdx; i++ { + callArg := callFunc.In(i) + doArg := doFunc.In(i) + + if err := validateArg(doArg, callArg); err != nil { + return fmt.Errorf("input argument at %d: %s", i, err) + } + } + + // check number of return vals and type of each val + if doFunc.NumOut() != callFunc.NumOut() { + return fmt.Errorf( + "Do: expected function to have %d return vals not %d", + callFunc.NumOut(), doFunc.NumOut()) + } + + for i := 0; i < callFunc.NumOut(); i++ { + callArg := callFunc.Out(i) + doArg := doFunc.Out(i) + + if err := validateArg(doArg, callArg); err != nil { + return errors.Wrapf(err, "return argument at %d", i) + } + } + + return nil +} + +func validateVariadicArg(lastIdx int, doFunc, callFunc reflect.Type) bool { + if doFunc.In(lastIdx-1) != callFunc.In(lastIdx-1) { + if doFunc.In(lastIdx-1).Kind() != reflect.Slice { + return false + } + + callArgT := callFunc.In(lastIdx - 1) + callElem := callArgT.Elem() + if callElem.Kind() != reflect.Interface { + return false + } + + doArgT := doFunc.In(lastIdx - 1) + doElem := doArgT.Elem() + + if ok := doElem.ConvertibleTo(callElem); !ok { + return false + } + + } + + return true +} + +func validateInterfaceArg(doArg, callArg reflect.Type) error { + if !doArg.ConvertibleTo(callArg) { + return fmt.Errorf( + "expected arg convertible to type %v not type %v", + callArg, doArg, + ) + } + + return nil +} + +func validateMapArg(doArg, callArg reflect.Type) error { + callKey := callArg.Key() + doKey := doArg.Key() + + switch callKey.Kind() { + case reflect.Interface: + if err := validateInterfaceArg(doKey, callKey); err != nil { + return errors.Wrap(err, "map key") + } + default: + if doKey != callKey { + return fmt.Errorf("expected map key of type %v not type %v", + callKey, doKey) + } + } + + callElem := callArg.Elem() + doElem := doArg.Elem() + + switch callElem.Kind() { + case reflect.Interface: + if err := validateInterfaceArg(doElem, callElem); err != nil { + return errors.Wrap(err, "map element") + } + default: + if doElem != callElem { + return fmt.Errorf("expected map element of type %v not type %v", + callElem, doElem) + } + } + + return nil +} + +func validateArg(doArg, callArg reflect.Type) error { + switch callArg.Kind() { + // If the Call arg is an interface we only care if the Do arg is convertible + // to that interface + case reflect.Interface: + if err := validateInterfaceArg(doArg, callArg); err != nil { + return err + } + default: + // If the Call arg is not an interface then first check to see if + // the Do arg is even the same reflect.Kind + if callArg.Kind() != doArg.Kind() { + return fmt.Errorf("expected arg of kind %v not %v", + callArg.Kind(), doArg.Kind()) + } + + switch callArg.Kind() { + // If the Call arg is a map then we need to handle the case where + // the map key or element type is an interface + case reflect.Map: + if err := validateMapArg(doArg, callArg); err != nil { + return err + } + default: + if doArg != callArg { + return fmt.Errorf( + "Expected arg of type %v not type %v", + callArg, doArg, + ) + } + } + } + + return nil +}