Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

Commit

Permalink
Fix #71 Do signature change error msg
Browse files Browse the repository at this point in the history
Update the error handling for Call.Do in the case where
the argument passed to Call.Do does not match expectations.

* panic if the argument is not a function
* panic if the number of input arguments do not match those expected by Call
* panic if the types of the input arguments do not match those expected
by Call
* panic if the number of return arguments do not match those expected by
Call
* panic if the types of return arguments do not match those expected by
Call
  • Loading branch information
cvgw committed Feb 2, 2020
1 parent 5c85495 commit 0d51e72
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 2 deletions.
58 changes: 57 additions & 1 deletion gomock/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,70 @@ func (c *Call) DoAndReturn(f interface{}) *Call {
return c
}

// validateInputAndOutputSig compares the argument and return signatures of the
// function passed to Do against those expected by Call. It panics unless everything
// matches.
func validateInputAndOutputSig(doFunc, callFunc reflect.Type) {
// check number of arguments and type of each argument
if doFunc.NumIn() != callFunc.NumIn() {
panic(
fmt.Sprintf(
"Do: expected function to have %d arguments not %d",
callFunc.NumIn(), doFunc.NumIn()),
)
}

for i := 0; i < callFunc.NumIn(); i++ {
if doFunc.In(i) != callFunc.In(i) {
panic(
fmt.Sprintf(
"Do: expected function to have"+
" arg of type %v at position %d",
callFunc.In(i), i,
),
)
}
}

// check number of return vals and type of each val
if doFunc.NumOut() != callFunc.NumOut() {
panic(
fmt.Sprintf(
"Do: expected function to have %d return vals not %d",
callFunc.NumOut(), doFunc.NumOut()),
)
}

for i := 0; i < callFunc.NumOut(); i++ {
if doFunc.Out(i) != callFunc.Out(i) {
panic(
fmt.Sprintf(
"Do: expected function to have"+
" return val of type %v at position %d",
callFunc.Out(i), i,
),
)
}
}
}

// Do declares the action to run when the call is matched. The function's
// return values are ignored to retain backward compatibility. To use the
// return values call DoAndReturn.
// It takes an interface{} argument to support n-arity functions.
func (c *Call) Do(f interface{}) *Call {
// TODO: Check arity and types here, rather than dying badly elsewhere.
v := reflect.ValueOf(f)

switch v.Kind() {
case reflect.Func:
mt := c.methodType

ft := v.Type()
validateInputAndOutputSig(ft, mt)
default:
panic("Do: argument must be a function")
}

c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
ft := v.Type()
Expand Down
187 changes: 187 additions & 0 deletions gomock/call_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gomock

import (
"reflect"
"testing"
)

Expand Down Expand Up @@ -49,3 +50,189 @@ func TestCall_After(t *testing.T) {
}
})
}

func TestCall_Do(t *testing.T) {
t.Run("Do function matches Call function", func(t *testing.T) {
tr := &mockTestReporter{}

doFunc := func(x int) bool {
if x < 20 {
return false
}

return true
}

callFunc := func(x int) bool {
return false
}

c := &Call{
t: tr,
methodType: reflect.TypeOf(callFunc),
}

c.Do(doFunc)

if len(c.actions) != 1 {
t.Errorf("expected %d actions but got %d", 1, len(c.actions))
}
})

t.Run("argument to Do is not a function", func(t *testing.T) {
tr := &mockTestReporter{}

callFunc := func(x int, y int) bool {
return false
}

c := &Call{
t: tr,
methodType: reflect.TypeOf(callFunc),
}

defer func() {
if r := recover(); r == nil {
t.Error("expected Do to panic")
}
}()

c.Do("meow")

if len(c.actions) != 1 {
t.Errorf("expected %d actions but got %d", 1, len(c.actions))
}
})

t.Run("number of args for Do func don't match Call func", func(t *testing.T) {
tr := &mockTestReporter{}

doFunc := func(x int) bool {
if x < 20 {
return false
}

return true
}

callFunc := func(x int, y int) bool {
return false
}

c := &Call{
t: tr,
methodType: reflect.TypeOf(callFunc),
}

defer func() {
if r := recover(); r == nil {
t.Error("expected Do to panic")
}
}()

c.Do(doFunc)

if len(c.actions) != 1 {
t.Errorf("expected %d actions but got %d", 1, len(c.actions))
}
})

t.Run("arg types for Do func don't match Call func", func(t *testing.T) {
tr := &mockTestReporter{}

doFunc := func(x int) bool {
if x < 20 {
return false
}

return true
}

callFunc := func(x string) bool {
return false
}

c := &Call{
t: tr,
methodType: reflect.TypeOf(callFunc),
}

defer func() {
if r := recover(); r == nil {
t.Error("expected Do to panic")
}
}()

c.Do(doFunc)

if len(c.actions) != 1 {
t.Errorf("expected %d actions but got %d", 1, len(c.actions))
}
})

t.Run("number of return vals for Do func don't match Call func", func(t *testing.T) {
tr := &mockTestReporter{}

doFunc := func(x int) bool {
if x < 20 {
return false
}

return true
}

callFunc := func(x int) (bool, error) {
return false, nil
}

c := &Call{
t: tr,
methodType: reflect.TypeOf(callFunc),
}

defer func() {
if r := recover(); r == nil {
t.Error("expected Do to panic")
}
}()

c.Do(doFunc)

if len(c.actions) != 1 {
t.Errorf("expected %d actions but got %d", 1, len(c.actions))
}
})

t.Run("return types for Do func don't match Call func", func(t *testing.T) {
tr := &mockTestReporter{}

doFunc := func(x int) bool {
if x < 20 {
return false
}

return true
}

callFunc := func(x int) error {
return nil
}

c := &Call{
t: tr,
methodType: reflect.TypeOf(callFunc),
}

defer func() {
if r := recover(); r == nil {
t.Error("expected Do to panic")
}
}()

c.Do(doFunc)

if len(c.actions) != 1 {
t.Errorf("expected %d actions but got %d", 1, len(c.actions))
}
})
}
4 changes: 3 additions & 1 deletion gomock/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,11 @@ func TestDo(t *testing.T) {
doCalled := false
var argument string
ctrl.RecordCall(subject, "FooMethod", "argument").Do(
func(arg string) {
func(arg string) int {
doCalled = true
argument = arg

return 0
})
if doCalled {
t.Error("Do() callback called too early.")
Expand Down

0 comments on commit 0d51e72

Please sign in to comment.