Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 101 additions & 42 deletions ignite/metrics/regression/pearson_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import torch

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
from ignite.metrics.metric import reinit__is_reduced

from ignite.metrics.regression._base import _BaseRegression

Expand All @@ -19,6 +20,11 @@ class PearsonCorrelation(_BaseRegression):

where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value.

Internally uses `Welford's online algorithm <https://en.wikipedia.org/wiki/
Algorithms_for_calculating_variance#Welford's_online_algorithm>`_ for numerically
stable computation, avoiding catastrophic cancellation that can occur with the
naive sum-of-squares formula in float32.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.

Expand Down Expand Up @@ -54,7 +60,10 @@ class PearsonCorrelation(_BaseRegression):

.. testoutput::

0.9768688678741455
0.9768687504744322

.. versionchanged:: 0.6.0
Uses Welford's online algorithm for improved numerical stability.
"""

def __init__(
Expand All @@ -68,56 +77,106 @@ def __init__(
self.eps = eps

_state_dict_all_req_keys = (
"_sum_of_y_preds",
"_sum_of_ys",
"_sum_of_y_pred_squares",
"_sum_of_y_squares",
"_sum_of_products",
"_num_examples",
"_mean_x",
"_mean_y",
"_m2_x",
"_m2_y",
"_cxy",
)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_y_preds = torch.tensor(0.0, device=self._device)
self._sum_of_ys = torch.tensor(0.0, device=self._device)
self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device)
self._sum_of_y_squares = torch.tensor(0.0, device=self._device)
self._sum_of_products = torch.tensor(0.0, device=self._device)
self._num_examples = 0
self._mean_x = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self._mean_y = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self._m2_x = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self._m2_y = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self._cxy = torch.tensor(0.0, dtype=torch.float64, device=self._device)
Comment on lines +102 to +106
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metric state is initialized and updated using torch.float64 tensors unconditionally. This will raise on MPS (float64 is not supported there), and available_device includes mps for this test suite. Consider selecting an internal accumulator dtype based on device capabilities (e.g., float64 normally, but float32 on MPS, similar to how SSIM handles this), and use that dtype consistently in reset, _update casts, and the distributed state tensor in compute.

Copilot uses AI. Check for mistakes.

def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
self._sum_of_y_preds += y_pred.sum().to(self._device)
self._sum_of_ys += y.sum().to(self._device)
self._sum_of_y_pred_squares += y_pred.square().sum().to(self._device)
self._sum_of_y_squares += y.square().sum().to(self._device)
self._sum_of_products += (y_pred * y).sum().to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce(
"_sum_of_y_preds",
"_sum_of_ys",
"_sum_of_y_pred_squares",
"_sum_of_y_squares",
"_sum_of_products",
"_num_examples",
)
y_pred = y_pred.to(dtype=torch.float64)
y = y.to(dtype=torch.float64)

n_b = y.shape[0]
n_a = self._num_examples
n_ab = n_a + n_b

mean_x_b = y_pred.mean().to(self._device)
mean_y_b = y.mean().to(self._device)

# Within-batch second moments
dx_b = y_pred - mean_x_b
dy_b = y - mean_y_b
m2_x_b = dx_b.square().sum().to(self._device)
m2_y_b = dy_b.square().sum().to(self._device)
cxy_b = (dx_b * dy_b).sum().to(self._device)

Comment on lines +121 to +128
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _update, mean_x_b/mean_y_b are moved to self._device before being used to center y_pred/y, which can trigger a device mismatch when the update tensors are on a different device than the metric accumulators (e.g., metric on CPU but y_pred/y on CUDA, or vice versa). This breaks common Ignite usage patterns (and the existing tests that pass CPU tensors with device='cuda'). Compute within-batch mean/deviations fully on the input tensor device first, then move only the resulting scalar statistics (mean_*_b, m2_*_b, cxy_b) to self._device before merging into the running state.

Suggested change
mean_x_b = y_pred.mean().to(self._device)
mean_y_b = y.mean().to(self._device)
# Within-batch second moments
dx_b = y_pred - mean_x_b
dy_b = y - mean_y_b
m2_x_b = dx_b.square().sum().to(self._device)
m2_y_b = dy_b.square().sum().to(self._device)
cxy_b = (dx_b * dy_b).sum().to(self._device)
mean_x_b = y_pred.mean()
mean_y_b = y.mean()
# Within-batch second moments
dx_b = y_pred - mean_x_b
dy_b = y - mean_y_b
m2_x_b = dx_b.square().sum()
m2_y_b = dy_b.square().sum()
cxy_b = (dx_b * dy_b).sum()
mean_x_b = mean_x_b.to(self._device)
mean_y_b = mean_y_b.to(self._device)
m2_x_b = m2_x_b.to(self._device)
m2_y_b = m2_y_b.to(self._device)
cxy_b = cxy_b.to(self._device)

Copilot uses AI. Check for mistakes.
if n_a == 0:
self._mean_x = mean_x_b
self._mean_y = mean_y_b
self._m2_x = m2_x_b
self._m2_y = m2_y_b
self._cxy = cxy_b
else:
delta_x = mean_x_b - self._mean_x
delta_y = mean_y_b - self._mean_y

self._mean_x += delta_x * n_b / n_ab
self._mean_y += delta_y * n_b / n_ab

self._m2_x += m2_x_b + delta_x * delta_x * n_a * n_b / n_ab
self._m2_y += m2_y_b + delta_y * delta_y * n_a * n_b / n_ab
self._cxy += cxy_b + delta_x * delta_y * n_a * n_b / n_ab

self._num_examples = n_ab

def compute(self) -> float:
n = self._num_examples
if n == 0:
if self._num_examples == 0:
raise NotComputableError("PearsonCorrelation must have at least one example before it can be computed.")

# cov = E[xy] - E[x]*E[y]
cov = self._sum_of_products / n - self._sum_of_y_preds * self._sum_of_ys / (n * n)

# var = E[x^2] - E[x]^2
y_pred_mean = self._sum_of_y_preds / n
y_pred_var = self._sum_of_y_pred_squares / n - y_pred_mean * y_pred_mean
y_pred_var = torch.clamp(y_pred_var, min=0.0)

y_mean = self._sum_of_ys / n
y_var = self._sum_of_y_squares / n - y_mean * y_mean
y_var = torch.clamp(y_var, min=0.0)

r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps)
n = self._num_examples
mean_x = self._mean_x
mean_y = self._mean_y
m2_x = self._m2_x
m2_y = self._m2_y
cxy = self._cxy

ws = idist.get_world_size()
if ws > 1:
# Gather Welford state from all processes and merge using
# the parallel variant of Welford's algorithm.
state = torch.stack([
torch.tensor(float(n), dtype=torch.float64, device=self._device),
mean_x, mean_y, m2_x, m2_y, cxy,
])
all_states = idist.all_gather(state)
all_states = all_states.reshape(ws, 6)

n = all_states[0, 0].item()
mean_x = all_states[0, 1]
mean_y = all_states[0, 2]
m2_x = all_states[0, 3]
m2_y = all_states[0, 4]
cxy = all_states[0, 5]

for i in range(1, ws):
n_i = all_states[i, 0].item()
n_combined = n + n_i
dx = all_states[i, 1] - mean_x
dy = all_states[i, 2] - mean_y

mean_x = mean_x + dx * n_i / n_combined
mean_y = mean_y + dy * n_i / n_combined
m2_x = m2_x + all_states[i, 3] + dx * dx * n * n_i / n_combined
m2_y = m2_y + all_states[i, 4] + dy * dy * n * n_i / n_combined
cxy = cxy + all_states[i, 5] + dx * dy * n * n_i / n_combined
n = n_combined

var_x = torch.clamp(m2_x / n, min=0.0)
var_y = torch.clamp(m2_y / n, min=0.0)
cov = cxy / n

r = cov / torch.clamp(torch.sqrt(var_x * var_y), min=self.eps)
return float(r.item())
60 changes: 44 additions & 16 deletions tests/ignite/metrics/regression/test_pearson_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_wrong_input_shapes():
def test_degenerated_sample(available_device):
if available_device == "mps":
pytest.skip(reason="PearsonCorrelation.compute returns nan on mps")
# r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps)

# one sample
m = PearsonCorrelation(device=available_device)
Expand Down Expand Up @@ -98,6 +97,35 @@ def test_pearson_correlation(available_device):
assert m.compute() == pytest.approx(expected, rel=1e-4)


def test_numerical_stability(available_device):
"""Test that Welford's algorithm handles large-mean, small-variance data
that would cause catastrophic cancellation with the naive formula."""
if available_device == "mps":
pytest.skip(reason="float64 not supported on mps")

m = PearsonCorrelation(device=available_device)

torch.manual_seed(42)
# Large offset with small variance: naive E[x^2] - E[x]^2 would fail in float32
offset = 1e6
n = 1000
x = torch.randn(n, dtype=torch.float32) + offset
y = x + torch.randn(n, dtype=torch.float32) * 0.1 # highly correlated

# Feed in small batches to stress accumulation
batch_size = 10
for i in range(0, n, batch_size):
m.update((x[i : i + batch_size], y[i : i + batch_size]))

result = m.compute()
expected = pearsonr(x.numpy(), y.numpy()).statistic

assert result == pytest.approx(expected, rel=1e-5), (
f"Numerical instability detected: got {result}, expected {expected}"
)
assert result > 0.99, f"Correlation should be near 1.0 for highly correlated data, got {result}"


@pytest.fixture(params=list(range(2)))
def test_case(request):
# correlated sample
Expand Down Expand Up @@ -148,11 +176,11 @@ def test_accumulator_detached(available_device):
assert all(
(not accumulator.requires_grad)
for accumulator in (
corr._sum_of_products,
corr._sum_of_y_pred_squares,
corr._sum_of_y_preds,
corr._sum_of_y_squares,
corr._sum_of_ys,
corr._mean_x,
corr._mean_y,
corr._m2_x,
corr._m2_y,
corr._cxy,
)
)

Expand Down Expand Up @@ -240,11 +268,11 @@ def test_accumulator_device(self):

devices = (
corr._device,
corr._sum_of_products.device,
corr._sum_of_y_pred_squares.device,
corr._sum_of_y_preds.device,
corr._sum_of_y_squares.device,
corr._sum_of_ys.device,
corr._mean_x.device,
corr._mean_y.device,
corr._m2_x.device,
corr._m2_y.device,
corr._cxy.device,
)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
Expand All @@ -255,11 +283,11 @@ def test_accumulator_device(self):

devices = (
corr._device,
corr._sum_of_products.device,
corr._sum_of_y_pred_squares.device,
corr._sum_of_y_preds.device,
corr._sum_of_y_squares.device,
corr._sum_of_ys.device,
corr._mean_x.device,
corr._mean_y.device,
corr._m2_x.device,
corr._m2_y.device,
corr._cxy.device,
)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"