diff --git a/assert/assertions.go b/assert/assertions.go index 2c978b3a3..2975081d4 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -176,40 +176,57 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool { if !expectedType.ConvertibleTo(actualType) { return false } + expectedValueCmp := expectedValue.Convert(actualType).Interface() + actualValueCmp := actual if !isNumericType(expectedType) || !isNumericType(actualType) { // Attempt comparison after type conversion return reflect.DeepEqual( - expectedValue.Convert(actualType).Interface(), actual, + expectedValueCmp, actualValueCmp, ) } - // If BOTH values are numeric, there are chances of false positives due - // to overflow or underflow. So, we need to make sure to always convert - // the smaller type to a larger type before comparing. - fromType := actualType - toType := expectedType - fromValue := actualValue - toValue := expectedValue - if expectedType.Size() < actualType.Size() { - fromType = expectedType - toType = actualType - fromValue = expectedValue - toValue = actualValue - } - - // If we are converting from float32 to float64, the converted value will - // have trailing non zero decimals due to binary representation differences - // For example: float64(float32(10.1)) = 10.100000381469727 - // To remove the trailing decimals we can round the 64-bit value to - // expected precision of 32-bit which is 6 decimal places - newValue := fromValue.Convert(toType).Interface() - if fromType.Kind() == reflect.Float32 && toType.Kind() == reflect.Float64 { - scale := math.Pow(10, 6) - newValue = math.Round(newValue.(float64)*scale) / scale - } - - return newValue == toValue.Interface() + if actual != actual || expected != expected { + // NaN is not equal to NaN + return false + } + + // Both are numeric but their types are different, otherwise ObjectsAreEqual would have returned true already. + // We need to convert the smaller type to the larger type and then compare. + + smallestTypeValue, largestTypeValue := expectedValue, actualValue + if actualType.Size() < expectedType.Size() { + smallestTypeValue, largestTypeValue = largestTypeValue, smallestTypeValue + + if !actualType.ConvertibleTo(expectedType) { + return false + } + actualValueCmp = actualValue.Convert(expectedType).Interface() + expectedValueCmp = expected + } + + if actualValueCmp == expectedValueCmp { + // fast path + return true + } + + if smallestTypeValue.Kind() == reflect.Float32 && largestTypeValue.Kind() == reflect.Float64 { + a, b := expected, actual + if af, aIsNumber := toFloat(expected); aIsNumber && af == 0 { + // avoid division by zero in calcRelativeError + a, b = b, a + } + + epsilon, err := calcRelativeError(a, b) + if err != nil { + return false + } + + // The threshold of 1e-6 is somewhat arbitrary, but it is a common choice for comparing floating point numbers. + return epsilon <= 1e-6 + } + + return false } // isNumericType returns true if the type is one of: diff --git a/assert/assertions_test.go b/assert/assertions_test.go index da3eb4228..b6700f1d0 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -123,6 +123,11 @@ func TestObjectsAreEqual(t *testing.T) { {time.Now, time.Now, false}, {func() {}, func() {}, false}, {uint32(10), int32(10), false}, + {math.NaN(), math.NaN(), false}, + {math.Inf(1), math.Inf(1), true}, + {math.Inf(-1), math.Inf(-1), true}, + {math.Inf(1), math.Inf(-1), false}, + {math.Copysign(0, -1), 0.0, true}, // -0 should compare equal to 0 } for _, c := range cases { @@ -132,6 +137,10 @@ func TestObjectsAreEqual(t *testing.T) { if res != c.result { t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result) } + + if ObjectsAreEqual(c.actual, c.expected) != res { + t.Errorf("ObjectsAreEqual should be symmetric: ObjectsAreEqual(%#v, %#v) should return the same as ObjectsAreEqual(%#v, %#v)", c.expected, c.actual, c.actual, c.expected) + } }) } } @@ -161,11 +170,36 @@ func TestObjectsAreEqualValues(t *testing.T) { {3.14, complex128(1e+100 + 1e+100i), false}, {complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true}, {complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true}, + + {float32(1.0 / 3.0), float64(1.0 / 3.0), true}, + + // cases for float32/float64 comparison, which should be equal within float32 precision {float32(10.1), float64(10.1), true}, - {float64(10.1), float32(10.1), true}, - {float32(10.123456), float64(10.12345600), true}, - {float32(10.123456), float64(10.12345678), false}, - {float32(1.0 / 3.0), float64(1.0 / 3.0), false}, + + {float32(10.12345), float64(10.12345), true}, + + // cases that are close but should not be equal at float32 precision + {float32(10.1234), float64(10.1235), false}, + + // so anything beyond 7 decimal digits should be ignored when comparing at float32 precision, so these should still be equal + {float32(10.12345600), float64(10.123456789), true}, + + // Something near overflow should work + {float32(math.MaxFloat32), float64(math.MaxFloat32), true}, + + // NaN should remain unequal, even across float32/float64. + {float32(math.NaN()), float64(math.NaN()), false}, + + // Infinity should compare like ordinary equality. + {float32(math.Inf(1)), float64(math.Inf(1)), true}, + {float32(math.Inf(-1)), float64(math.Inf(-1)), true}, + {float64(math.Inf(1)), float32(math.Inf(-1)), false}, + + // zero should not lead to division by zero error + {float32(0), float64(0), true}, + + // Signed zero should still compare equal. + {float32(math.Copysign(0, -1)), float64(0), true}, } for _, c := range cases { @@ -175,6 +209,10 @@ func TestObjectsAreEqualValues(t *testing.T) { if res != c.result { t.Errorf("ObjectsAreEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result) } + + if ObjectsAreEqualValues(c.actual, c.expected) != res { + t.Errorf("ObjectsAreEqualValues should be symmetric: ObjectsAreEqualValues(%#v, %#v) should return the same as ObjectsAreEqualValues(%#v, %#v)", c.expected, c.actual, c.actual, c.expected) + } }) } }