-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add parameter to DiceMetric and DiceHelper classes #8774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 10 commits
b5107da
ccca77a
c110e2a
41e52c1
34a6817
8d412a1
cb433a8
ba2e0b3
d9bfb5d
6f2155c
ba05438
4e6def7
925e431
28a2944
24a17e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,13 +11,19 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.metrics.utils import do_metric_reduction | ||
| from monai.utils import MetricReduction, deprecated_arg | ||
| from monai.utils.module import optional_import | ||
|
|
||
| from .metric import CumulativeIterationMetric | ||
|
|
||
| distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt") | ||
| generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure") | ||
| sn_label, _ = optional_import("scipy.ndimage", name="label") | ||
|
|
||
| __all__ = ["DiceMetric", "compute_dice", "DiceHelper"] | ||
|
|
||
|
|
||
|
|
@@ -41,6 +47,9 @@ class DiceMetric(CumulativeIterationMetric): | |
| image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction | ||
| and ground truth is BCHW[D]. | ||
|
|
||
| The ``per_component`` parameter can be set to `True` to compute the Dice metric per connected component in the ground truth | ||
| , and then average. This requires binary segmentations with 2 channels (background + foreground) as input. | ||
|
|
||
| The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. | ||
|
|
||
| Further information can be found in the official | ||
|
|
@@ -95,6 +104,9 @@ class DiceMetric(CumulativeIterationMetric): | |
| If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, | ||
| the index begins at "0", otherwise at "1". It can also take a list of label names. | ||
| The outcome will then be returned as a dictionary. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| """ | ||
|
|
||
|
|
@@ -106,6 +118,7 @@ def __init__( | |
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| return_with_label: bool | list[str] = False, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.include_background = include_background | ||
|
|
@@ -114,13 +127,15 @@ def __init__( | |
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.return_with_label = return_with_label | ||
| self.per_component = per_component | ||
| self.dice_helper = DiceHelper( | ||
| include_background=self.include_background, | ||
| reduction=MetricReduction.NONE, | ||
| get_not_nans=False, | ||
| apply_argmax=False, | ||
| ignore_empty=self.ignore_empty, | ||
| num_classes=self.num_classes, | ||
| per_component=self.per_component, | ||
| ) | ||
|
|
||
| def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] | ||
|
|
@@ -175,6 +190,7 @@ def compute_dice( | |
| include_background: bool = True, | ||
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| per_component: bool = False, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Computes Dice score metric for a batch of predictions. This performs the same computation as | ||
|
|
@@ -192,6 +208,9 @@ def compute_dice( | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| Returns: | ||
| Dice scores per batch and per class, (shape: [batch_size, num_classes]). | ||
|
|
@@ -204,6 +223,7 @@ def compute_dice( | |
| apply_argmax=False, | ||
| ignore_empty=ignore_empty, | ||
| num_classes=num_classes, | ||
| per_component=per_component, | ||
| )(y_pred=y_pred, y=y) | ||
|
|
||
|
|
||
|
|
@@ -246,6 +266,9 @@ class DiceHelper: | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
| """ | ||
|
|
||
| @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") | ||
|
|
@@ -262,6 +285,7 @@ def __init__( | |
| num_classes: int | None = None, | ||
| sigmoid: bool | None = None, | ||
| softmax: bool | None = None, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| # handling deprecated arguments | ||
| if sigmoid is not None: | ||
|
|
@@ -277,6 +301,88 @@ def __init__( | |
| self.activate = activate | ||
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.per_component = per_component | ||
|
|
||
| def compute_voronoi_regions_fast(self, labels): | ||
| """ | ||
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | ||
| Returns the ID of the nearest component for each voxel. | ||
|
|
||
| Args: | ||
| labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Voronoi region IDs (int32) on CPU. | ||
| """ | ||
| if not has_ndimage: | ||
| raise RuntimeError("scipy.ndimage is required for per_component Dice computation.") | ||
| x = np.asarray(labels) | ||
| rank = x.ndim | ||
| if rank == 3: | ||
| conn_map = {6: 1, 18: 2, 26: 3} | ||
| connectivity = 26 | ||
| elif rank == 2: | ||
| conn_map = {4: 1, 8: 2} | ||
| connectivity = 8 | ||
| else: | ||
| raise ValueError("Only 2D or 3D inputs supported") | ||
| conn_rank = conn_map.get(connectivity, max(conn_map.values())) | ||
|
||
| structure = generate_binary_structure(rank=rank, connectivity=conn_rank) | ||
| cc, num = sn_label(x > 0, structure=structure) | ||
| if num == 0: | ||
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | ||
| edt_input = np.ones(cc.shape, dtype=np.uint8) | ||
| edt_input[cc > 0] = 0 | ||
| indices = distance_transform_edt(edt_input, sampling=None, return_distances=False, return_indices=True) | ||
| voronoi = cc[tuple(indices)] | ||
| return torch.from_numpy(voronoi) | ||
|
|
||
| def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the above,
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the compute_channel function, which is similar to compute_cc_dice, resides in the same class. To maintain consistency and encapsulate related functionality, it would make sense to keep both functions within the same class, right? |
||
| """ | ||
| Compute per-component Dice for a single batch item. | ||
|
|
||
| Args: | ||
| y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W) or (1, 2, H, W). | ||
| y (torch.Tensor): Ground truth with shape (1, 2, D, H, W) or (1, 2, H, W). | ||
|
|
||
| Returns: | ||
| torch.Tensor: Mean Dice over connected components. | ||
| """ | ||
VijayVignesh1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| data = [] | ||
|
||
| if y_pred.ndim == y.ndim: | ||
| y_pred_idx = torch.argmax(y_pred, dim=1) | ||
| y_idx = torch.argmax(y, dim=1) | ||
| else: | ||
| y_pred_idx = y_pred | ||
| y_idx = y | ||
| if y_idx[0].sum() == 0: | ||
| if self.ignore_empty: | ||
| data.append(torch.tensor(float("nan"), device=y_idx.device)) | ||
| elif y_pred_idx.sum() == 0: | ||
| data.append(torch.tensor(1.0, device=y_idx.device)) | ||
| else: | ||
| data.append(torch.tensor(0.0, device=y_idx.device)) | ||
| else: | ||
| cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) | ||
| uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) | ||
| nof_components = uniq.numel() | ||
| code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) | ||
| idx = (inv << 2) | code | ||
| hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) | ||
| _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] | ||
| denom = 2 * tp + fp + fn | ||
| dice_scores = torch.where( | ||
| denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) | ||
| ) | ||
| data.append(dice_scores.unsqueeze(-1)) | ||
| data = [ | ||
| torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [ | ||
| torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
VijayVignesh1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| data = [x.reshape(-1, 1) for x in data] | ||
| return torch.stack([x.mean() for x in data]) | ||
|
|
||
| def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -322,15 +428,25 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = y_pred > 0.5 | ||
|
|
||
| first_ch = 0 if self.include_background else 1 | ||
| if self.per_component: | ||
| if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2: | ||
| raise ValueError( | ||
| "per_component requires both y_pred and y to be 4D or 5D binary segmentations " | ||
| f"with 2 channels. Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." | ||
| ) | ||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| first_ch = 0 if self.include_background and not self.per_component else 1 | ||
VijayVignesh1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| data = [] | ||
| for b in range(y_pred.shape[0]): | ||
| c_list = [] | ||
| for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: | ||
| x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() | ||
| x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] | ||
| c_list.append(self.compute_channel(x_pred, x)) | ||
| if self.per_component: | ||
| c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] | ||
| data.append(torch.stack(c_list)) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| data = torch.stack(data, dim=0).contiguous() # type: ignore | ||
|
|
||
| f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.