diff --git a/assert/assertions.go b/assert/assertions.go index 6950636d3..5a8ed649b 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -177,21 +177,53 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool { return false } + convertedExpectedValue := expectedValue.Convert(actualType).Interface() + if !isNumericType(expectedType) || !isNumericType(actualType) { // Attempt comparison after type conversion - return reflect.DeepEqual( - expectedValue.Convert(actualType).Interface(), actual, - ) + return reflect.DeepEqual(convertedExpectedValue, actual) } // If BOTH values are numeric, there are chances of false positives due // to overflow or underflow. So, we need to make sure to always convert // the smaller type to a larger type before comparing. - if expectedType.Size() >= actualType.Size() { - return actualValue.Convert(expectedType).Interface() == expected + // Assume smaller is expected value and larger is actual value + smallerTypeValue, largerTypeValue := expectedValue, actualValue + smallerValueCmp, largerValueCmp := convertedExpectedValue, actual + + // Actual value is smaller than expected value, converting actual value to expected value type + if actualType.Size() < expectedType.Size() { + smallerTypeValue, largerTypeValue = actualValue, expectedValue + + if !actualType.ConvertibleTo(expectedType) { + return false + } + smallerValueCmp = actualValue.Convert(expectedType).Interface() + largerValueCmp = expected + } + + // Quick comparison after type conversion to see if overflow or underflow is resolved + if smallerValueCmp == largerValueCmp { + return true + } + + // We want to allow comparison between float32(10.1) and float64(10.1). + // The problem here is when converting from float32 to float64, the converted + // value will have trailing non zero decimals due to binary representation differences + // For example: float64(float32(10.1)) = 10.100000381469727 + // To remove the trailing decimals we can round the 64-bit value to + // expected precision of 32-bit which is 6 decimal places + if smallerTypeValue.Kind() == reflect.Float32 && largerTypeValue.Kind() == reflect.Float64 { + float := smallerValueCmp.(float64) + integerPart := math.Floor(float) + decimalPart := float - integerPart + + scale := math.Pow(10, 6) + decimalPart = math.Round(decimalPart*scale) / scale + smallerValueCmp = integerPart + decimalPart } - return expectedValue.Convert(actualType).Interface() == actual + return smallerValueCmp == largerValueCmp } // isNumericType returns true if the type is one of: diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 4975f5e41..a3ee918ab 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -123,6 +123,11 @@ func TestObjectsAreEqual(t *testing.T) { {time.Now, time.Now, false}, {func() {}, func() {}, false}, {uint32(10), int32(10), false}, + {math.NaN(), math.NaN(), false}, + {math.Inf(1), math.Inf(1), true}, + {math.Inf(-1), math.Inf(-1), true}, + {math.Inf(1), math.Inf(-1), false}, + {math.Copysign(0, -1), 0.0, true}, // -0 should compare equal to 0 } for _, c := range cases { @@ -132,6 +137,10 @@ func TestObjectsAreEqual(t *testing.T) { if res != c.result { t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result) } + + if ObjectsAreEqual(c.actual, c.expected) != res { + t.Errorf("ObjectsAreEqual should be symmetric: ObjectsAreEqual(%#v, %#v) should return the same as ObjectsAreEqual(%#v, %#v)", c.expected, c.actual, c.actual, c.expected) + } }) } } @@ -148,19 +157,36 @@ func TestObjectsAreEqualValues(t *testing.T) { }{ {uint32(10), int32(10), true}, {0, nil, false}, - {nil, 0, false}, {now, now.In(time.Local), false}, // should not be time zone independent {int(270), int8(14), false}, // should handle overflow/underflow - {int8(14), int(270), false}, {[]int{270, 270}, []int8{14, 14}, false}, {complex128(1e+100 + 1e+100i), complex64(complex(math.Inf(0), math.Inf(0))), false}, {complex64(complex(math.Inf(0), math.Inf(0))), complex128(1e+100 + 1e+100i), false}, {complex128(1e+100 + 1e+100i), 270, false}, {270, complex128(1e+100 + 1e+100i), false}, {complex128(1e+100 + 1e+100i), 3.14, false}, - {3.14, complex128(1e+100 + 1e+100i), false}, {complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true}, - {complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true}, + {float64(10.1), float32(10.1), true}, + {float32(10.123456), float64(10.12345600), true}, + {float32(10.123456), float64(10.12345678), false}, + {float32(1.0 / 3.0), float64(1.0 / 3.0), false}, + + // Something near overflow should work + {float32(math.MaxFloat32), float64(math.MaxFloat32), true}, + + // NaN should remain unequal, even across float32/float64. + {float32(math.NaN()), float64(math.NaN()), false}, + + // Infinity should compare like ordinary equality. + {float32(math.Inf(1)), float64(math.Inf(1)), true}, + {float32(math.Inf(-1)), float64(math.Inf(-1)), true}, + {float64(math.Inf(1)), float32(math.Inf(-1)), false}, + + // zero should not lead to division by zero error + {float32(0), float64(0), true}, + + // Signed zero should still compare equal. + {float32(math.Copysign(0, -1)), float64(0), true}, } for _, c := range cases { @@ -170,6 +196,10 @@ func TestObjectsAreEqualValues(t *testing.T) { if res != c.result { t.Errorf("ObjectsAreEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result) } + + if ObjectsAreEqualValues(c.actual, c.expected) != res { + t.Errorf("ObjectsAreEqualValues should be symmetric: ObjectsAreEqualValues(%#v, %#v) should return the same as ObjectsAreEqualValues(%#v, %#v)", c.expected, c.actual, c.actual, c.expected) + } }) } }