Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,12 @@
ToTensord,
ToTensorD,
ToTensorDict,
TransformPointsImageToWorldd,
TransformPointsImageToWorldD,
TransformPointsImageToWorldDict,
TransformPointsWorldToImaged,
TransformPointsWorldToImageD,
TransformPointsWorldToImageDict,
Transposed,
TransposeD,
TransposeDict,
Expand Down
143 changes: 133 additions & 10 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from monai.transforms.traits import LazyTrait, MultiSampleTrait
from monai.transforms.transform import LazyTransform, MapTransform, Randomizable
from monai.transforms.utils import is_positive
from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep
from monai.utils import MAX_SEED, Method, PytorchPadMode, TraceKeys, ensure_tuple_rep

__all__ = [
"Padd",
Expand Down Expand Up @@ -431,17 +431,33 @@ class SpatialCropd(Cropd):
- a spatial center and size
- the start and end coordinates of the ROI

ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as
string dictionary keys. When a string is provided, the actual coordinate values are read from the
data dictionary at call time. This enables pipelines where coordinates are computed by earlier
transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the
data dictionary under the given key.

Example::

from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd

pipeline = Compose([
TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"),
TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"),
SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"),
])

This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
for more information.
"""

def __init__(
self,
keys: KeysCollection,
roi_center: Sequence[int] | int | None = None,
roi_size: Sequence[int] | int | None = None,
roi_start: Sequence[int] | int | None = None,
roi_end: Sequence[int] | int | None = None,
roi_center: Sequence[int] | int | str | None = None,
roi_size: Sequence[int] | int | str | None = None,
roi_start: Sequence[int] | int | str | None = None,
roi_end: Sequence[int] | int | str | None = None,
roi_slices: Sequence[slice] | None = None,
allow_missing_keys: bool = False,
lazy: bool = False,
Expand All @@ -450,19 +466,126 @@ def __init__(
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
roi_center: voxel coordinates for center of the crop ROI.
roi_center: voxel coordinates for center of the crop ROI, or a string key to look up
the coordinates from the data dictionary.
roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,
will not crop that dimension of the image.
roi_start: voxel coordinates for start of the crop ROI.
will not crop that dimension of the image. Can also be a string key.
roi_start: voxel coordinates for start of the crop ROI, or a string key to look up
the coordinates from the data dictionary.
roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
use the end coordinate of image.
use the end coordinate of image. Can also be a string key.
roi_slices: list of slices for each of the spatial dimensions.
allow_missing_keys: don't raise exception if key is missing.
lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.
"""
cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)
self._roi_center = roi_center
self._roi_size = roi_size
self._roi_start = roi_start
self._roi_end = roi_end
self._roi_slices = roi_slices
self._has_str_roi = any(isinstance(v, str) for v in [roi_center, roi_size, roi_start, roi_end])

if not self._has_str_roi:
cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)
else:
# Placeholder cropper for the string-key path. Replaced on self.cropper at
# __call__ time once string keys are resolved from the data dictionary.
cropper = SpatialCrop(roi_start=[0], roi_end=[1], lazy=lazy)
super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)

@staticmethod
def _resolve_roi_param(val, d):
"""Resolve an ROI parameter from the data dictionary if it is a string key.

Args:
val: the ROI parameter value. If a string, it is used as a key to look up
the actual value from ``d``. Otherwise returned as-is.
d: the data dictionary.

Returns:
The resolved ROI parameter. Tensors and numpy arrays are flattened to 1-D
and rounded to int64 so they can be consumed by ``Crop.compute_slices``.

Raises:
KeyError: if ``val`` is a string key that does not exist in ``d``.
"""
if not isinstance(val, str):
return val
if val not in d:
raise KeyError(f"ROI key '{val}' not found in the data dictionary.")
resolved = d[val]
# ApplyTransformToPoints outputs tensors of shape (C, N, dims).
# A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3).
# Flatten to 1-D and round to integers for compute_slices.
# Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates.
if isinstance(resolved, np.ndarray):
resolved = torch.from_numpy(resolved)
if isinstance(resolved, torch.Tensor):
resolved = torch.round(resolved.flatten()).to(torch.int64)
return resolved

@property
def requires_current_data(self):
"""bool: Whether this transform requires the current data dictionary to resolve ROI parameters."""
return self._has_str_roi

def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:
"""
Args:
data: dictionary of data items to be transformed.
lazy: whether to execute lazily. If ``None``, uses the instance default.

Returns:
Dictionary with cropped data for each key.
"""
if not self._has_str_roi:
return super().__call__(data, lazy=lazy)

