Skip to content

Commit

Permalink
Check more carefully for nils in WithTransform (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbandy authored Feb 28, 2021
1 parent b82522a commit 3c60a15
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
23 changes: 11 additions & 12 deletions matchers/with_transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ func NewWithTransformMatcher(transform interface{}, matcher types.GomegaMatcher)
func (m *WithTransformMatcher) Match(actual interface{}) (bool, error) {
// prepare a parameter to pass to the Transform function
var param reflect.Value
{
if actual != nil {
// return error if actual's type is incompatible with Transform function's argument type
actualType := reflect.TypeOf(actual)
if !actualType.AssignableTo(m.transformArgType) {
return false, fmt.Errorf("Transform function expects '%s' but we have '%s'", m.transformArgType, actualType)
}
param = reflect.ValueOf(actual)
} else {
// make a nil value of the expected type
param = reflect.New(m.transformArgType).Elem()
}
if actual != nil && reflect.TypeOf(actual).AssignableTo(m.transformArgType) {
// The dynamic type of actual is compatible with the transform argument.
param = reflect.ValueOf(actual)

} else if actual == nil && m.transformArgType.Kind() == reflect.Interface {
// The dynamic type of actual is unknown, so there's no way to make its
// reflect.Value. Create a nil of the transform argument, which is known.
param = reflect.Zero(m.transformArgType)

} else {
return false, fmt.Errorf("Transform function expects '%s' but we have '%T'", m.transformArgType, actual)
}

// call the Transform function with `actual`
Expand Down
42 changes: 40 additions & 2 deletions matchers/with_transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,44 @@ var _ = Describe("WithTransformMatcher", func() {
})
})

When("the actual value is incompatible", func() {
It("fails to pass int to func(string)", func() {
actual, transform := int(0), func(string) int { return 0 }
success, err := WithTransform(transform, Equal(0)).Match(actual)
Expect(success).To(BeFalse())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("function expects 'string'"))
Expect(err.Error()).To(ContainSubstring("have 'int'"))
})

It("fails to pass string to func(interface)", func() {
actual, transform := "bang", func(error) int { return 0 }
success, err := WithTransform(transform, Equal(0)).Match(actual)
Expect(success).To(BeFalse())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("function expects 'error'"))
Expect(err.Error()).To(ContainSubstring("have 'string'"))
})

It("fails to pass nil interface to func(int)", func() {
actual, transform := error(nil), func(int) int { return 0 }
success, err := WithTransform(transform, Equal(0)).Match(actual)
Expect(success).To(BeFalse())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("function expects 'int'"))
Expect(err.Error()).To(ContainSubstring("have '<nil>'"))
})

It("fails to pass nil interface to func(pointer)", func() {
actual, transform := error(nil), func(*string) int { return 0 }
success, err := WithTransform(transform, Equal(0)).Match(actual)
Expect(success).To(BeFalse())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("function expects '*string'"))
Expect(err.Error()).To(ContainSubstring("have '<nil>'"))
})
})

It("works with positive cases", func() {
Expect(1).To(WithTransform(plus1, Equal(2)))
Expect(1).To(WithTransform(plus1, WithTransform(plus1, Equal(3))))
Expand All @@ -53,11 +91,11 @@ var _ = Describe("WithTransformMatcher", func() {
// transform expects interface
errString := func(e error) string {
if e == nil {
return "<nil>"
return "safe"
}
return e.Error()
}
Expect(nil).To(WithTransform(errString, Equal("<nil>")), "handles nil actual values")
Expect(nil).To(WithTransform(errString, Equal("safe")), "handles nil actual values")
Expect(errors.New("abc")).To(WithTransform(errString, Equal("abc")))
})

Expand Down

0 comments on commit 3c60a15

Please sign in to comment.