Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-



- Fixed `GeneralizedDiceScore` to yield `NaN` if there are missing classes ([#2846](https://github.com/Lightning-AI/torchmetrics/issues/2846))

---

## [1.8.2] - 2025-09-03
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class:
if not per_class:
numerator = torch.sum(numerator, 1)
denominator = torch.sum(denominator, 1)
return _safe_divide(numerator, denominator, "nan")
Comment thread
Borda marked this conversation as resolved.
Outdated
return _safe_divide(numerator, denominator)


Expand Down
28 changes: 19 additions & 9 deletions src/torchmetrics/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,16 @@ class GeneralizedDiceScore(Metric):
tensor(0.4992)
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True)
>>> gds(preds, target)
tensor([0.5001, 0.4993, 0.4982])
tensor([0.5000, 0.4993, 0.4983])
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False)
>>> gds(preds, target)
tensor([0.4993, 0.4982])
tensor([0.4993, 0.4983])

"""

score: Tensor
samples: Tensor
class_present: Tensor
numerator: Tensor
denominator: Tensor
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
Expand All @@ -133,20 +134,29 @@ def __init__(
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("class_present", default=torch.zeros(num_classes, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("numerator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat")
self.add_state("denominator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new data."""
numerator, denominator = _generalized_dice_update(
preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format
)
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]
self.numerator = torch.cat([self.numerator, numerator], dim=0)
self.denominator = torch.cat([self.denominator, denominator], dim=0)
if self.per_class:
class_mask = target.sum(dim=(0, *range(2, target.ndim))) > 0
self.class_present += class_mask[1:] if not self.include_background else class_mask
Comment thread
Borda marked this conversation as resolved.
Outdated
self.numerator = torch.sum(self.numerator, dim=0, keepdim=True)
self.denominator = torch.sum(self.denominator, dim=0, keepdim=True)

def compute(self) -> Tensor:
"""Compute the final generalized dice score."""
return self.score / self.samples
score = _generalized_dice_compute(self.numerator, self.denominator, self.per_class)
Comment thread
Borda marked this conversation as resolved.
Outdated
if not self.per_class:
return score.mean()
return torch.where(self.class_present > 0, score, torch.tensor(float("nan"))).squeeze()

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def _reference_generalized_dice(
class TestGeneralizedDiceScore(MetricTester):
"""Test class for `GeneralizedDiceScore` metric."""

atol = 2e-2
Comment thread
Borda marked this conversation as resolved.
Outdated

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp):
"""Test class implementation of metric."""
Expand Down Expand Up @@ -122,3 +124,52 @@ def test_generalized_dice_functional(self, preds, target, input_format, include_
"input_format": input_format,
},
)


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_samples_with_missing_classes(per_class, include_background):
"""Test GeneralizedDiceScore with missing classes in some samples."""
target = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)
preds = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)

target[0, 0, 0, 0] = 1
preds[0, 0, 0, 0] = 1

target[2, 1, 0, 0] = 1
preds[2, 1, 0, 0] = 1

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)
score = metric(preds, target)

target_slice = target if include_background else target[:, 1:]
output_classes = NUM_CLASSES if include_background else NUM_CLASSES - 1

if per_class:
assert len(score) == output_classes
for c in range(output_classes):
assert score[c] == 1.0 if target_slice[:, c].sum() > 0 else torch.isnan(score[c])
else:
assert score.isnan()


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_generalized_dice_zero_denominator(per_class, include_background):
"""Check that GeneralizedDiceScore returns NaN when the denominator is all zero (no class present)."""
target = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)

score = metric(preds, target)

if per_class and include_background:
assert len(score) == NUM_CLASSES
assert all(t.isnan() for t in score)
elif per_class and not include_background:
assert len(score) == NUM_CLASSES - 1
assert all(t.isnan() for t in score)
else:
# Expect scalar NaN
assert score.isnan()
Loading