From ec1f186861bc4fcbe3f2cdb2a1442f81a710ec3e Mon Sep 17 00:00:00 2001 From: thediveo Date: Thu, 18 Apr 2024 13:58:25 +0200 Subject: [PATCH] feat: receiver matcher accepting (POINTER, MATCHER), includes unit tests Signed-off-by: thediveo --- matchers.go | 15 +++---- matchers/receive_matcher.go | 70 ++++++++++++++++++++++++-------- matchers/receive_matcher_test.go | 53 ++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 28 deletions(-) diff --git a/matchers.go b/matchers.go index 8860d677f..7ef27dc9c 100644 --- a/matchers.go +++ b/matchers.go @@ -194,20 +194,21 @@ func BeClosed() types.GomegaMatcher { // // will repeatedly attempt to pull values out of `c` until a value matching "bar" is received. // -// Finally, if you want to have a reference to the value *sent* to the channel you can pass the `Receive` matcher a pointer to a variable of the appropriate type: +// Furthermore, if you want to have a reference to the value *sent* to the channel you can pass the `Receive` matcher a pointer to a variable of the appropriate type: // // var myThing thing // Eventually(thingChan).Should(Receive(&myThing)) // Expect(myThing.Sprocket).Should(Equal("foo")) // Expect(myThing.IsValid()).Should(BeTrue()) +// +// Finally, if you want to match the received object as well as get the actual received value into a variable, so you can reason further about the value received, +// you can pass a pointer to a variable of the approriate type first, and second a matcher: +// +// var myThing thing +// Eventually(thingChan).Should(Receive(&myThing, ContainSubstring("bar"))) func Receive(args ...interface{}) types.GomegaMatcher { - var arg interface{} - if len(args) > 0 { - arg = args[0] - } - return &matchers.ReceiveMatcher{ - Arg: arg, + Args: args, } } diff --git a/matchers/receive_matcher.go b/matchers/receive_matcher.go index 1936a2ba5..948164eaf 100644 --- a/matchers/receive_matcher.go +++ b/matchers/receive_matcher.go @@ -3,6 +3,7 @@ package matchers import ( + "errors" "fmt" "reflect" @@ -10,7 +11,7 @@ import ( ) type ReceiveMatcher struct { - Arg interface{} + Args []interface{} receivedValue reflect.Value channelClosed bool } @@ -29,15 +30,38 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro var subMatcher omegaMatcher var hasSubMatcher bool - - if matcher.Arg != nil { - subMatcher, hasSubMatcher = (matcher.Arg).(omegaMatcher) + var resultReference interface{} + + // Valid arg formats are as follows, always with optional POINTER before + // optional MATCHER: + // - Receive() + // - Receive(POINTER) + // - Receive(MATCHER) + // - Receive(POINTER, MATCHER) + args := matcher.Args + if len(args) > 0 { + arg := args[0] + _, isSubMatcher := arg.(omegaMatcher) + if !isSubMatcher && reflect.ValueOf(arg).Kind() == reflect.Ptr { + // Consume optional POINTER arg first, if it ain't no matcher ;) + resultReference = arg + args = args[1:] + } + } + if len(args) > 0 { + arg := args[0] + subMatcher, hasSubMatcher = arg.(omegaMatcher) if !hasSubMatcher { - argType := reflect.TypeOf(matcher.Arg) - if argType.Kind() != reflect.Ptr { - return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nTo:\n%s\nYou need to pass a pointer!", format.Object(actual, 1), format.Object(matcher.Arg, 1)) - } + // At this point we assume the dev user wanted to assign a received + // value, so [POINTER,]MATCHER. + return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nTo:\n%s\nYou need to pass a pointer!", format.Object(actual, 1), format.Object(arg, 1)) } + // Consume optional MATCHER arg. + args = args[1:] + } + if len(args) > 0 { + // If there are still args present, reject all. + return false, errors.New("Receive matcher expects at most an optional pointer and/or an optional matcher") } winnerIndex, value, open := reflect.Select([]reflect.SelectCase{ @@ -58,16 +82,20 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro } if hasSubMatcher { - if didReceive { - matcher.receivedValue = value - return subMatcher.Match(matcher.receivedValue.Interface()) + if !didReceive { + return false, nil } - return false, nil + matcher.receivedValue = value + if match, err := subMatcher.Match(matcher.receivedValue.Interface()); err != nil || !match { + return match, err + } + // if we received a match, then fall through in order to handle an + // optional assignment of the received value to the specified reference. } if didReceive { - if matcher.Arg != nil { - outValue := reflect.ValueOf(matcher.Arg) + if resultReference != nil { + outValue := reflect.ValueOf(resultReference) if value.Type().AssignableTo(outValue.Elem().Type()) { outValue.Elem().Set(value) @@ -77,7 +105,7 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro outValue.Elem().Set(value.Elem()) return true, nil } else { - return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nType:\n%s\nTo:\n%s", format.Object(actual, 1), format.Object(value.Interface(), 1), format.Object(matcher.Arg, 1)) + return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nType:\n%s\nTo:\n%s", format.Object(actual, 1), format.Object(value.Interface(), 1), format.Object(resultReference, 1)) } } @@ -88,7 +116,11 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro } func (matcher *ReceiveMatcher) FailureMessage(actual interface{}) (message string) { - subMatcher, hasSubMatcher := (matcher.Arg).(omegaMatcher) + var matcherArg interface{} + if len(matcher.Args) > 0 { + matcherArg = matcher.Args[len(matcher.Args)-1] + } + subMatcher, hasSubMatcher := (matcherArg).(omegaMatcher) closedAddendum := "" if matcher.channelClosed { @@ -105,7 +137,11 @@ func (matcher *ReceiveMatcher) FailureMessage(actual interface{}) (message strin } func (matcher *ReceiveMatcher) NegatedFailureMessage(actual interface{}) (message string) { - subMatcher, hasSubMatcher := (matcher.Arg).(omegaMatcher) + var matcherArg interface{} + if len(matcher.Args) > 0 { + matcherArg = matcher.Args[len(matcher.Args)-1] + } + subMatcher, hasSubMatcher := (matcherArg).(omegaMatcher) closedAddendum := "" if matcher.channelClosed { diff --git a/matchers/receive_matcher_test.go b/matchers/receive_matcher_test.go index 7837dc94e..d1b2ba810 100644 --- a/matchers/receive_matcher_test.go +++ b/matchers/receive_matcher_test.go @@ -54,6 +54,39 @@ var _ = Describe("ReceiveMatcher", func() { }) }) + Context("with too many arguments", func() { + It("should error", func() { + channel := make(chan bool, 1) + var actual bool + + channel <- true + + success, err := (&ReceiveMatcher{Args: []interface{}{ + &actual, + Equal(true), + 42, + }}).Match(channel) + Expect(success).To(BeFalse()) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("with swapped arguments", func() { + It("should error", func() { + channel := make(chan bool, 1) + var actual bool + + channel <- true + + success, err := (&ReceiveMatcher{Args: []interface{}{ + Equal(true), + &actual, + }}).Match(channel) + Expect(success).To(BeFalse()) + Expect(err).To(HaveOccurred()) + }) + }) + Context("with a pointer argument", func() { Context("of the correct type", func() { When("the channel has an interface type", func() { @@ -134,12 +167,12 @@ var _ = Describe("ReceiveMatcher", func() { var incorrectType bool - success, err := (&ReceiveMatcher{Arg: &incorrectType}).Match(channel) + success, err := (&ReceiveMatcher{Args: []interface{}{&incorrectType}}).Match(channel) Expect(success).Should(BeFalse()) Expect(err).Should(HaveOccurred()) var notAPointer int - success, err = (&ReceiveMatcher{Arg: notAPointer}).Match(channel) + success, err = (&ReceiveMatcher{Args: []interface{}{notAPointer}}).Match(channel) Expect(success).Should(BeFalse()) Expect(err).Should(HaveOccurred()) }) @@ -192,7 +225,7 @@ var _ = Describe("ReceiveMatcher", func() { It("should error", func() { channel := make(chan int, 1) channel <- 3 - success, err := (&ReceiveMatcher{Arg: ContainSubstring("three")}).Match(channel) + success, err := (&ReceiveMatcher{Args: []interface{}{ContainSubstring("three")}}).Match(channel) Expect(success).Should(BeFalse()) Expect(err).Should(HaveOccurred()) }) @@ -201,13 +234,25 @@ var _ = Describe("ReceiveMatcher", func() { Context("if nothing is received", func() { It("should fail", func() { channel := make(chan int, 1) - success, err := (&ReceiveMatcher{Arg: Equal(1)}).Match(channel) + success, err := (&ReceiveMatcher{Args: []interface{}{Equal(1)}}).Match(channel) Expect(success).Should(BeFalse()) Expect(err).ShouldNot(HaveOccurred()) }) }) }) + Context("with a pointer and a matcher argument", func() { + It("should succeed", func() { + channel := make(chan bool, 1) + channel <- true + + var received bool + + Expect(channel).Should(Receive(&received, Equal(true))) + Expect(received).Should(BeTrue()) + }) + }) + Context("When actual is a *closed* channel", func() { Context("for a buffered channel", func() { It("should work until it hits the end of the buffer", func() {