-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add parameter to DiceMetric and DiceHelper classes #8774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 9 commits
b5107da
ccca77a
c110e2a
41e52c1
34a6817
8d412a1
cb433a8
ba2e0b3
d9bfb5d
6f2155c
ba05438
4e6def7
925e431
28a2944
24a17e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,13 +11,19 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.metrics.utils import do_metric_reduction | ||
| from monai.utils import MetricReduction, deprecated_arg | ||
| from monai.utils.module import optional_import | ||
|
|
||
| from .metric import CumulativeIterationMetric | ||
|
|
||
| distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt") | ||
| generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure") | ||
| sn_label, _ = optional_import("scipy.ndimage", name="label") | ||
|
|
||
| __all__ = ["DiceMetric", "compute_dice", "DiceHelper"] | ||
|
|
||
|
|
||
|
|
@@ -95,6 +101,9 @@ class DiceMetric(CumulativeIterationMetric): | |
| If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, | ||
| the index begins at "0", otherwise at "1". It can also take a list of label names. | ||
| The outcome will then be returned as a dictionary. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| """ | ||
|
|
||
|
|
@@ -106,6 +115,7 @@ def __init__( | |
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| return_with_label: bool | list[str] = False, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.include_background = include_background | ||
|
|
@@ -114,13 +124,15 @@ def __init__( | |
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.return_with_label = return_with_label | ||
| self.per_component = per_component | ||
| self.dice_helper = DiceHelper( | ||
| include_background=self.include_background, | ||
| reduction=MetricReduction.NONE, | ||
| get_not_nans=False, | ||
| apply_argmax=False, | ||
| ignore_empty=self.ignore_empty, | ||
| num_classes=self.num_classes, | ||
| per_component=self.per_component, | ||
| ) | ||
|
|
||
| def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] | ||
|
|
@@ -175,6 +187,7 @@ def compute_dice( | |
| include_background: bool = True, | ||
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| per_component: bool = False, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Computes Dice score metric for a batch of predictions. This performs the same computation as | ||
|
|
@@ -192,6 +205,9 @@ def compute_dice( | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| Returns: | ||
| Dice scores per batch and per class, (shape: [batch_size, num_classes]). | ||
|
|
@@ -204,6 +220,7 @@ def compute_dice( | |
| apply_argmax=False, | ||
| ignore_empty=ignore_empty, | ||
| num_classes=num_classes, | ||
| per_component=per_component, | ||
| )(y_pred=y_pred, y=y) | ||
|
|
||
|
|
||
|
|
@@ -246,6 +263,9 @@ class DiceHelper: | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
| """ | ||
|
|
||
| @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") | ||
|
|
@@ -262,6 +282,7 @@ def __init__( | |
| num_classes: int | None = None, | ||
| sigmoid: bool | None = None, | ||
| softmax: bool | None = None, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| # handling deprecated arguments | ||
| if sigmoid is not None: | ||
|
|
@@ -277,6 +298,81 @@ def __init__( | |
| self.activate = activate | ||
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.per_component = per_component | ||
|
|
||
| def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): | ||
| """ | ||
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | ||
| Returns the ID of the nearest component for each voxel. | ||
|
|
||
| Args: | ||
| labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. | ||
| connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26. | ||
| sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Voronoi region IDs (int32) on CPU. | ||
| """ | ||
| if not has_ndimage: | ||
| raise RuntimeError("scipy.ndimage is required for per_component Dice computation.") | ||
| x = np.asarray(labels) | ||
| conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) | ||
| structure = generate_binary_structure(rank=3, connectivity=conn_rank) | ||
| cc, num = sn_label(x > 0, structure=structure) | ||
| if num == 0: | ||
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | ||
| edt_input = np.ones(cc.shape, dtype=np.uint8) | ||
| edt_input[cc > 0] = 0 | ||
| indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) | ||
| voronoi = cc[tuple(indices)] | ||
| return torch.from_numpy(voronoi) | ||
|
|
||
| def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Compute per-component Dice for a single batch item. | ||
|
|
||
| Args: | ||
| y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W). | ||
| y (torch.Tensor): Ground truth with shape (1, 2, D, H, W). | ||
|
|
||
| Returns: | ||
| torch.Tensor: Mean Dice over connected components. | ||
| """ | ||
| data = [] | ||
|
||
| if y_pred.ndim == y.ndim: | ||
| y_pred_idx = torch.argmax(y_pred, dim=1) | ||
| y_idx = torch.argmax(y, dim=1) | ||
| else: | ||
| y_pred_idx = y_pred | ||
| y_idx = y | ||
| if y_idx[0].sum() == 0: | ||
| if self.ignore_empty: | ||
| data.append(torch.tensor(float("nan"), device=y_idx.device)) | ||
| elif y_pred_idx.sum() == 0: | ||
| data.append(torch.tensor(1.0, device=y_idx.device)) | ||
| else: | ||
| data.append(torch.tensor(0.0, device=y_idx.device)) | ||
| else: | ||
| cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) | ||
| uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) | ||
| nof_components = uniq.numel() | ||
| code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) | ||
| idx = (inv << 2) | code | ||
| hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) | ||
| _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] | ||
| denom = 2 * tp + fp + fn | ||
| dice_scores = torch.where( | ||
| denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) | ||
| ) | ||
| data.append(dice_scores.unsqueeze(-1)) | ||
| data = [ | ||
| torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [ | ||
| torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
VijayVignesh1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| data = [x.reshape(-1, 1) for x in data] | ||
| return torch.stack([x.mean() for x in data]) | ||
|
|
||
| def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -322,15 +418,25 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = y_pred > 0.5 | ||
|
|
||
| first_ch = 0 if self.include_background else 1 | ||
| if self.per_component: | ||
| if len(y_pred.shape) != 5 or len(y.shape) != 5 or y_pred.shape[1] != 2 or y.shape[1] != 2: | ||
| raise ValueError( | ||
| "per_component requires both y_pred and y to be 5D binary segmentations " | ||
| f"with 2 channels. Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." | ||
| ) | ||
|
|
||
| first_ch = 0 if self.include_background and not self.per_component else 1 | ||
VijayVignesh1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| data = [] | ||
| for b in range(y_pred.shape[0]): | ||
| c_list = [] | ||
| for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: | ||
| x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() | ||
| x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] | ||
| c_list.append(self.compute_channel(x_pred, x)) | ||
| if self.per_component: | ||
| c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] | ||
| data.append(torch.stack(c_list)) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| data = torch.stack(data, dim=0).contiguous() # type: ignore | ||
|
|
||
| f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the above,
selfis used to accessignore_emptywhich could be passed as a argument instead with this method turned into a function external to this class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the compute_channel function, which is similar to compute_cc_dice, resides in the same class. To maintain consistency and encapsulate related functionality, it would make sense to keep both functions within the same class, right?