Bugfix for GeneralizedDiceScore to yield NaN for missing classes#3251
Bugfix for GeneralizedDiceScore to yield NaN for missing classes#3251VijayVignesh1 wants to merge 20 commits into
GeneralizedDiceScore to yield NaN for missing classes#3251Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #3251 +/- ##
========================================
- Coverage 69% 32% -37%
========================================
Files 364 349 -15
Lines 20096 19905 -191
========================================
- Hits 13790 6377 -7413
- Misses 6306 13528 +7222 🚀 New features to boost your workflow:
|
GeneralizedDiceScore to yield NaN for missing classes
Head branch was pushed to by a user without write access
|
Done. Aggregating the numerator and denominator in the update function and finally dividing them in the compute function brings the range back to [0,1]. |
There was a problem hiding this comment.
Pull Request Overview
This PR fixes a bug in GeneralizedDiceScore where missing classes were not properly handled when per_class=True, causing incorrect behavior instead of returning NaN for those classes.
- Refactored the metric's state tracking to properly identify missing classes and return
NaNfor them - Updated the compute logic to handle both per-class and aggregate scoring correctly
- Added comprehensive test coverage for missing class scenarios
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/torchmetrics/segmentation/generalized_dice.py | Refactored state management and compute logic to track class presence and return NaN for missing classes |
| src/torchmetrics/functional/segmentation/generalized_dice.py | Added NaN handling for non-per-class computation when denominator is zero |
| tests/unittests/segmentation/test_generalized_dice_score.py | Added test cases for missing classes and zero denominator scenarios |
| CHANGELOG.md | Updated changelog to document the bug fix |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
@VijayVignesh1 dont worry, Copilot is not always right but it may spot some overlooked issues |
| ) | ||
| self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) | ||
| self.samples += preds.shape[0] | ||
| self.numerator.append(numerator) |
There was a problem hiding this comment.
this still changes the memory format if you add to a list rather than adding to a tensor
There was a problem hiding this comment.
Can you please elaborate this a bit more?
|
Hi, I was investigating #2846 and took a look at why this PR stalled. I'd like to help move it forward. On the memory-format concern @justusschock raised: I checked this out locally and measured state memory directly. Master holds ~44 bytes of state regardless of how many
That's fine for a single val sweep, but the regression is unbounded in the number of To preserve master's memory profile: The per-class numerator and denominator are sum-reducible across batches (so they become two fixed-size tensor sums). Master's self.add_state("numerator_sum", default=torch.zeros(n), dist_reduce_fx="sum")
self.add_state("denominator_sum", default=torch.zeros(n), dist_reduce_fx="sum")
# per_class=False: sample-wise mean, excluding degenerate (all-zero target) samples
self.add_state("score_sum", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("valid_samples", default=torch.zeros(1), dist_reduce_fx="sum")
def update(self, preds, target):
num, den = _generalized_dice_update(...) # [N, C]
self.numerator_sum += num.sum(dim=0)
self.denominator_sum += den.sum(dim=0)
per_sample = _safe_divide(num.sum(dim=1), den.sum(dim=1), "nan")
self.score_sum += torch.nansum(per_sample)
self.valid_samples += (~torch.isnan(per_sample)).sum()
def compute(self):
if self.per_class:
return _safe_divide(self.numerator_sum, self.denominator_sum, "nan")
return _safe_divide(self.score_sum, self.valid_samples, "nan").squeeze()I ran this against the #2846 repro on current master: For One thing worth flagging in the current PR regardless of which structure ends up merged: the I'm happy to drive this to merge. Two options depending on PR author's availability:
@Borda @justusschock, does this direction work? Happy to start once I have a thumbs up. |
Sure, we can work on the same PR! |
What does this PR do?
Fixes GeneralizedDiceScore to yield NaN for missing classes when per_class=True.
Added testcases to verify the same.
Fixes #2846
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
📚 Documentation preview 📚: https://torchmetrics--3251.org.readthedocs.build/en/3251/