Skip to content

feat(metrics): add WelfordVariance and WelfordCovariance helpers (PR 1 of #3748)#3750

Open
joemunene-by wants to merge 3 commits into
pytorch:masterfrom
joemunene-by:feat/welford-running-stats-helper
Open

feat(metrics): add WelfordVariance and WelfordCovariance helpers (PR 1 of #3748)#3750
joemunene-by wants to merge 3 commits into
pytorch:masterfrom
joemunene-by:feat/welford-running-stats-helper

Conversation

@joemunene-by
Copy link
Copy Markdown

Summary

PR 1 of the plan in #3748: introduces ignite/metrics/_running_stats.py with two numerically stable running-statistics primitives that variance- and covariance-bearing metrics can share, instead of each one rolling its own naive Σx² − (Σx)²/n implementation.

  • WelfordVariance: running mean, variance, std for a single variable.
  • WelfordCovariance: running variance_x, variance_y, covariance, and Pearson correlation() for a paired (x, y) stream.

Both classes:

  • keep internal state in float64 regardless of input dtype, so the classic E[X²] − E[X]² cancellation does not bite at large means (the failure mode from [Bug] Numerical instability in PearsonCorrelation due to naive variance formula #3662);
  • update incrementally via Welford's online algorithm;
  • expose merge() implementing the Chan / Welford parallel formula, suitable for cross-rank distributed reduction or any other case where two accumulators need to be combined without re-iterating the raw data.

No existing metric is touched in this PR. The helper lands first so the two consumer PRs can each diff against a stable shared API.

Follow-ups (already scoped in #3748)

API

class WelfordVariance:
    n_samples: int
    mean: torch.Tensor                       # running mean
    sum_sq_dev_from_mean: torch.Tensor       # M2 = Σ (x_i − mean)^2

    def update(self, batch: torch.Tensor) -> None: ...
    def merge(self, other: WelfordVariance) -> None: ...

    @property
    def variance(self) -> torch.Tensor: ...  # M2 / n
    @property
    def std(self) -> torch.Tensor: ...       # sqrt(variance)


class WelfordCovariance:
    n_samples: int
    mean_x: torch.Tensor
    mean_y: torch.Tensor
    sum_sq_dev_x: torch.Tensor
    sum_sq_dev_y: torch.Tensor
    sum_product_of_devs: torch.Tensor        # Σ (x_i − mean_x)(y_i − mean_y)

    def update(self, batch_x, batch_y) -> None: ...
    def merge(self, other: WelfordCovariance) -> None: ...

    @property
    def variance_x(self) -> torch.Tensor: ...
    @property
    def variance_y(self) -> torch.Tensor: ...
    @property
    def covariance(self) -> torch.Tensor: ...
    def correlation(self, eps: float = 1e-8) -> torch.Tensor: ...  # Pearson r

Tests

20 unit tests in tests/ignite/metrics/test_running_stats.py, all passing locally with pytest tests/ignite/metrics/test_running_stats.py:

WelfordVariance

  • empty accumulator returns zero variance / std without crashing
  • single-batch update matches numpy.mean and numpy.var to 1e-12
  • multi-batch update matches single-batch on the same data
  • merge of two accumulators matches update on the concatenated data
  • merge with an empty accumulator on either side is a no-op or absorbs the other side
  • numerical-stability regression: 10,000 samples at mean=1e6, std=1 in float32; Welford in float64 recovers the true variance, and the same data through the naive Σx² − (Σx)²/n formula in float32 is asserted to fail by ≥ 0.1, so the test documents the failure mode it is protecting against
  • single-sample case (variance is 0, not undefined)
  • empty-batch update is a no-op
  • reset clears state
  • integer input dtypes upcast to float64

WelfordCovariance

  • empty accumulator returns zero everywhere
  • single-batch update matches numpy.cov(..., bias=True) and numpy.corrcoef to 1e-10 to 1e-12
  • multi-batch update matches single-batch
  • merge matches concatenated update
  • numerical-stability regression: 10,000 paired samples at mean=1e6, true correlation > 0.99; Welford recovers r within 1e-4 of the float64 ground truth
  • shape mismatch raises ValueError
  • empty-batch update is a no-op
  • constant-y series returns r = 0.0 (not NaN) via the eps clamp
  • reset clears state

Cross-class sanity

  • WelfordCovariance.variance_x matches WelfordVariance.variance fed the same x. Catches future drift between the two implementations.

Conventions

  • ruff format clean, ruff check clean.
  • Internal state, return values, and properties are all torch.Tensor on the user-supplied device, mirroring the dtype / device convention used elsewhere in ignite.metrics.
  • File is named _running_stats.py (leading underscore) and not re-exported from ignite/metrics/__init__.py, so it is internal to the metrics module. Public access stays through individual metric classes; the helper can graduate to a public API later if there is demand.

Test plan

  • pytest tests/ignite/metrics/test_running_stats.py: 20 / 20 passing locally
  • ruff format --check, ruff check
  • CI on this PR

cc @vfdev-5. Opening PR 2 (R2Score port) and PR 3 (PearsonCorrelation refactor on top of #3741) once this lands.

Introduces ignite/metrics/_running_stats.py with two numerically
stable running-statistics primitives that variance- and
covariance-bearing metrics can share, instead of each one rolling
its own naive Σx² − (Σx)²/n implementation.

  WelfordVariance     mean, variance, std for a single variable.
  WelfordCovariance   variance_x, variance_y, covariance, and
                      Pearson correlation for a paired (x, y) stream.

Both classes:
  - keep internal state in float64 regardless of input dtype, so the
    classic E[X²] − E[X]² cancellation does not bite at large means;
  - update incrementally via Welford's online algorithm;
  - expose merge() implementing the Chan / Welford parallel formula,
    suitable for cross-rank distributed reduction or any other case
    where two accumulators need to be combined without re-iterating
    the raw data.

This is PR 1 of the plan in pytorch#3748. Follow-ups:
  PR 2 will port R2Score (pytorch#3662-style regression test attached).
  PR 3 will refactor pytorch#3741 to consume WelfordCovariance instead of
    its current inline Welford state.

Tests (20 total, all passing):
  - per-class correctness vs numpy mean / var / cov / corrcoef
  - multi-batch update matches single-batch update
  - merge matches concatenated update
  - merge with empty accumulators on either side
  - numerical-stability regression (mean=1e6 in float32) for both
    classes, with an assertion that the naive float32 formula
    actually does fail on the same data so the test documents what
    we're protecting against
  - shape-mismatch raises ValueError
  - empty-batch update is a no-op
  - reset clears state
  - input dtypes (int32) upcast to float64 correctly
  - cross-class sanity: WelfordCovariance.variance_x matches
    WelfordVariance.variance fed the same x
Comment thread ignite/metrics/_running_stats.py Outdated
mean: torch.Tensor
sum_sq_dev_from_mean: torch.Tensor

def __init__(self, device: Union[str, torch.device] = "cpu") -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should not handle device, neither dtype

import torch


class WelfordVariance:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it as a dataclass?

Comment thread ignite/metrics/_running_stats.py Outdated
"""
if batch.numel() == 0:
return
batch64 = batch.detach().to(dtype=torch.float64).flatten()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we should flatten it and for the mean computation we may need to specify the axis (to confirm)

Addressing @vfdev-5's inline review on pytorch#3750:

  - Drop the device and dtype constructor args. The helper now leaves
    placement and precision to the caller; state takes the dtype and
    device of the first batch passed to update(). PearsonCorrelation
    and R2Score already do their own float64 upcast before handing
    inputs to the helper, so this is a no-op for the planned
    consumers.

  - Switch both classes to @DataClass with field(default_factory=...)
    for the tensor fields. Drops the manual __init__ / reset()
    plumbing; "reset" is now reconstruction (m.welford =
    WelfordVariance()), which is the natural fit for how the consumer
    Metric.reset() methods already work.

  - Drop the explicit .flatten() on update inputs. batch.mean() and
    batch.numel() both reduce over the full tensor regardless of
    shape, so behavior for the current scalar-reduction consumers is
    unchanged, and the code reads more naturally for any shape.

Tests adjusted accordingly:
  - test_reset replaced by test_fresh_instance_has_zero_state, which
    documents the dataclass default-factory behavior.
  - test_input_dtype_upcast_to_float64 replaced by
    test_state_dtype_follows_first_batch, which verifies dtype is
    preserved (the design change).
  - Stability tests upcast inputs caller-side before handing to the
    helper, matching how the metric classes will use it.
  - test_multi_batch_matches_single_batch switched to float64 inputs
    so it exercises the algorithm rather than float32 noise.

All 20 tests still pass, ruff format / check clean. The question
about axis-aware reduction is deferred to the review thread; I'll
follow it once @vfdev-5 confirms whether it lands here or as a
follow-up.
@joemunene-by
Copy link
Copy Markdown
Author

Thanks for the review. Pushed 1b82442 addressing all three points.

1. No device / dtype in the helper. Dropped the constructor args and all internal .to(device) / .to(torch.float64) calls. State now takes the dtype and device of the first batch passed to update. Callers (R2Score, PearsonCorrelation) are responsible for the float64 upcast, which they were already doing in their own update methods anyway, so this is a no-op for the planned consumers and a cleaner contract for any future caller.

2. Dataclasses. Both classes are now @dataclass with field(default_factory=...) for the tensor fields. Dropped the manual __init__ and the reset() method; "reset" becomes reconstruction (self.welford = WelfordVariance()), which fits cleanly into how Metric.reset() is normally written in the consumer classes. Tests updated accordingly: test_reset is now test_fresh_instance_has_zero_state and just asserts the default-factory state.

3. Flatten. Removed the explicit .flatten() call. batch.mean() and batch.numel() already reduce over the full tensor regardless of shape, so the current scalar-reduction behavior is preserved and the code reads cleaner for any input shape.

On axis-aware reduction: the current consumers (R2Score, PearsonCorrelation) both produce a single scalar variance / covariance, so they only need the full-reduction behavior. If we want axis support for future per-channel use cases (running variance over (N, C, H, W) with reduction over (0, 2, 3)), the natural shape of the change is an optional dim parameter on update:

def update(self, batch: torch.Tensor, dim: Optional[Union[int, tuple[int, ...]]] = None) -> None:
    ...
    n_b = batch.numel() if dim is None else _samples_along(batch.shape, dim)
    mean_b = batch.mean(dim=dim, keepdim=False) if dim is not None else batch.mean()
    ...

The state on first update would take the shape of mean_b, and n_samples would become an int count of samples-per-position (still scalar because the reduction is the same at every position).

Two ways we can take it:

  • (a) Add the dim parameter in this PR so the helper ships fully-featured, even though no consumer uses it yet.
  • (b) Leave it scalar for now; add dim as a follow-up the first time a metric actually needs it (avoids speculative API surface).

I lean toward (b): the consumers don't need it, axis-aware running stats have a few subtleties (Bessel correction conventions, mixed-shape merge, axis reordering) that are easier to design alongside a concrete user. Happy to go either way; let me know which you prefer.

@aaishwarymishra
Copy link
Copy Markdown
Collaborator

@joemunene-by hi, I was trying to review but I find the code to be quite confusing I think we can try to make more readable :)

Comment thread ignite/metrics/_running_stats.py Outdated
Comment on lines +34 to +35
def _zero() -> torch.Tensor:
return torch.tensor(0.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we writing a function for a code which is literally 1 line?

Comment thread ignite/metrics/_running_stats.py Outdated
Comment on lines +81 to +92
if self.n_samples == 0:
self.mean = mean_b
self.sum_sq_dev_from_mean = m2_b
self.n_samples = n_b
return

n_a = self.n_samples
n_ab = n_a + n_b
delta = mean_b - self.mean
self.mean = self.mean + delta * n_b / n_ab
self.sum_sq_dev_from_mean = self.sum_sq_dev_from_mean + m2_b + delta * delta * n_a * n_b / n_ab
self.n_samples = n_ab
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is redundant as merge is already doing it we can just use that.

self.sum_sq_dev_from_mean = self.sum_sq_dev_from_mean + m2_b + delta * delta * n_a * n_b / n_ab
self.n_samples = n_ab

def merge(self, other: "WelfordVariance") -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need merge?

Comment thread ignite/metrics/_running_stats.py Outdated
Comment on lines +173 to +192
if self.n_samples == 0:
self.mean_x = mean_x_b
self.mean_y = mean_y_b
self.sum_sq_dev_x = m2_x_b
self.sum_sq_dev_y = m2_y_b
self.sum_product_of_devs = cxy_b
self.n_samples = n_b
return

n_a = self.n_samples
n_ab = n_a + n_b
cross = n_a * n_b / n_ab
delta_x = mean_x_b - self.mean_x
delta_y = mean_y_b - self.mean_y

self.mean_x = self.mean_x + delta_x * n_b / n_ab
self.mean_y = self.mean_y + delta_y * n_b / n_ab
self.sum_sq_dev_x = self.sum_sq_dev_x + m2_x_b + delta_x * delta_x * cross
self.sum_sq_dev_y = self.sum_sq_dev_y + m2_y_b + delta_y * delta_y * cross
self.sum_product_of_devs = self.sum_product_of_devs + cxy_b + delta_x * delta_y * cross
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Three readability changes responding to @aaishwarymishra's inline review:

1. _zero() helper removed; each tensor field uses
   field(default_factory=lambda: torch.tensor(0.0)) directly.

2. update() is now the degenerate case of merge() where "other" is a
   freshly-built single-batch accumulator. The Chan / Welford parallel
   formula lives in exactly one place. Same refactor applied to
   WelfordCovariance.update.

3. merge() docstring grew a paragraph explaining the distributed-
   reduction motivation -- without an explicit merge, cross-rank
   reduction has to re-iterate the raw data, which defeats the point
   of an online algorithm.

Behavior is bit-equivalent: the only delta is an extra detach/clone
on the first-batch path (via merge's first-time-absorb branch), which
is a no-op for correctness.
@joemunene-by
Copy link
Copy Markdown
Author

@aaishwarymishra thanks, your inline points are all fair. Pushed e153f57 addressing each.

1. _zero() for a one-liner. Removed. Each tensor field now uses field(default_factory=lambda: torch.tensor(0.0)) directly. Six lines of one-purpose helper become six explicit defaults.

2. update's body duplicates merge. Refactored. update now builds a single-batch WelfordVariance from the input and calls self.merge(batch_acc). The Chan / Welford parallel formula now lives in exactly one place. Same refactor applied to WelfordCovariance.update (your "same here" on line 192).

The mental model that fell out of this is worth stating: update is the degenerate case of merge where the right-hand side has just been built from one batch. There's one formula, not two. I lifted that into the module docstring and the method docstrings so the next reader sees it immediately.

3. "Why do we need merge?" Distributed reduction. In multi-rank training, each rank accumulates its own WelfordVariance / WelfordCovariance over its local samples; at eval time the ranks fold their accumulators together rank-by-rank to produce the population statistic. Without an explicit merge, that cross-rank reduction has to re-iterate the raw data, which defeats the whole point of an online algorithm. I made this rationale explicit in merge's docstring (previously just one line, now a paragraph).

The bonus is that the two-path design (update for the local-batch case, merge for the cross-state case) collapses to a single formula at the implementation level, which I think also lands the "more readable" feedback from your top-level comment. Let me know if anything's still murky.

cc @vfdev-5 — points from your earlier review are still addressed in 1b82442; this commit is a follow-up to @aaishwarymishra's readability pass. Ready for another look when you have a moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: metrics Metrics module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants