diff --git a/assert/assertions.go b/assert/assertions.go index 909daa5e2..b3b309b8d 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -165,17 +165,40 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool { return true } - actualType := reflect.TypeOf(actual) - if actualType == nil { + expectedValue := reflect.ValueOf(expected) + actualValue := reflect.ValueOf(actual) + if !expectedValue.IsValid() || !actualValue.IsValid() { return false } - expectedValue := reflect.ValueOf(expected) - if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + + expectedType := expectedValue.Type() + actualType := actualValue.Type() + if !expectedType.ConvertibleTo(actualType) { + return false + } + + if !isNumericType(expectedType) || !isNumericType(actualType) { // Attempt comparison after type conversion - return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + return reflect.DeepEqual( + expectedValue.Convert(actualType).Interface(), actual, + ) } - return false + // If BOTH values are numeric, there are chances of false positives due + // to overflow or underflow. So, we need to make sure to always convert + // the smaller type to a larger type before comparing. + if expectedType.Size() >= actualType.Size() { + return actualValue.Convert(expectedType).Interface() == expected + } + + return expectedValue.Convert(actualType).Interface() == actual +} + +// isNumericType returns true if the type is one of: +// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, +// float32, float64, complex64, complex128 +func isNumericType(t reflect.Type) bool { + return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128 } /* CallerInfo is necessary because the assert functions use the testing object diff --git a/assert/assertions_test.go b/assert/assertions_test.go index d65dd0218..a06d6e9e0 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -135,24 +135,42 @@ func TestObjectsAreEqual(t *testing.T) { }) } +} - // Cases where type differ but values are equal - if !ObjectsAreEqualValues(uint32(10), int32(10)) { - t.Error("ObjectsAreEqualValues should return true") - } - if ObjectsAreEqualValues(0, nil) { - t.Fail() - } - if ObjectsAreEqualValues(nil, 0) { - t.Fail() - } +func TestObjectsAreEqualValues(t *testing.T) { + now := time.Now() - tm := time.Now() - tz := tm.In(time.Local) - if !ObjectsAreEqualValues(tm, tz) { - t.Error("ObjectsAreEqualValues should return true for time.Time objects with different time zones") + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + {uint32(10), int32(10), true}, + {0, nil, false}, + {nil, 0, false}, + {now, now.In(time.Local), true}, // should be time zone independent + {int(270), int8(14), false}, // should handle overflow/underflow + {int8(14), int(270), false}, + {[]int{270, 270}, []int8{14, 14}, false}, + {complex128(1e+100 + 1e+100i), complex64(complex(math.Inf(0), math.Inf(0))), false}, + {complex64(complex(math.Inf(0), math.Inf(0))), complex128(1e+100 + 1e+100i), false}, + {complex128(1e+100 + 1e+100i), 270, false}, + {270, complex128(1e+100 + 1e+100i), false}, + {complex128(1e+100 + 1e+100i), 3.14, false}, + {3.14, complex128(1e+100 + 1e+100i), false}, + {complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true}, + {complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true}, } + for _, c := range cases { + t.Run(fmt.Sprintf("ObjectsAreEqualValues(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := ObjectsAreEqualValues(c.expected, c.actual) + + if res != c.result { + t.Errorf("ObjectsAreEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result) + } + }) + } } type Nested struct {