Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions ignite/metrics/regression/pearson_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,20 @@ def __init__(

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_y_preds = torch.tensor(0.0, device=self._device)
self._sum_of_ys = torch.tensor(0.0, device=self._device)
self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device)
self._sum_of_y_squares = torch.tensor(0.0, device=self._device)
self._sum_of_products = torch.tensor(0.0, device=self._device)
# Use float64 accumulators to avoid catastrophic cancellation in
# E[X^2] - (E[X])^2 when values have large magnitude. MPS does not
# support float64, so fall back to float32 there.
acc_dtype = torch.float64 if self._device.type != "mps" else torch.float32
self._sum_of_y_preds = torch.tensor(0.0, dtype=acc_dtype, device=self._device)
self._sum_of_ys = torch.tensor(0.0, dtype=acc_dtype, device=self._device)
self._sum_of_y_pred_squares = torch.tensor(0.0, dtype=acc_dtype, device=self._device)
self._sum_of_y_squares = torch.tensor(0.0, dtype=acc_dtype, device=self._device)
self._sum_of_products = torch.tensor(0.0, dtype=acc_dtype, device=self._device)
self._num_examples = 0

def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
y_pred = output[0].detach().to(dtype=self._sum_of_y_preds.dtype)
y = output[1].detach().to(dtype=self._sum_of_y_preds.dtype)
self._sum_of_y_preds += y_pred.sum().to(self._device)
self._sum_of_ys += y.sum().to(self._device)
self._sum_of_y_pred_squares += y_pred.square().sum().to(self._device)
Expand Down
18 changes: 18 additions & 0 deletions tests/ignite/metrics/regression/test_pearson_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ def update_fn(engine: Engine, batch):
assert pytest.approx(np_ans, rel=2e-4) == corr


def test_numerical_stability_large_offset():
# float32 accumulators suffer catastrophic cancellation in E[X^2]-(E[X])^2
# when values have large magnitude relative to their variance: both E[X^2]
# and (E[X])^2 are ~1e16 but their difference (the variance) is ~1, which
# falls below float32's ULP at that scale. float64 accumulators preserve
# the precision. MPS is excluded because it does not support float64.
offset = 1e8
# y = 2*y_pred => perfect positive correlation; expected r = 1.0
y_pred = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float64) + offset
y = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], dtype=torch.float64) + offset

m = PearsonCorrelation() # CPU device (float64 accumulators)
m.update((y_pred, y))
result = m.compute()

assert pytest.approx(1.0, abs=1e-6) == result


def test_accumulator_detached(available_device):
corr = PearsonCorrelation(device=available_device)
assert corr._device == torch.device(available_device)
Expand Down