Skip to content

Bugfix for GeneralizedDiceScore to yield NaN for missing classes#3251

Open
VijayVignesh1 wants to merge 20 commits into
Lightning-AI:masterfrom
VijayVignesh1:bugfix/generalized_dice_ignore_empty_classes
Open

Bugfix for GeneralizedDiceScore to yield NaN for missing classes#3251
VijayVignesh1 wants to merge 20 commits into
Lightning-AI:masterfrom
VijayVignesh1:bugfix/generalized_dice_ignore_empty_classes

Conversation

@VijayVignesh1
Copy link
Copy Markdown
Contributor

@VijayVignesh1 VijayVignesh1 commented Sep 4, 2025

What does this PR do?

Fixes GeneralizedDiceScore to yield NaN for missing classes when per_class=True.
Added testcases to verify the same.

Fixes #2846

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--3251.org.readthedocs.build/en/3251/

@VijayVignesh1 VijayVignesh1 marked this pull request as ready for review September 4, 2025 11:20
Borda
Borda previously approved these changes Sep 4, 2025
@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 4, 2025

Codecov Report

❌ Patch coverage is 35.71429% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 32%. Comparing base (e08e009) to head (ef5dade).
⚠️ Report is 1 commits behind head on master.

❗ There is a different number of reports uploaded between BASE (e08e009) and HEAD (ef5dade). Click for more details.

HEAD has 478 uploads less than BASE
Flag BASE (e08e009) HEAD (ef5dade)
torch2.0.1+cpu 24 3
python3.10 96 12
Windows 16 2
cpu 136 17
macOS 24 3
torch2.0.1 16 2
python3.12 32 4
torch2.8.0+cpu 24 3
Linux 96 12
torch2.7.1+cpu 16 2
torch2.8.0 8 1
torch2.2.2+cpu 8 1
torch2.3.1+cpu 8 1
torch2.4.1+cpu 8 1
torch2.1.2+cpu 8 1
torch2.6.0+cpu 8 1
torch2.5.1+cpu 8 1
python3.9 8 1
gpu 1 0
unittest 1 0
Additional details and impacted files
@@           Coverage Diff            @@
##           master   #3251     +/-   ##
========================================
- Coverage      69%     32%    -37%     
========================================
  Files         364     349     -15     
  Lines       20096   19905    -191     
========================================
- Hits        13790    6377   -7413     
- Misses       6306   13528   +7222     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mergify mergify Bot added the ready label Sep 4, 2025
@Borda Borda changed the title Bugfix for GeneralizedDiceScore to yield NaN for missing classes Bugfix for GeneralizedDiceScore to yield NaN for missing classes Sep 5, 2025
@Borda Borda enabled auto-merge (squash) September 5, 2025 07:02
Comment thread src/torchmetrics/segmentation/generalized_dice.py Outdated
Comment thread src/torchmetrics/segmentation/generalized_dice.py Outdated
@Borda Borda dismissed their stale review September 6, 2025 11:27

Let's revisit the range

@mergify mergify Bot removed the ready label Sep 6, 2025
auto-merge was automatically disabled September 8, 2025 18:49

Head branch was pushed to by a user without write access

@VijayVignesh1
Copy link
Copy Markdown
Contributor Author

Done. Aggregating the numerator and denominator in the update function and finally dividing them in the compute function brings the range back to [0,1].

@mergify mergify Bot added the ready label Sep 12, 2025
Comment thread tests/unittests/segmentation/test_generalized_dice_score.py Outdated
@Borda Borda requested a review from Copilot September 19, 2025 19:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a bug in GeneralizedDiceScore where missing classes were not properly handled when per_class=True, causing incorrect behavior instead of returning NaN for those classes.

  • Refactored the metric's state tracking to properly identify missing classes and return NaN for them
  • Updated the compute logic to handle both per-class and aggregate scoring correctly
  • Added comprehensive test coverage for missing class scenarios

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
src/torchmetrics/segmentation/generalized_dice.py Refactored state management and compute logic to track class presence and return NaN for missing classes
src/torchmetrics/functional/segmentation/generalized_dice.py Added NaN handling for non-per-class computation when denominator is zero
tests/unittests/segmentation/test_generalized_dice_score.py Added test cases for missing classes and zero denominator scenarios
CHANGELOG.md Updated changelog to document the bug fix

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment thread src/torchmetrics/segmentation/generalized_dice.py Outdated
Comment thread src/torchmetrics/segmentation/generalized_dice.py Outdated
@Borda
Copy link
Copy Markdown
Collaborator

