diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index 1b4e8fa5d04b..8002f4d25c6d 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -78,15 +78,20 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum_of_y_preds = torch.tensor(0.0, device=self._device) - self._sum_of_ys = torch.tensor(0.0, device=self._device) - self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device) - self._sum_of_y_squares = torch.tensor(0.0, device=self._device) - self._sum_of_products = torch.tensor(0.0, device=self._device) + # Use float64 accumulators to avoid catastrophic cancellation in + # E[X^2] - (E[X])^2 when values have large magnitude. MPS does not + # support float64, so fall back to float32 there. + acc_dtype = torch.float64 if self._device.type != "mps" else torch.float32 + self._sum_of_y_preds = torch.tensor(0.0, dtype=acc_dtype, device=self._device) + self._sum_of_ys = torch.tensor(0.0, dtype=acc_dtype, device=self._device) + self._sum_of_y_pred_squares = torch.tensor(0.0, dtype=acc_dtype, device=self._device) + self._sum_of_y_squares = torch.tensor(0.0, dtype=acc_dtype, device=self._device) + self._sum_of_products = torch.tensor(0.0, dtype=acc_dtype, device=self._device) self._num_examples = 0 def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: - y_pred, y = output[0].detach(), output[1].detach() + y_pred = output[0].detach().to(dtype=self._sum_of_y_preds.dtype) + y = output[1].detach().to(dtype=self._sum_of_y_preds.dtype) self._sum_of_y_preds += y_pred.sum().to(self._device) self._sum_of_ys += y.sum().to(self._device) self._sum_of_y_pred_squares += y_pred.square().sum().to(self._device) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py index 599e6fae2033..d88eaa4cadb7 100644 --- a/tests/ignite/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -137,6 +137,24 @@ def update_fn(engine: Engine, batch): assert pytest.approx(np_ans, rel=2e-4) == corr +def test_numerical_stability_large_offset(): + # float32 accumulators suffer catastrophic cancellation in E[X^2]-(E[X])^2 + # when values have large magnitude relative to their variance: both E[X^2] + # and (E[X])^2 are ~1e16 but their difference (the variance) is ~1, which + # falls below float32's ULP at that scale. float64 accumulators preserve + # the precision. MPS is excluded because it does not support float64. + offset = 1e8 + # y = 2*y_pred => perfect positive correlation; expected r = 1.0 + y_pred = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float64) + offset + y = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], dtype=torch.float64) + offset + + m = PearsonCorrelation() # CPU device (float64 accumulators) + m.update((y_pred, y)) + result = m.compute() + + assert pytest.approx(1.0, abs=1e-6) == result + + def test_accumulator_detached(available_device): corr = PearsonCorrelation(device=available_device) assert corr._device == torch.device(available_device)