d = dict(data)
roi_center = self._resolve_roi_param(self._roi_center, d)
roi_size = self._resolve_roi_param(self._roi_size, d)
roi_start = self._resolve_roi_param(self._roi_start, d)
roi_end = self._resolve_roi_param(self._roi_end, d)

lazy_ = self.lazy if lazy is None else lazy
self.cropper = SpatialCrop(
roi_center=roi_center,
roi_size=roi_size,
roi_start=roi_start,
roi_end=roi_end,
roi_slices=self._roi_slices,
lazy=lazy_,
)
for key in self.key_iterator(d):
d[key] = self.cropper(d[key], lazy=lazy_)
return d

def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]:
"""
Inverse of the crop transform, restoring the original spatial dimensions via padding.

For the string-key path, ``self.cropper`` is recreated on each ``__call__``, so its
``id()`` won't match the one stored in the MetaTensor's transform stack. This override
bypasses the ID check and applies the inverse directly using the crop info stored in the
MetaTensor.

Args:
data: dictionary of cropped ``MetaTensor`` items.

Returns:
Dictionary with inverse-transformed (padded) data for each key.
"""
if not self._has_str_roi:
return super().inverse(data)
d = dict(data)
for key in self.key_iterator(d):
transform = self.cropper.pop_transform(d[key], check=False)
cropped = transform[TraceKeys.EXTRA_INFO]["cropped"]
inverse_transform = BorderPad(cropped)
with inverse_transform.trace_transform(False):
d[key] = inverse_transform(d[key])
return d


class CenterSpatialCropd(Cropd):
"""
Expand Down
88 changes: 88 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@
"ApplyTransformToPointsd",
"ApplyTransformToPointsD",
"ApplyTransformToPointsDict",
"TransformPointsWorldToImaged",
"TransformPointsWorldToImageD",
"TransformPointsWorldToImageDict",
"TransformPointsImageToWorldd",
"TransformPointsImageToWorldD",
"TransformPointsImageToWorldDict",
"FlattenSequenced",
"FlattenSequenceD",
"FlattenSequenceDict",
Expand Down Expand Up @@ -1910,6 +1916,86 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
return d


class TransformPointsWorldToImaged(ApplyTransformToPointsd):
"""
Dictionary-based transform to convert points from world coordinates to image coordinates.

This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
``invert_affine=True``, which transforms world-space coordinates into the coordinate space of a
reference image by inverting the image's affine matrix.

Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_keys: The key of the reference image used to derive the affine transformation.
This is required because the affine must come from a reference image.
It can also be a sequence of keys, in which case each refers to the affine applied
to the matching points in ``keys``.
dtype: The desired data type for the output.
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
allow_missing_keys: Don't raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
refer_keys: KeysCollection,
dtype: DtypeLike | torch.dtype = torch.float64,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
super().__init__(
keys=keys,
refer_keys=refer_keys,
dtype=dtype,
affine=None,
invert_affine=True,
affine_lps_to_ras=affine_lps_to_ras,
allow_missing_keys=allow_missing_keys,
)


class TransformPointsImageToWorldd(ApplyTransformToPointsd):
"""
Dictionary-based transform to convert points from image coordinates to world coordinates.

This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
``invert_affine=False``, which transforms image-space coordinates into world-space coordinates
by applying the reference image's affine matrix directly.

Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_keys: The key of the reference image used to derive the affine transformation.
This is required because the affine must come from a reference image.
It can also be a sequence of keys, in which case each refers to the affine applied
to the matching points in ``keys``.
dtype: The desired data type for the output.
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
allow_missing_keys: Don't raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
refer_keys: KeysCollection,
dtype: DtypeLike | torch.dtype = torch.float64,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
super().__init__(
keys=keys,
refer_keys=refer_keys,
dtype=dtype,
affine=None,
invert_affine=False,
affine_lps_to_ras=affine_lps_to_ras,
allow_missing_keys=allow_missing_keys,
)


class FlattenSequenced(MapTransform, ReduceTrait):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`.
Expand Down Expand Up @@ -1975,4 +2061,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
TransformPointsWorldToImageD = TransformPointsWorldToImageDict = TransformPointsWorldToImaged
TransformPointsImageToWorldD = TransformPointsImageToWorldDict = TransformPointsImageToWorldd
FlattenSequenceD = FlattenSequenceDict = FlattenSequenced
Loading
Loading