feat(metrics): add WelfordVariance and WelfordCovariance helpers (PR 1 of #3748)#3750
feat(metrics): add WelfordVariance and WelfordCovariance helpers (PR 1 of #3748)#3750joemunene-by wants to merge 3 commits into
Conversation
Introduces ignite/metrics/_running_stats.py with two numerically
stable running-statistics primitives that variance- and
covariance-bearing metrics can share, instead of each one rolling
its own naive Σx² − (Σx)²/n implementation.
WelfordVariance mean, variance, std for a single variable.
WelfordCovariance variance_x, variance_y, covariance, and
Pearson correlation for a paired (x, y) stream.
Both classes:
- keep internal state in float64 regardless of input dtype, so the
classic E[X²] − E[X]² cancellation does not bite at large means;
- update incrementally via Welford's online algorithm;
- expose merge() implementing the Chan / Welford parallel formula,
suitable for cross-rank distributed reduction or any other case
where two accumulators need to be combined without re-iterating
the raw data.
This is PR 1 of the plan in pytorch#3748. Follow-ups:
PR 2 will port R2Score (pytorch#3662-style regression test attached).
PR 3 will refactor pytorch#3741 to consume WelfordCovariance instead of
its current inline Welford state.
Tests (20 total, all passing):
- per-class correctness vs numpy mean / var / cov / corrcoef
- multi-batch update matches single-batch update
- merge matches concatenated update
- merge with empty accumulators on either side
- numerical-stability regression (mean=1e6 in float32) for both
classes, with an assertion that the naive float32 formula
actually does fail on the same data so the test documents what
we're protecting against
- shape-mismatch raises ValueError
- empty-batch update is a no-op
- reset clears state
- input dtypes (int32) upcast to float64 correctly
- cross-class sanity: WelfordCovariance.variance_x matches
WelfordVariance.variance fed the same x
| mean: torch.Tensor | ||
| sum_sq_dev_from_mean: torch.Tensor | ||
|
|
||
| def __init__(self, device: Union[str, torch.device] = "cpu") -> None: |
There was a problem hiding this comment.
This class should not handle device, neither dtype
| import torch | ||
|
|
||
|
|
||
| class WelfordVariance: |
There was a problem hiding this comment.
Let's make it as a dataclass?
| """ | ||
| if batch.numel() == 0: | ||
| return | ||
| batch64 = batch.detach().to(dtype=torch.float64).flatten() |
There was a problem hiding this comment.
I do not think we should flatten it and for the mean computation we may need to specify the axis (to confirm)
Addressing @vfdev-5's inline review on pytorch#3750: - Drop the device and dtype constructor args. The helper now leaves placement and precision to the caller; state takes the dtype and device of the first batch passed to update(). PearsonCorrelation and R2Score already do their own float64 upcast before handing inputs to the helper, so this is a no-op for the planned consumers. - Switch both classes to @DataClass with field(default_factory=...) for the tensor fields. Drops the manual __init__ / reset() plumbing; "reset" is now reconstruction (m.welford = WelfordVariance()), which is the natural fit for how the consumer Metric.reset() methods already work. - Drop the explicit .flatten() on update inputs. batch.mean() and batch.numel() both reduce over the full tensor regardless of shape, so behavior for the current scalar-reduction consumers is unchanged, and the code reads more naturally for any shape. Tests adjusted accordingly: - test_reset replaced by test_fresh_instance_has_zero_state, which documents the dataclass default-factory behavior. - test_input_dtype_upcast_to_float64 replaced by test_state_dtype_follows_first_batch, which verifies dtype is preserved (the design change). - Stability tests upcast inputs caller-side before handing to the helper, matching how the metric classes will use it. - test_multi_batch_matches_single_batch switched to float64 inputs so it exercises the algorithm rather than float32 noise. All 20 tests still pass, ruff format / check clean. The question about axis-aware reduction is deferred to the review thread; I'll follow it once @vfdev-5 confirms whether it lands here or as a follow-up.
|
Thanks for the review. Pushed 1b82442 addressing all three points. 1. No device / dtype in the helper. Dropped the constructor args and all internal 2. Dataclasses. Both classes are now 3. Flatten. Removed the explicit On axis-aware reduction: the current consumers (R2Score, PearsonCorrelation) both produce a single scalar variance / covariance, so they only need the full-reduction behavior. If we want axis support for future per-channel use cases (running variance over def update(self, batch: torch.Tensor, dim: Optional[Union[int, tuple[int, ...]]] = None) -> None:
...
n_b = batch.numel() if dim is None else _samples_along(batch.shape, dim)
mean_b = batch.mean(dim=dim, keepdim=False) if dim is not None else batch.mean()
...The state on first update would take the shape of Two ways we can take it:
I lean toward (b): the consumers don't need it, axis-aware running stats have a few subtleties (Bessel correction conventions, mixed-shape merge, axis reordering) that are easier to design alongside a concrete user. Happy to go either way; let me know which you prefer. |
|
@joemunene-by hi, I was trying to review but I find the code to be quite confusing I think we can try to make more readable :) |
| def _zero() -> torch.Tensor: | ||
| return torch.tensor(0.0) |
There was a problem hiding this comment.
why we writing a function for a code which is literally 1 line?
| if self.n_samples == 0: | ||
| self.mean = mean_b | ||
| self.sum_sq_dev_from_mean = m2_b | ||
| self.n_samples = n_b | ||
| return | ||
|
|
||
| n_a = self.n_samples | ||
| n_ab = n_a + n_b | ||
| delta = mean_b - self.mean | ||
| self.mean = self.mean + delta * n_b / n_ab | ||
| self.sum_sq_dev_from_mean = self.sum_sq_dev_from_mean + m2_b + delta * delta * n_a * n_b / n_ab | ||
| self.n_samples = n_ab |
There was a problem hiding this comment.
This part is redundant as merge is already doing it we can just use that.
| self.sum_sq_dev_from_mean = self.sum_sq_dev_from_mean + m2_b + delta * delta * n_a * n_b / n_ab | ||
| self.n_samples = n_ab | ||
|
|
||
| def merge(self, other: "WelfordVariance") -> None: |
There was a problem hiding this comment.
Why do we need merge?
| if self.n_samples == 0: | ||
| self.mean_x = mean_x_b | ||
| self.mean_y = mean_y_b | ||
| self.sum_sq_dev_x = m2_x_b | ||
| self.sum_sq_dev_y = m2_y_b | ||
| self.sum_product_of_devs = cxy_b | ||
| self.n_samples = n_b | ||
| return | ||
|
|
||
| n_a = self.n_samples | ||
| n_ab = n_a + n_b | ||
| cross = n_a * n_b / n_ab | ||
| delta_x = mean_x_b - self.mean_x | ||
| delta_y = mean_y_b - self.mean_y | ||
|
|
||
| self.mean_x = self.mean_x + delta_x * n_b / n_ab | ||
| self.mean_y = self.mean_y + delta_y * n_b / n_ab | ||
| self.sum_sq_dev_x = self.sum_sq_dev_x + m2_x_b + delta_x * delta_x * cross | ||
| self.sum_sq_dev_y = self.sum_sq_dev_y + m2_y_b + delta_y * delta_y * cross | ||
| self.sum_product_of_devs = self.sum_product_of_devs + cxy_b + delta_x * delta_y * cross |
Three readability changes responding to @aaishwarymishra's inline review: 1. _zero() helper removed; each tensor field uses field(default_factory=lambda: torch.tensor(0.0)) directly. 2. update() is now the degenerate case of merge() where "other" is a freshly-built single-batch accumulator. The Chan / Welford parallel formula lives in exactly one place. Same refactor applied to WelfordCovariance.update. 3. merge() docstring grew a paragraph explaining the distributed- reduction motivation -- without an explicit merge, cross-rank reduction has to re-iterate the raw data, which defeats the point of an online algorithm. Behavior is bit-equivalent: the only delta is an extra detach/clone on the first-batch path (via merge's first-time-absorb branch), which is a no-op for correctness.
|
@aaishwarymishra thanks, your inline points are all fair. Pushed e153f57 addressing each. 1. 2. The mental model that fell out of this is worth stating: 3. "Why do we need The bonus is that the two-path design ( cc @vfdev-5 — points from your earlier review are still addressed in 1b82442; this commit is a follow-up to @aaishwarymishra's readability pass. Ready for another look when you have a moment. |
Summary
PR 1 of the plan in #3748: introduces
ignite/metrics/_running_stats.pywith two numerically stable running-statistics primitives that variance- and covariance-bearing metrics can share, instead of each one rolling its own naiveΣx² − (Σx)²/nimplementation.WelfordVariance: runningmean,variance,stdfor a single variable.WelfordCovariance: runningvariance_x,variance_y,covariance, and Pearsoncorrelation()for a paired(x, y)stream.Both classes:
float64regardless of input dtype, so the classicE[X²] − E[X]²cancellation does not bite at large means (the failure mode from [Bug] Numerical instability inPearsonCorrelationdue to naive variance formula #3662);merge()implementing the Chan / Welford parallel formula, suitable for cross-rank distributed reduction or any other case where two accumulators need to be combined without re-iterating the raw data.No existing metric is touched in this PR. The helper lands first so the two consumer PRs can each diff against a stable shared API.
Follow-ups (already scoped in #3748)
R2ScoretoWelfordVariance, with a regression test mirroring [Bug] Numerical instability inPearsonCorrelationdue to naive variance formula #3662'smean=1e6, std=1failure case.WelfordCovarianceinstead. Pure refactor on top of the existing review.API
Tests
20 unit tests in
tests/ignite/metrics/test_running_stats.py, all passing locally withpytest tests/ignite/metrics/test_running_stats.py:WelfordVariance
numpy.meanandnumpy.varto 1e-12mergeof two accumulators matchesupdateon the concatenated datamergewith an empty accumulator on either side is a no-op or absorbs the other sidemean=1e6, std=1in float32; Welford in float64 recovers the true variance, and the same data through the naiveΣx² − (Σx)²/nformula in float32 is asserted to fail by ≥ 0.1, so the test documents the failure mode it is protecting againstupdateis a no-opresetclears stateWelfordCovariance
numpy.cov(..., bias=True)andnumpy.corrcoefto 1e-10 to 1e-12mergematches concatenatedupdatemean=1e6, true correlation> 0.99; Welford recoversrwithin 1e-4 of the float64 ground truthValueErrorr = 0.0(not NaN) via theepsclampresetclears stateCross-class sanity
WelfordCovariance.variance_xmatchesWelfordVariance.variancefed the samex. Catches future drift between the two implementations.Conventions
ruff formatclean,ruff checkclean.torch.Tensoron the user-supplied device, mirroring the dtype / device convention used elsewhere inignite.metrics._running_stats.py(leading underscore) and not re-exported fromignite/metrics/__init__.py, so it is internal to the metrics module. Public access stays through individual metric classes; the helper can graduate to a public API later if there is demand.Test plan
pytest tests/ignite/metrics/test_running_stats.py: 20 / 20 passing locallyruff format --check,ruff checkcc @vfdev-5. Opening PR 2 (R2Score port) and PR 3 (PearsonCorrelation refactor on top of #3741) once this lands.