Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@

from __future__ import annotations

import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
from monai.utils.module import optional_import

from .metric import CumulativeIterationMetric

distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt")
generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure")
sn_label, _ = optional_import("scipy.ndimage", name="label")

__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]


Expand Down Expand Up @@ -95,6 +101,9 @@ class DiceMetric(CumulativeIterationMetric):
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

"""

Expand All @@ -106,6 +115,7 @@ def __init__(
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
Expand All @@ -114,13 +124,15 @@ def __init__(
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.per_component = per_component
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
apply_argmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
per_component=self.per_component,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -175,6 +187,7 @@ def compute_dice(
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
Computes Dice score metric for a batch of predictions. This performs the same computation as
Expand All @@ -192,6 +205,9 @@ def compute_dice(
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -204,6 +220,7 @@ def compute_dice(
apply_argmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
per_component=per_component,
)(y_pred=y_pred, y=y)


Expand Down Expand Up @@ -246,6 +263,9 @@ class DiceHelper:
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
"""

@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
Expand All @@ -262,6 +282,7 @@ def __init__(
num_classes: int | None = None,
sigmoid: bool | None = None,
softmax: bool | None = None,
per_component: bool = False,
) -> None:
# handling deprecated arguments
if sigmoid is not None:
Expand All @@ -277,6 +298,81 @@ def __init__(
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.per_component = per_component

def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.

Args:
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.

Returns:
torch.Tensor: Voronoi region IDs (int32) on CPU.
"""
if not has_ndimage:
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)

def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above, self is used to access ignore_empty which could be passed as a argument instead with this method turned into a function external to this class.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the compute_channel function, which is similar to compute_cc_dice, resides in the same class. To maintain consistency and encapsulate related functionality, it would make sense to keep both functions within the same class, right?

"""
Compute per-component Dice for a single batch item.

Args:
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).

Returns:
torch.Tensor: Mean Dice over connected components.
"""
data = []
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this method data only ever gets one item placed into it?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the dice_scores if all the connected components.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed data = [], and added direct assignment.

if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
if self.ignore_empty:
data.append(torch.tensor(float("nan"), device=y_idx.device))
elif y_pred_idx.sum() == 0:
data.append(torch.tensor(1.0, device=y_idx.device))
else:
data.append(torch.tensor(0.0, device=y_idx.device))
else:
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
nof_components = uniq.numel()
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
idx = (inv << 2) | code
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * tp + fp + fn
dice_scores = torch.where(
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
)
data.append(dice_scores.unsqueeze(-1))
data = [
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [x.reshape(-1, 1) for x in data]
return torch.stack([x.mean() for x in data])

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -322,15 +418,25 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5

first_ch = 0 if self.include_background else 1
if self.per_component:
if len(y_pred.shape) != 5 or len(y.shape) != 5 or y_pred.shape[1] != 2 or y.shape[1] != 2:
raise ValueError(
"per_component requires both y_pred and y to be 5D binary segmentations "
f"with 2 channels. Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
)

first_ch = 0 if self.include_background and not self.per_component else 1
data = []
for b in range(y_pred.shape[0]):
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
if self.per_component:
c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
data.append(torch.stack(c_list))

data = torch.stack(data, dim=0).contiguous() # type: ignore

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
Expand Down
35 changes: 35 additions & 0 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from parameterized import parameterized

from monai.metrics import DiceHelper, DiceMetric, compute_dice
from monai.utils.module import optional_import

_, has_ndimage = optional_import("scipy.ndimage")

_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# keep background
Expand Down Expand Up @@ -250,6 +253,24 @@
{"label_1": 0.4000, "label_2": 0.6667},
]

# Testcase for per_component DiceMetric
y = torch.zeros((5, 2, 64, 64, 64))
y_hat = torch.zeros((5, 2, 64, 64, 64))

y[0, 1, 20:25, 20:25, 20:25] = 1
y[0, 1, 40:45, 40:45, 40:45] = 1
y[0, 0] = 1 - y[0, 1]

y_hat[0, 1, 21:26, 21:26, 21:26] = 1
y_hat[0, 1, 41:46, 39:44, 41:46] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]

TEST_CASE_16 = [
{"per_component": True, "ignore_empty": False},
{"y": y, "y_pred": y_hat},
[[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]],
]


class TestComputeMeanDice(unittest.TestCase):

Expand Down Expand Up @@ -301,6 +322,20 @@ def test_nans_class(self, params, input_data, expected_value):
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# CC DiceMetric tests
@parameterized.expand([TEST_CASE_16])
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
def test_cc_dice_value(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
def test_input_dimensions(self):
with self.assertRaises(ValueError):
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))


if __name__ == "__main__":
unittest.main()
Loading