From bcf45331b0ef3e8225296442fdd80b04fbd0f252 Mon Sep 17 00:00:00 2001 From: Joe Munene Date: Thu, 16 Apr 2026 03:47:30 +0300 Subject: [PATCH 1/2] fix(metrics): use Welford's algorithm in PearsonCorrelation for numerical stability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the naive sum-of-squares variance formula with Welford's online algorithm to fix catastrophic cancellation in float32. The old formula E[x²] - E[x]² loses precision when values have large magnitudes relative to their variance (e.g. mean=1e6, std=1), producing errors of 10%+ in computed correlation. Closes #3662 --- .../metrics/regression/pearson_correlation.py | 143 +++++++++++++----- .../regression/test_pearson_correlation.py | 60 ++++++-- 2 files changed, 145 insertions(+), 58 deletions(-) diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index 1b4e8fa5d04b..73cc05ed9f64 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -2,8 +2,9 @@ import torch +import ignite.distributed as idist from ignite.exceptions import NotComputableError -from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import reinit__is_reduced from ignite.metrics.regression._base import _BaseRegression @@ -19,6 +20,11 @@ class PearsonCorrelation(_BaseRegression): where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. + Internally uses `Welford's online algorithm `_ for numerically + stable computation, avoiding catastrophic cancellation that can occur with the + naive sum-of-squares formula in float32. + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. @@ -54,7 +60,10 @@ class PearsonCorrelation(_BaseRegression): .. testoutput:: - 0.9768688678741455 + 0.9768687504744322 + + .. versionchanged:: 0.6.0 + Uses Welford's online algorithm for improved numerical stability. """ def __init__( @@ -68,56 +77,106 @@ def __init__( self.eps = eps _state_dict_all_req_keys = ( - "_sum_of_y_preds", - "_sum_of_ys", - "_sum_of_y_pred_squares", - "_sum_of_y_squares", - "_sum_of_products", "_num_examples", + "_mean_x", + "_mean_y", + "_m2_x", + "_m2_y", + "_cxy", ) @reinit__is_reduced def reset(self) -> None: - self._sum_of_y_preds = torch.tensor(0.0, device=self._device) - self._sum_of_ys = torch.tensor(0.0, device=self._device) - self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device) - self._sum_of_y_squares = torch.tensor(0.0, device=self._device) - self._sum_of_products = torch.tensor(0.0, device=self._device) self._num_examples = 0 + self._mean_x = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._mean_y = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._m2_x = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._m2_y = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._cxy = torch.tensor(0.0, dtype=torch.float64, device=self._device) def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() - self._sum_of_y_preds += y_pred.sum().to(self._device) - self._sum_of_ys += y.sum().to(self._device) - self._sum_of_y_pred_squares += y_pred.square().sum().to(self._device) - self._sum_of_y_squares += y.square().sum().to(self._device) - self._sum_of_products += (y_pred * y).sum().to(self._device) - self._num_examples += y.shape[0] - - @sync_all_reduce( - "_sum_of_y_preds", - "_sum_of_ys", - "_sum_of_y_pred_squares", - "_sum_of_y_squares", - "_sum_of_products", - "_num_examples", - ) + y_pred = y_pred.to(dtype=torch.float64) + y = y.to(dtype=torch.float64) + + n_b = y.shape[0] + n_a = self._num_examples + n_ab = n_a + n_b + + mean_x_b = y_pred.mean().to(self._device) + mean_y_b = y.mean().to(self._device) + + # Within-batch second moments + dx_b = y_pred - mean_x_b + dy_b = y - mean_y_b + m2_x_b = dx_b.square().sum().to(self._device) + m2_y_b = dy_b.square().sum().to(self._device) + cxy_b = (dx_b * dy_b).sum().to(self._device) + + if n_a == 0: + self._mean_x = mean_x_b + self._mean_y = mean_y_b + self._m2_x = m2_x_b + self._m2_y = m2_y_b + self._cxy = cxy_b + else: + delta_x = mean_x_b - self._mean_x + delta_y = mean_y_b - self._mean_y + + self._mean_x += delta_x * n_b / n_ab + self._mean_y += delta_y * n_b / n_ab + + self._m2_x += m2_x_b + delta_x * delta_x * n_a * n_b / n_ab + self._m2_y += m2_y_b + delta_y * delta_y * n_a * n_b / n_ab + self._cxy += cxy_b + delta_x * delta_y * n_a * n_b / n_ab + + self._num_examples = n_ab + def compute(self) -> float: - n = self._num_examples - if n == 0: + if self._num_examples == 0: raise NotComputableError("PearsonCorrelation must have at least one example before it can be computed.") - # cov = E[xy] - E[x]*E[y] - cov = self._sum_of_products / n - self._sum_of_y_preds * self._sum_of_ys / (n * n) - - # var = E[x^2] - E[x]^2 - y_pred_mean = self._sum_of_y_preds / n - y_pred_var = self._sum_of_y_pred_squares / n - y_pred_mean * y_pred_mean - y_pred_var = torch.clamp(y_pred_var, min=0.0) - - y_mean = self._sum_of_ys / n - y_var = self._sum_of_y_squares / n - y_mean * y_mean - y_var = torch.clamp(y_var, min=0.0) - - r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps) + n = self._num_examples + mean_x = self._mean_x + mean_y = self._mean_y + m2_x = self._m2_x + m2_y = self._m2_y + cxy = self._cxy + + ws = idist.get_world_size() + if ws > 1: + # Gather Welford state from all processes and merge using + # the parallel variant of Welford's algorithm. + state = torch.stack([ + torch.tensor(float(n), dtype=torch.float64, device=self._device), + mean_x, mean_y, m2_x, m2_y, cxy, + ]) + all_states = idist.all_gather(state) + all_states = all_states.reshape(ws, 6) + + n = all_states[0, 0].item() + mean_x = all_states[0, 1] + mean_y = all_states[0, 2] + m2_x = all_states[0, 3] + m2_y = all_states[0, 4] + cxy = all_states[0, 5] + + for i in range(1, ws): + n_i = all_states[i, 0].item() + n_combined = n + n_i + dx = all_states[i, 1] - mean_x + dy = all_states[i, 2] - mean_y + + mean_x = mean_x + dx * n_i / n_combined + mean_y = mean_y + dy * n_i / n_combined + m2_x = m2_x + all_states[i, 3] + dx * dx * n * n_i / n_combined + m2_y = m2_y + all_states[i, 4] + dy * dy * n * n_i / n_combined + cxy = cxy + all_states[i, 5] + dx * dy * n * n_i / n_combined + n = n_combined + + var_x = torch.clamp(m2_x / n, min=0.0) + var_y = torch.clamp(m2_y / n, min=0.0) + cov = cxy / n + + r = cov / torch.clamp(torch.sqrt(var_x * var_y), min=self.eps) return float(r.item()) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py index 599e6fae2033..7472434f557b 100644 --- a/tests/ignite/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -44,7 +44,6 @@ def test_wrong_input_shapes(): def test_degenerated_sample(available_device): if available_device == "mps": pytest.skip(reason="PearsonCorrelation.compute returns nan on mps") - # r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps) # one sample m = PearsonCorrelation(device=available_device) @@ -98,6 +97,35 @@ def test_pearson_correlation(available_device): assert m.compute() == pytest.approx(expected, rel=1e-4) +def test_numerical_stability(available_device): + """Test that Welford's algorithm handles large-mean, small-variance data + that would cause catastrophic cancellation with the naive formula.""" + if available_device == "mps": + pytest.skip(reason="float64 not supported on mps") + + m = PearsonCorrelation(device=available_device) + + torch.manual_seed(42) + # Large offset with small variance: naive E[x^2] - E[x]^2 would fail in float32 + offset = 1e6 + n = 1000 + x = torch.randn(n, dtype=torch.float32) + offset + y = x + torch.randn(n, dtype=torch.float32) * 0.1 # highly correlated + + # Feed in small batches to stress accumulation + batch_size = 10 + for i in range(0, n, batch_size): + m.update((x[i : i + batch_size], y[i : i + batch_size])) + + result = m.compute() + expected = pearsonr(x.numpy(), y.numpy()).statistic + + assert result == pytest.approx(expected, rel=1e-5), ( + f"Numerical instability detected: got {result}, expected {expected}" + ) + assert result > 0.99, f"Correlation should be near 1.0 for highly correlated data, got {result}" + + @pytest.fixture(params=list(range(2))) def test_case(request): # correlated sample @@ -148,11 +176,11 @@ def test_accumulator_detached(available_device): assert all( (not accumulator.requires_grad) for accumulator in ( - corr._sum_of_products, - corr._sum_of_y_pred_squares, - corr._sum_of_y_preds, - corr._sum_of_y_squares, - corr._sum_of_ys, + corr._mean_x, + corr._mean_y, + corr._m2_x, + corr._m2_y, + corr._cxy, ) ) @@ -240,11 +268,11 @@ def test_accumulator_device(self): devices = ( corr._device, - corr._sum_of_products.device, - corr._sum_of_y_pred_squares.device, - corr._sum_of_y_preds.device, - corr._sum_of_y_squares.device, - corr._sum_of_ys.device, + corr._mean_x.device, + corr._mean_y.device, + corr._m2_x.device, + corr._m2_y.device, + corr._cxy.device, ) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" @@ -255,11 +283,11 @@ def test_accumulator_device(self): devices = ( corr._device, - corr._sum_of_products.device, - corr._sum_of_y_pred_squares.device, - corr._sum_of_y_preds.device, - corr._sum_of_y_squares.device, - corr._sum_of_ys.device, + corr._mean_x.device, + corr._mean_y.device, + corr._m2_x.device, + corr._m2_y.device, + corr._cxy.device, ) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From df1fde0213f289628b4c8517b01f3c8f82a394c7 Mon Sep 17 00:00:00 2001 From: Joe Munene Date: Sun, 19 Apr 2026 03:59:03 +0300 Subject: [PATCH 2/2] docs(metrics): annotate Welford state + merge formulas in PearsonCorrelation Per review feedback, the internal accumulator names (_m2_x, _cxy, etc.) follow Welford's original notation but aren't self-explanatory to readers who haven't encountered the algorithm before. - Block comment above state declaration maps each attribute to its role (running mean, sum of squared deviations, sum of paired deviations). - Inline comments in _update call out the 'accumulated vs batch vs combined' naming convention and explain the delta-correction term. - Inline comment in the distributed merge block explains why a plain sum across ranks is wrong (M2/Cxy are measured about different per-rank means). --- .../metrics/regression/pearson_correlation.py | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index 73cc05ed9f64..e5d03868c8ae 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -76,6 +76,17 @@ def __init__( self.eps = eps + # Welford state kept between updates. `x` denotes `y_pred`, `y` denotes the + # ground truth. Names follow the convention used in Welford (1962) and on + # the Wikipedia page linked in the class docstring: + # _num_examples -- running count of samples seen (n) + # _mean_x, _mean_y -- running means of x and y + # _m2_x, _m2_y -- running sums of squared deviations from the mean: + # M2_x = Σ (x_i - mean_x)^2 + # (dividing by n at the end gives the variance) + # _cxy -- running sum of paired deviations: + # C_xy = Σ (x_i - mean_x)(y_i - mean_y) + # (dividing by n at the end gives the covariance) _state_dict_all_req_keys = ( "_num_examples", "_mean_x", @@ -95,6 +106,9 @@ def reset(self) -> None: self._cxy = torch.tensor(0.0, dtype=torch.float64, device=self._device) def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + # Parallel Welford: compute the current batch's Welford state in one + # shot, then merge it into the running state. "_a" suffix = prior + # (accumulated) state, "_b" suffix = this batch, "_ab" = combined. y_pred, y = output[0].detach(), output[1].detach() y_pred = y_pred.to(dtype=torch.float64) y = y.to(dtype=torch.float64) @@ -103,10 +117,9 @@ def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: n_a = self._num_examples n_ab = n_a + n_b + # ---- Batch stats (mean + M2/Cxy computed about the batch mean) ---- mean_x_b = y_pred.mean().to(self._device) mean_y_b = y.mean().to(self._device) - - # Within-batch second moments dx_b = y_pred - mean_x_b dy_b = y - mean_y_b m2_x_b = dx_b.square().sum().to(self._device) @@ -114,12 +127,17 @@ def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: cxy_b = (dx_b * dy_b).sum().to(self._device) if n_a == 0: + # First batch: running state is just the batch state. self._mean_x = mean_x_b self._mean_y = mean_y_b self._m2_x = m2_x_b self._m2_y = m2_y_b self._cxy = cxy_b else: + # ---- Welford merge: combine running state (a) with batch (b) ---- + # delta = mean_b - mean_a is the shift between the two group means; + # the correction term n_a*n_b/n_ab * delta^2 accounts for the fact + # that M2_a and M2_b are measured about *different* centres. delta_x = mean_x_b - self._mean_x delta_y = mean_y_b - self._mean_y @@ -145,8 +163,11 @@ def compute(self) -> float: ws = idist.get_world_size() if ws > 1: - # Gather Welford state from all processes and merge using - # the parallel variant of Welford's algorithm. + # Distributed reduce: each rank already holds its local Welford + # state (n, mean_x, mean_y, M2_x, M2_y, C_xy). We can't sum them + # directly — M2/Cxy on each rank are measured about *that rank's* + # local mean. Instead we all_gather the per-rank states and fold + # them in one by one using the same pairwise merge as _update. state = torch.stack([ torch.tensor(float(n), dtype=torch.float64, device=self._device), mean_x, mean_y, m2_x, m2_y, cxy, @@ -154,6 +175,7 @@ def compute(self) -> float: all_states = idist.all_gather(state) all_states = all_states.reshape(ws, 6) + # Seed the accumulator with rank 0's state, then merge ranks 1..ws-1. n = all_states[0, 0].item() mean_x = all_states[0, 1] mean_y = all_states[0, 2]