diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index 1b4e8fa5d04b..e5d03868c8ae 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -2,8 +2,9 @@ import torch +import ignite.distributed as idist from ignite.exceptions import NotComputableError -from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import reinit__is_reduced from ignite.metrics.regression._base import _BaseRegression @@ -19,6 +20,11 @@ class PearsonCorrelation(_BaseRegression): where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. + Internally uses `Welford's online algorithm `_ for numerically + stable computation, avoiding catastrophic cancellation that can occur with the + naive sum-of-squares formula in float32. + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. @@ -54,7 +60,10 @@ class PearsonCorrelation(_BaseRegression): .. testoutput:: - 0.9768688678741455 + 0.9768687504744322 + + .. versionchanged:: 0.6.0 + Uses Welford's online algorithm for improved numerical stability. """ def __init__( @@ -67,57 +76,129 @@ def __init__( self.eps = eps + # Welford state kept between updates. `x` denotes `y_pred`, `y` denotes the + # ground truth. Names follow the convention used in Welford (1962) and on + # the Wikipedia page linked in the class docstring: + # _num_examples -- running count of samples seen (n) + # _mean_x, _mean_y -- running means of x and y + # _m2_x, _m2_y -- running sums of squared deviations from the mean: + # M2_x = Σ (x_i - mean_x)^2 + # (dividing by n at the end gives the variance) + # _cxy -- running sum of paired deviations: + # C_xy = Σ (x_i - mean_x)(y_i - mean_y) + # (dividing by n at the end gives the covariance) _state_dict_all_req_keys = ( - "_sum_of_y_preds", - "_sum_of_ys", - "_sum_of_y_pred_squares", - "_sum_of_y_squares", - "_sum_of_products", "_num_examples", + "_mean_x", + "_mean_y", + "_m2_x", + "_m2_y", + "_cxy", ) @reinit__is_reduced def reset(self) -> None: - self._sum_of_y_preds = torch.tensor(0.0, device=self._device) - self._sum_of_ys = torch.tensor(0.0, device=self._device) - self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device) - self._sum_of_y_squares = torch.tensor(0.0, device=self._device) - self._sum_of_products = torch.tensor(0.0, device=self._device) self._num_examples = 0 + self._mean_x = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._mean_y = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._m2_x = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._m2_y = torch.tensor(0.0, dtype=torch.float64, device=self._device) + self._cxy = torch.tensor(0.0, dtype=torch.float64, device=self._device) def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + # Parallel Welford: compute the current batch's Welford state in one + # shot, then merge it into the running state. "_a" suffix = prior + # (accumulated) state, "_b" suffix = this batch, "_ab" = combined. y_pred, y = output[0].detach(), output[1].detach() - self._sum_of_y_preds += y_pred.sum().to(self._device) - self._sum_of_ys += y.sum().to(self._device) - self._sum_of_y_pred_squares += y_pred.square().sum().to(self._device) - self._sum_of_y_squares += y.square().sum().to(self._device) - self._sum_of_products += (y_pred * y).sum().to(self._device) - self._num_examples += y.shape[0] - - @sync_all_reduce( - "_sum_of_y_preds", - "_sum_of_ys", - "_sum_of_y_pred_squares", - "_sum_of_y_squares", - "_sum_of_products", - "_num_examples", - ) + y_pred = y_pred.to(dtype=torch.float64) + y = y.to(dtype=torch.float64) + + n_b = y.shape[0] + n_a = self._num_examples + n_ab = n_a + n_b + + # ---- Batch stats (mean + M2/Cxy computed about the batch mean) ---- + mean_x_b = y_pred.mean().to(self._device) + mean_y_b = y.mean().to(self._device) + dx_b = y_pred - mean_x_b + dy_b = y - mean_y_b + m2_x_b = dx_b.square().sum().to(self._device) + m2_y_b = dy_b.square().sum().to(self._device) + cxy_b = (dx_b * dy_b).sum().to(self._device) + + if n_a == 0: + # First batch: running state is just the batch state. + self._mean_x = mean_x_b + self._mean_y = mean_y_b + self._m2_x = m2_x_b + self._m2_y = m2_y_b + self._cxy = cxy_b + else: + # ---- Welford merge: combine running state (a) with batch (b) ---- + # delta = mean_b - mean_a is the shift between the two group means; + # the correction term n_a*n_b/n_ab * delta^2 accounts for the fact + # that M2_a and M2_b are measured about *different* centres. + delta_x = mean_x_b - self._mean_x + delta_y = mean_y_b - self._mean_y + + self._mean_x += delta_x * n_b / n_ab + self._mean_y += delta_y * n_b / n_ab + + self._m2_x += m2_x_b + delta_x * delta_x * n_a * n_b / n_ab + self._m2_y += m2_y_b + delta_y * delta_y * n_a * n_b / n_ab + self._cxy += cxy_b + delta_x * delta_y * n_a * n_b / n_ab + + self._num_examples = n_ab + def compute(self) -> float: - n = self._num_examples - if n == 0: + if self._num_examples == 0: raise NotComputableError("PearsonCorrelation must have at least one example before it can be computed.") - # cov = E[xy] - E[x]*E[y] - cov = self._sum_of_products / n - self._sum_of_y_preds * self._sum_of_ys / (n * n) - - # var = E[x^2] - E[x]^2 - y_pred_mean = self._sum_of_y_preds / n - y_pred_var = self._sum_of_y_pred_squares / n - y_pred_mean * y_pred_mean - y_pred_var = torch.clamp(y_pred_var, min=0.0) - - y_mean = self._sum_of_ys / n - y_var = self._sum_of_y_squares / n - y_mean * y_mean - y_var = torch.clamp(y_var, min=0.0) - - r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps) + n = self._num_examples + mean_x = self._mean_x + mean_y = self._mean_y + m2_x = self._m2_x + m2_y = self._m2_y + cxy = self._cxy + + ws = idist.get_world_size() + if ws > 1: + # Distributed reduce: each rank already holds its local Welford + # state (n, mean_x, mean_y, M2_x, M2_y, C_xy). We can't sum them + # directly — M2/Cxy on each rank are measured about *that rank's* + # local mean. Instead we all_gather the per-rank states and fold + # them in one by one using the same pairwise merge as _update. + state = torch.stack([ + torch.tensor(float(n), dtype=torch.float64, device=self._device), + mean_x, mean_y, m2_x, m2_y, cxy, + ]) + all_states = idist.all_gather(state) + all_states = all_states.reshape(ws, 6) + + # Seed the accumulator with rank 0's state, then merge ranks 1..ws-1. + n = all_states[0, 0].item() + mean_x = all_states[0, 1] + mean_y = all_states[0, 2] + m2_x = all_states[0, 3] + m2_y = all_states[0, 4] + cxy = all_states[0, 5] + + for i in range(1, ws): + n_i = all_states[i, 0].item() + n_combined = n + n_i + dx = all_states[i, 1] - mean_x + dy = all_states[i, 2] - mean_y + + mean_x = mean_x + dx * n_i / n_combined + mean_y = mean_y + dy * n_i / n_combined + m2_x = m2_x + all_states[i, 3] + dx * dx * n * n_i / n_combined + m2_y = m2_y + all_states[i, 4] + dy * dy * n * n_i / n_combined + cxy = cxy + all_states[i, 5] + dx * dy * n * n_i / n_combined + n = n_combined + + var_x = torch.clamp(m2_x / n, min=0.0) + var_y = torch.clamp(m2_y / n, min=0.0) + cov = cxy / n + + r = cov / torch.clamp(torch.sqrt(var_x * var_y), min=self.eps) return float(r.item()) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py index 599e6fae2033..7472434f557b 100644 --- a/tests/ignite/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -44,7 +44,6 @@ def test_wrong_input_shapes(): def test_degenerated_sample(available_device): if available_device == "mps": pytest.skip(reason="PearsonCorrelation.compute returns nan on mps") - # r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps) # one sample m = PearsonCorrelation(device=available_device) @@ -98,6 +97,35 @@ def test_pearson_correlation(available_device): assert m.compute() == pytest.approx(expected, rel=1e-4) +def test_numerical_stability(available_device): + """Test that Welford's algorithm handles large-mean, small-variance data + that would cause catastrophic cancellation with the naive formula.""" + if available_device == "mps": + pytest.skip(reason="float64 not supported on mps") + + m = PearsonCorrelation(device=available_device) + + torch.manual_seed(42) + # Large offset with small variance: naive E[x^2] - E[x]^2 would fail in float32 + offset = 1e6 + n = 1000 + x = torch.randn(n, dtype=torch.float32) + offset + y = x + torch.randn(n, dtype=torch.float32) * 0.1 # highly correlated + + # Feed in small batches to stress accumulation + batch_size = 10 + for i in range(0, n, batch_size): + m.update((x[i : i + batch_size], y[i : i + batch_size])) + + result = m.compute() + expected = pearsonr(x.numpy(), y.numpy()).statistic + + assert result == pytest.approx(expected, rel=1e-5), ( + f"Numerical instability detected: got {result}, expected {expected}" + ) + assert result > 0.99, f"Correlation should be near 1.0 for highly correlated data, got {result}" + + @pytest.fixture(params=list(range(2))) def test_case(request): # correlated sample @@ -148,11 +176,11 @@ def test_accumulator_detached(available_device): assert all( (not accumulator.requires_grad) for accumulator in ( - corr._sum_of_products, - corr._sum_of_y_pred_squares, - corr._sum_of_y_preds, - corr._sum_of_y_squares, - corr._sum_of_ys, + corr._mean_x, + corr._mean_y, + corr._m2_x, + corr._m2_y, + corr._cxy, ) ) @@ -240,11 +268,11 @@ def test_accumulator_device(self): devices = ( corr._device, - corr._sum_of_products.device, - corr._sum_of_y_pred_squares.device, - corr._sum_of_y_preds.device, - corr._sum_of_y_squares.device, - corr._sum_of_ys.device, + corr._mean_x.device, + corr._mean_y.device, + corr._m2_x.device, + corr._m2_y.device, + corr._cxy.device, ) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" @@ -255,11 +283,11 @@ def test_accumulator_device(self): devices = ( corr._device, - corr._sum_of_products.device, - corr._sum_of_y_pred_squares.device, - corr._sum_of_y_preds.device, - corr._sum_of_y_squares.device, - corr._sum_of_ys.device, + corr._mean_x.device, + corr._mean_y.device, + corr._m2_x.device, + corr._m2_y.device, + corr._cxy.device, ) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"