Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-



- Fixed `GeneralizedDiceScore` to yield `NaN` if there are missing classes ([#2846](https://github.com/Lightning-AI/torchmetrics/issues/2846))

---

## [1.8.2] - 2025-09-03
Expand Down
15 changes: 6 additions & 9 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class:
if not per_class:
numerator = torch.sum(numerator, 1)
denominator = torch.sum(denominator, 1)
return _safe_divide(numerator, denominator)
else:
numerator = torch.sum(numerator, 0, keepdim=True)
denominator = torch.sum(denominator, 0, keepdim=True)
return _safe_divide(numerator, denominator, "nan")


def generalized_dice_score(
Expand Down Expand Up @@ -126,10 +129,7 @@ def generalized_dice_score(
>>> generalized_dice_score(preds, target, num_classes=5)
tensor([0.4830, 0.4935, 0.5044, 0.4880])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True)
tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500],
[0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])
tensor([[0.4845, 0.4997, 0.4993, 0.4864, 0.4912]])
Comment thread
Borda marked this conversation as resolved.
Outdated

Example (with index tensors):
>>> from torch import randint
Expand All @@ -139,10 +139,7 @@ def generalized_dice_score(
>>> generalized_dice_score(preds, target, num_classes=5, input_format="index")
tensor([0.1991, 0.1971, 0.2350, 0.2216])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index")
tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069],
[0.1837, 0.2162, 0.0962, 0.2692, 0.1895],
[0.3866, 0.1348, 0.2526, 0.2301, 0.2083],
[0.1978, 0.2804, 0.1714, 0.1915, 0.2783]])
tensor([[0.2234, 0.2170, 0.1597, 0.2399, 0.2204]])
Comment thread
Borda marked this conversation as resolved.
Outdated

"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
Expand Down
24 changes: 13 additions & 11 deletions src/torchmetrics/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

Expand All @@ -24,6 +23,7 @@
_generalized_dice_validate_args,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down Expand Up @@ -100,15 +100,16 @@ class GeneralizedDiceScore(Metric):
tensor(0.4992)
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True)
>>> gds(preds, target)
tensor([0.5001, 0.4993, 0.4982])
tensor([0.5000, 0.4993, 0.4983])
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False)
>>> gds(preds, target)
tensor([0.4993, 0.4982])
tensor([0.4993, 0.4983])

"""

score: Tensor
samples: Tensor
class_present: Tensor
numerator: List[Tensor]
denominator: List[Tensor]
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
Expand All @@ -133,20 +134,21 @@ def __init__(
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("numerator", default=[], dist_reduce_fx="cat")
self.add_state("denominator", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new data."""
numerator, denominator = _generalized_dice_update(
preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format
)
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]
self.numerator.append(numerator)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this still changes the memory format if you add to a list rather than adding to a tensor

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate this a bit more?

self.denominator.append(denominator)

def compute(self) -> Tensor:
"""Compute the final generalized dice score."""
return self.score / self.samples
score = _generalized_dice_compute(dim_zero_cat(self.numerator), dim_zero_cat(self.denominator), self.per_class)
return score.mean(dim=0)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def _reference_generalized_dice(
class TestGeneralizedDiceScore(MetricTester):
"""Test class for `GeneralizedDiceScore` metric."""

atol = 2e-3

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp):
"""Test class implementation of metric."""
Expand Down Expand Up @@ -122,3 +124,52 @@ def test_generalized_dice_functional(self, preds, target, input_format, include_
"input_format": input_format,
},
)


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_samples_with_missing_classes(per_class, include_background):
"""Test GeneralizedDiceScore with missing classes in some samples."""
target = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)
preds = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)

target[0, 0, 0, 0] = 1
preds[0, 0, 0, 0] = 1

target[2, 1, 0, 0] = 1
preds[2, 1, 0, 0] = 1

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)
score = metric(preds, target)

target_slice = target if include_background else target[:, 1:]
output_classes = NUM_CLASSES if include_background else NUM_CLASSES - 1

if per_class:
assert len(score) == output_classes
for c in range(output_classes):
assert score[c] == 1.0 if target_slice[:, c].sum() > 0 else torch.isnan(score[c])
else:
assert score.isnan()


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_generalized_dice_zero_denominator(per_class, include_background):
"""Check that GeneralizedDiceScore returns NaN when the denominator is all zero (no class present)."""
target = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)

score = metric(preds, target)

if per_class and include_background:
assert len(score) == NUM_CLASSES
assert all(t.isnan() for t in score)
elif per_class and not include_background:
assert len(score) == NUM_CLASSES - 1
assert all(t.isnan() for t in score)
else:
# Expect scalar NaN
assert score.isnan()
Loading