-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add parameter to DiceMetric and DiceHelper classes #8774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 12 commits
b5107da
ccca77a
c110e2a
41e52c1
34a6817
8d412a1
cb433a8
ba2e0b3
d9bfb5d
6f2155c
ba05438
4e6def7
925e431
28a2944
24a17e9
a74f2cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,13 +11,22 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.metrics.utils import do_metric_reduction | ||
| from monai.utils import MetricReduction, deprecated_arg | ||
| from monai.utils.module import optional_import | ||
|
|
||
| from .metric import CumulativeIterationMetric | ||
|
|
||
| scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage") | ||
| cupy, has_cupy = optional_import("cupy") | ||
| cupy_ndimage, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage") | ||
|
|
||
|
|
||
| __all__ = ["DiceMetric", "compute_dice", "DiceHelper"] | ||
|
|
||
|
|
||
|
|
@@ -41,6 +50,9 @@ class DiceMetric(CumulativeIterationMetric): | |
| image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction | ||
| and ground truth is BCHW[D]. | ||
|
|
||
| The ``per_component`` parameter can be set to `True` to compute the Dice metric per connected component in the ground truth | ||
| , and then average. This requires binary segmentations with 2 channels (background + foreground) as input. | ||
|
|
||
| The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. | ||
|
|
||
| Further information can be found in the official | ||
|
|
@@ -95,6 +107,9 @@ class DiceMetric(CumulativeIterationMetric): | |
| If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, | ||
| the index begins at "0", otherwise at "1". It can also take a list of label names. | ||
| The outcome will then be returned as a dictionary. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| """ | ||
|
|
||
|
|
@@ -106,6 +121,7 @@ def __init__( | |
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| return_with_label: bool | list[str] = False, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.include_background = include_background | ||
|
|
@@ -114,13 +130,15 @@ def __init__( | |
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.return_with_label = return_with_label | ||
| self.per_component = per_component | ||
| self.dice_helper = DiceHelper( | ||
| include_background=self.include_background, | ||
| reduction=MetricReduction.NONE, | ||
| get_not_nans=False, | ||
| apply_argmax=False, | ||
| ignore_empty=self.ignore_empty, | ||
| num_classes=self.num_classes, | ||
| per_component=self.per_component, | ||
| ) | ||
|
|
||
| def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] | ||
|
|
@@ -175,6 +193,7 @@ def compute_dice( | |
| include_background: bool = True, | ||
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| per_component: bool = False, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Computes Dice score metric for a batch of predictions. This performs the same computation as | ||
|
|
@@ -192,6 +211,9 @@ def compute_dice( | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| Returns: | ||
| Dice scores per batch and per class, (shape: [batch_size, num_classes]). | ||
|
|
@@ -204,6 +226,7 @@ def compute_dice( | |
| apply_argmax=False, | ||
| ignore_empty=ignore_empty, | ||
| num_classes=num_classes, | ||
| per_component=per_component, | ||
| )(y_pred=y_pred, y=y) | ||
|
|
||
|
|
||
|
|
@@ -246,6 +269,9 @@ class DiceHelper: | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
| """ | ||
|
|
||
| @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") | ||
|
|
@@ -262,6 +288,7 @@ def __init__( | |
| num_classes: int | None = None, | ||
| sigmoid: bool | None = None, | ||
| softmax: bool | None = None, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| # handling deprecated arguments | ||
| if sigmoid is not None: | ||
|
|
@@ -277,6 +304,117 @@ def __init__( | |
| self.activate = activate | ||
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.per_component = per_component | ||
|
|
||
| def compute_voronoi_regions_fast(self, labels): | ||
| """ | ||
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | ||
| Returns the ID of the nearest component for each voxel. | ||
|
|
||
| Args: | ||
| labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. | ||
|
|
||
| Raises: | ||
| RuntimeError: when `scipy.ndimage` is not available. | ||
| ValueError: when `labels` has fewer than two dimensions. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Voronoi region IDs (int32) on CPU. | ||
| """ | ||
| if isinstance(labels, torch.Tensor) and labels.is_cuda and has_cupy and has_cupy_ndimage: | ||
| xp = cupy | ||
| nd_distance_transform_edt = cupy_ndimage.distance_transform_edt | ||
| nd_generate_binary_structure = cupy_ndimage.generate_binary_structure | ||
| nd_label = cupy_ndimage.label | ||
| x = cupy.asarray(labels.detach()) | ||
| else: | ||
| xp = np | ||
| nd_distance_transform_edt = scipy_ndimage.distance_transform_edt | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that we have a distance_transform_edt function in MONAI, we should probably update this later to use cupyx if cucim isn't present.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MONAI's distance_transform_edt gives different results that scipy's and cupy's, as given in the docstring. |
||
| nd_generate_binary_structure = scipy_ndimage.generate_binary_structure | ||
| nd_label = scipy_ndimage.label | ||
|
|
||
| if not has_scipy_ndimage: | ||
| raise RuntimeError("scipy.ndimage is required for per_component Dice computation.") | ||
|
|
||
| if isinstance(labels, torch.Tensor): | ||
| warnings.warn( | ||
| "Voronoi computation is running on CPU. " | ||
| "To accelerate, move the input tensor to GPU and ensure 'cupy' with 'cupyx.scipy.ndimage' is installed." | ||
| ) | ||
| x = labels.cpu().numpy() | ||
| else: | ||
| x = np.asarray(labels) | ||
| rank = x.ndim | ||
| if rank == 3: | ||
| conn_map = {6: 1, 18: 2, 26: 3} | ||
| connectivity = 26 | ||
| elif rank == 2: | ||
| conn_map = {4: 1, 8: 2} | ||
| connectivity = 8 | ||
| else: | ||
| raise ValueError("Only 2D or 3D inputs supported") | ||
| conn_rank = conn_map.get(connectivity, max(conn_map.values())) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not following the logic in this block. The value for
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, the connectivity can be 1 dimensional, 2 dimensional or 3 dimensional as given in conn_map. The deafult is 26 for 3D input and 8 for 2D input. I could add it as a default parameter to the function, for the future. Would that work? |
||
| structure = nd_generate_binary_structure(rank=rank, connectivity=conn_rank) | ||
| cc, num = nd_label(x > 0, structure=structure) | ||
| if num == 0: | ||
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | ||
| edt_input = xp.ones(cc.shape, dtype=xp.uint8) | ||
| edt_input[cc > 0] = 0 | ||
| indices = nd_distance_transform_edt(edt_input, sampling=None, return_distances=False, return_indices=True) | ||
| voronoi = cc[tuple(indices)] | ||
| if xp is cupy: | ||
| return torch.as_tensor(cupy.asnumpy(voronoi), dtype=torch.int32) | ||
| else: | ||
| return torch.as_tensor(voronoi, dtype=torch.int32) | ||
|
|
||
| def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the above,
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the compute_channel function, which is similar to compute_cc_dice, resides in the same class. To maintain consistency and encapsulate related functionality, it would make sense to keep both functions within the same class, right? |
||
| """ | ||
| Compute per-component Dice for a single batch item. | ||
|
|
||
| Args: | ||
| y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W) or (1, 2, H, W). | ||
| y (torch.Tensor): Ground truth with shape (1, 2, D, H, W) or (1, 2, H, W). | ||
|
|
||
| Returns: | ||
| torch.Tensor: Mean Dice over connected components. | ||
| """ | ||
| data = [] | ||
|
VijayVignesh1 marked this conversation as resolved.
Outdated
|
||
| if y_pred.ndim == y.ndim: | ||
| y_pred_idx = torch.argmax(y_pred, dim=1) | ||
| y_idx = torch.argmax(y, dim=1) | ||
| else: | ||
| y_pred_idx = y_pred | ||
| y_idx = y | ||
| if y_idx[0].sum() == 0: | ||
| if self.ignore_empty: | ||
| data.append(torch.tensor(float("nan"), device=y_idx.device)) | ||
| elif y_pred_idx.sum() == 0: | ||
| data.append(torch.tensor(1.0, device=y_idx.device)) | ||
| else: | ||
| data.append(torch.tensor(0.0, device=y_idx.device)) | ||
| else: | ||
| cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) | ||
| if cc_assignment.device != y_idx.device: | ||
| cc_assignment = cc_assignment.to(y_idx.device) | ||
| uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) | ||
| nof_components = uniq.numel() | ||
| code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) | ||
| idx = (inv << 2) | code | ||
| hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) | ||
| _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] | ||
| denom = 2 * tp + fp + fn | ||
| dice_scores = torch.where( | ||
| denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) | ||
| ) | ||
| data.append(dice_scores.unsqueeze(-1)) | ||
| data = [ | ||
| torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [ | ||
| torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
|
VijayVignesh1 marked this conversation as resolved.
Outdated
|
||
| data = [x.reshape(-1, 1) for x in data] | ||
| return torch.stack([x.mean() for x in data]) | ||
|
|
||
| def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -305,6 +443,9 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |
| y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...). | ||
| the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. | ||
| y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). | ||
|
|
||
| Raises: | ||
| ValueError: when the shapes of `y_pred` and `y` are not compatible for the per-component computation. | ||
| """ | ||
| _apply_argmax, _threshold = self.apply_argmax, self.threshold | ||
| if self.num_classes is None: | ||
|
|
@@ -322,15 +463,33 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = y_pred > 0.5 | ||
|
|
||
| first_ch = 0 if self.include_background else 1 | ||
| if self.per_component: | ||
| if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2: | ||
| same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5) | ||
| binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2 | ||
| same_shape = y_pred.shape == y.shape | ||
| if not (same_rank and binary_channels and same_shape): | ||
| raise ValueError( | ||
| "per_component requires matching 4D/5D binary tensors " | ||
| "(B, 2, H, W) or (B, 2, D, H, W). " | ||
| f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." | ||
| ) | ||
|
|
||
| first_ch = 0 if self.include_background and not self.per_component else 1 | ||
|
VijayVignesh1 marked this conversation as resolved.
|
||
| data = [] | ||
| for b in range(y_pred.shape[0]): | ||
| if self.per_component: | ||
| data.append(self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).reshape(-1)) | ||
| continue | ||
| c_list = [] | ||
| for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: | ||
| x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() | ||
| x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] | ||
| c_list.append(self.compute_channel(x_pred, x)) | ||
| # if self.per_component: | ||
| # c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] | ||
|
VijayVignesh1 marked this conversation as resolved.
Outdated
|
||
| data.append(torch.stack(c_list)) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| data = torch.stack(data, dim=0).contiguous() # type: ignore | ||
|
|
||
| f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.