Borda commented Sep 28, 2025

@VijayVignesh1 dont worry, Copilot is not always right but it may spot some overlooked issues

@mergify mergify Bot removed the ready label Sep 28, 2025
Comment thread src/torchmetrics/segmentation/generalized_dice.py Outdated
Comment thread src/torchmetrics/functional/segmentation/generalized_dice.py Outdated
@mergify mergify Bot requested a review from a team September 29, 2025 19:02
@mergify mergify Bot added the ready label Sep 30, 2025
Comment thread src/torchmetrics/functional/segmentation/generalized_dice.py Outdated
Comment thread src/torchmetrics/functional/segmentation/generalized_dice.py Outdated
@mergify mergify Bot requested a review from a team September 30, 2025 10:18
@Borda Borda self-requested a review September 30, 2025 10:20
)
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]
self.numerator.append(numerator)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this still changes the memory format if you add to a list rather than adding to a tensor

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate this a bit more?

@rittik9 rittik9 requested a review from lantiga as a code owner January 28, 2026 13:47
@RecreationalMath
Copy link
Copy Markdown

RecreationalMath commented Apr 20, 2026

Hi, I was investigating #2846 and took a look at why this PR stalled. I'd like to help move it forward.

On the memory-format concern @justusschock raised: I checked this out locally and measured state memory directly. Master holds ~44 bytes of state regardless of how many update() calls you make. This PR's List[Tensor] grows linearly:

updates PR state size
1,000 1.2 MB
10,000 12 MB
62,500 76 MB

That's fine for a single val sweep, but the regression is unbounded in the number of update() calls (vs master's O(1)), plus 125k+ individual tensor objects at the high end, and compute() transiently doubles that footprint via dim_zero_cat. For a metric that logically only needs running sums, that's concerning.

To preserve master's memory profile: The per-class numerator and denominator are sum-reducible across batches (so they become two fixed-size tensor sums). Master's per_class=False path averages per-sample ratios rather than doing a global sum/sum, so to preserve that semantic we can track two scalars, a running sum of valid per-sample scores and a count of valid samples. All four states are fixed-size, O(C) total:

self.add_state("numerator_sum",   default=torch.zeros(n), dist_reduce_fx="sum")
self.add_state("denominator_sum", default=torch.zeros(n), dist_reduce_fx="sum")
# per_class=False: sample-wise mean, excluding degenerate (all-zero target) samples
self.add_state("score_sum",     default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("valid_samples", default=torch.zeros(1), dist_reduce_fx="sum")

def update(self, preds, target):
    num, den = _generalized_dice_update(...)  # [N, C]
    self.numerator_sum   += num.sum(dim=0)
    self.denominator_sum += den.sum(dim=0)
    per_sample = _safe_divide(num.sum(dim=1), den.sum(dim=1), "nan")
    self.score_sum     += torch.nansum(per_sample)
    self.valid_samples += (~torch.isnan(per_sample)).sum()

def compute(self):
    if self.per_class:
        return _safe_divide(self.numerator_sum, self.denominator_sum, "nan")
    return _safe_divide(self.score_sum, self.valid_samples, "nan").squeeze()

I ran this against the #2846 repro on current master:

  per_class=True : tensor([1., 1., nan])   # matches the expected output
  per_class=False: tensor(1.)              # two valid samples, both 1.0
  multi-update   : tensor([1., 1., nan])   # same under split update() calls
  all-empty      : nan / [nan, nan, nan]   # degenerate case still signals nan

For per_class=True this converges to the same "global GDS" aggregation the current PR produces (sum num / sum den), it just does it incrementally instead of buffering per-sample tensors.

One thing worth flagging in the current PR regardless of which structure ends up merged: the per_class=False path does score.mean(dim=0) over per-sample ratios post-_safe_divide(..., "nan"), so a single empty-target sample mixed with real ones turns the whole result into NaN (mean([1.0, nan, 1.0, nan]) = nan). Existing tests only cover the all-empty case, so this slips through. Either .nanmean() or a nansum/valid-count pattern would avoid it.

I'm happy to drive this to merge. Two options depending on PR author's availability:

@Borda @justusschock, does this direction work? Happy to start once I have a thumbs up.

@VijayVignesh1
Copy link
Copy Markdown
Contributor Author

Hi, I was investigating #2846 and took a look at why this PR stalled. I'd like to help move it forward.

On the memory-format concern @justusschock raised: I checked this out locally and measured state memory directly. Master holds ~44 bytes of state regardless of how many update() calls you make. This PR's List[Tensor] grows linearly:
updates PR state size
1,000 1.2 MB
10,000 12 MB
62,500 76 MB

That's fine for a single val sweep, but the regression is unbounded in the number of update() calls (vs master's O(1)), plus 125k+ individual tensor objects at the high end, and compute() transiently doubles that footprint via dim_zero_cat. For a metric that logically only needs running sums, that's concerning.

