Skip to content

Commit d62101f

Browse files
committed
Fix the issue with float32 and float64 comparison
1 parent 19cf96a commit d62101f

File tree

2 files changed

+45
-29
lines changed

2 files changed

+45
-29
lines changed

assert/assertions.go

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -176,40 +176,57 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
176176
if !expectedType.ConvertibleTo(actualType) {
177177
return false
178178
}
179+
expectedValueCmp := expectedValue.Convert(actualType).Interface()
180+
actualValueCmp := actual
179181

180182
if !isNumericType(expectedType) || !isNumericType(actualType) {
181183
// Attempt comparison after type conversion
182184
return reflect.DeepEqual(
183-
expectedValue.Convert(actualType).Interface(), actual,
185+
expectedValueCmp, actualValueCmp,
184186
)
185187
}
186188

187-
// If BOTH values are numeric, there are chances of false positives due
188-
// to overflow or underflow. So, we need to make sure to always convert
189-
// the smaller type to a larger type before comparing.
190-
fromType := actualType
191-
toType := expectedType
192-
fromValue := actualValue
193-
toValue := expectedValue
194-
if expectedType.Size() < actualType.Size() {
195-
fromType = expectedType
196-
toType = actualType
197-
fromValue = expectedValue
198-
toValue = actualValue
199-
}
200-
201-
// If we are converting from float32 to float64, the converted value will
202-
// have trailing non zero decimals due to binary representation differences
203-
// For example: float64(float32(10.1)) = 10.100000381469727
204-
// To remove the trailing decimals we can round the 64-bit value to
205-
// expected precision of 32-bit which is 6 decimal places
206-
newValue := fromValue.Convert(toType).Interface()
207-
if fromType.Kind() == reflect.Float32 && toType.Kind() == reflect.Float64 {
208-
scale := math.Pow(10, 6)
209-
newValue = math.Round(newValue.(float64)*scale) / scale
210-
}
211-
212-
return newValue == toValue.Interface()
189+
if actual != actual || expected != expected {
190+
// NaN is not equal to NaN
191+
return false
192+
}
193+
194+
// Both are numeric but their types are different, otherwise ObjectsAreEqual would have returned true already.
195+
// We need to convert the smaller type to the larger type and then compare.
196+
197+
smallestTypeValue, largestTypeValue := expectedValue, actualValue
198+
if actualType.Size() < expectedType.Size() {
199+
smallestTypeValue, largestTypeValue = largestTypeValue, smallestTypeValue
200+
201+
if !actualType.ConvertibleTo(expectedType) {
202+
return false
203+
}
204+
actualValueCmp = actualValue.Convert(expectedType).Interface()
205+
expectedValueCmp = expected
206+
}
207+
208+
if actualValueCmp == expectedValueCmp {
209+
// fast path
210+
return true
211+
}
212+
213+
if smallestTypeValue.Kind() == reflect.Float32 && largestTypeValue.Kind() == reflect.Float64 {
214+
a, b := expected, actual
215+
if af, aIsNumber := toFloat(expected); aIsNumber && af == 0 {
216+
// avoid division by zero in calcRelativeError
217+
a, b = b, a
218+
}
219+
220+
epsilon, err := calcRelativeError(a, b)
221+
if err != nil {
222+
return false
223+
}
224+
225+
// The threshold of 1e-6 is somewhat arbitrary, but it is a common choice for comparing floating point numbers.
226+
return epsilon <= 1e-6
227+
}
228+
229+
return false
213230
}
214231

215232
// isNumericType returns true if the type is one of:

assert/assertions_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ func TestObjectsAreEqualValues(t *testing.T) {
171171
{complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true},
172172
{complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true},
173173

174-
// cases for float32/float64 comparison, which should not be equal due to precision differences
175-
{float32(1.0 / 3.0), float64(1.0 / 3.0), false},
174+
{float32(1.0 / 3.0), float64(1.0 / 3.0), true},
176175

177176
// cases for float32/float64 comparison, which should be equal within float32 precision
178177
{float32(10.1), float64(10.1), true},

0 commit comments

Comments
 (0)