To preserve master's memory profile: The per-class numerator and denominator are sum-reducible across batches (so they become two fixed-size tensor sums). Master's per_class=False path averages per-sample ratios rather than doing a global sum/sum, so to preserve that semantic we can track two scalars, a running sum of valid per-sample scores and a count of valid samples. All four states are fixed-size, O(C) total:

self.add_state("numerator_sum",   default=torch.zeros(n), dist_reduce_fx="sum")
self.add_state("denominator_sum", default=torch.zeros(n), dist_reduce_fx="sum")
# per_class=False: sample-wise mean, excluding degenerate (all-zero target) samples
self.add_state("score_sum",     default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("valid_samples", default=torch.zeros(1), dist_reduce_fx="sum")

def update(self, preds, target):
    num, den = _generalized_dice_update(...)  # [N, C]
    self.numerator_sum   += num.sum(dim=0)
    self.denominator_sum += den.sum(dim=0)
    per_sample = _safe_divide(num.sum(dim=1), den.sum(dim=1), "nan")
    self.score_sum     += torch.nansum(per_sample)
    self.valid_samples += (~torch.isnan(per_sample)).sum()

def compute(self):
    if self.per_class:
        return _safe_divide(self.numerator_sum, self.denominator_sum, "nan")
    return _safe_divide(self.score_sum, self.valid_samples, "nan").squeeze()

I ran this against the #2846 repro on current master:

  per_class=True : tensor([1., 1., nan])   # matches the expected output
  per_class=False: tensor(1.)              # two valid samples, both 1.0
  multi-update   : tensor([1., 1., nan])   # same under split update() calls
  all-empty      : nan / [nan, nan, nan]   # degenerate case still signals nan

For per_class=True this converges to the same "global GDS" aggregation the current PR produces (sum num / sum den), it just does it incrementally instead of buffering per-sample tensors.

One thing worth flagging in the current PR regardless of which structure ends up merged: the per_class=False path does score.mean(dim=0) over per-sample ratios post-_safe_divide(..., "nan"), so a single empty-target sample mixed with real ones turns the whole result into NaN (mean([1.0, nan, 1.0, nan]) = nan). Existing tests only cover the all-empty case, so this slips through. Either .nanmean() or a nansum/valid-count pattern would avoid it.

I'm happy to drive this to merge. Two options depending on PR author's availability:

* If @VijayVignesh1 is still active on this, I'm glad to contribute the changes directly to PR [Bugfix for `GeneralizedDiceScore` to yield `NaN` for missing classes #3251](https://github.com/Lightning-AI/torchmetrics/pull/3251).

* If not, I can open a follow-up PR that builds on [Bugfix for `GeneralizedDiceScore` to yield `NaN` for missing classes #3251](https://github.com/Lightning-AI/torchmetrics/pull/3251) and credits @VijayVignesh1 as co-author, whichever closes out [`GeneralizedDiceScore` yields 0 scores when using `per_class=True` for samples where class is not present #2846](https://github.com/Lightning-AI/torchmetrics/issues/2846) fastest.

@Borda @justusschock, does this direction work? Happy to start once I have a thumbs up.

Sure, we can work on the same PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GeneralizedDiceScore yields 0 scores when using per_class=True for samples where class is not present

6 participants