Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ba56a6d
Move test_image_filter.py
garciadias Feb 5, 2025
5216b7a
Merge remote-tracking branch 'origin/dev' into dev
garciadias Feb 12, 2025
eacd783
Merge remote-tracking branch 'origin/dev' into dev
garciadias Feb 14, 2025
c64825f
Merge remote-tracking branch 'upstream/dev' into dev
garciadias Feb 19, 2025
66b6c17
Merge remote-tracking branch 'origin/dev' into dev
garciadias Feb 27, 2025
19cab57
Merge remote-tracking branch 'upstream/dev' into dev
garciadias Apr 11, 2025
09c2cd9
fix(spatial): fall back to PyTorch path on Blackwell (sm_120) GPUs wh…
garciadias Mar 8, 2026
7cd0607
fix(spatial): extend Blackwell fallback to Resample class in array.py
garciadias Mar 8, 2026
3fd7546
lint
garciadias Mar 9, 2026
4f6df07
lint
garciadias Mar 9, 2026
36e2623
lint
garciadias Mar 9, 2026
356956a
Merge remote-tracking branch 'upstream/dev' into 8587-test-erros-on-p…
garciadias Mar 9, 2026
17b9910
autofix
garciadias Mar 17, 2026
8de64af
Add device compatibility check to Warp.forward()
garciadias Mar 20, 2026
7c2ddb6
Fix grid coordinate conversion in Resample fallback path
garciadias Mar 20, 2026
cfe5524
Improve _compiled_unsupported() docstring with Google-style format
garciadias Mar 20, 2026
1dec216
Fix string formatting: add missing spaces in error messages
garciadias Mar 20, 2026
ec8bf1f
Add tests for GPU support detection
garciadias Mar 20, 2026
4b5bf1e
autofix
garciadias Mar 20, 2026
23f0290
fix: lint
garciadias Mar 12, 2026
057ff4d
DCO Remediation Commit for R. Garcia-Dias <rafaelagd@gmail.com>
garciadias Mar 20, 2026
2b5b367
fix(warp): fix mode type mismatch in Warp fallback path on Blackwell …
claude Mar 30, 2026
80124e6
Merge branch 'claude/fix-blackwell-monai-tests-Apjwl' into 8587-test-…
garciadias Mar 30, 2026
1b5ac46
feat(runner): add per-test timeout via --timeout flag
claude Apr 2, 2026
0a90770
Merge remote-tracking branch 'origin/claude/fix-blackwell-monai-tests…
garciadias Apr 2, 2026
b5b1eff
fix(img2tensorboard): correct GIF header byte sequence and add error …
garciadias Apr 2, 2026
4a54359
fix(dataset): ensure proper handling of LMDB read environment and pre…
garciadias Apr 2, 2026
f07dcb1
fix(dataset): ensure proper handling of LMDB read environment and pre…
garciadias Apr 2, 2026
2feca34
DCO Remediation Commit for Claude <noreply@anthropic.com>
garciadias Apr 2, 2026
0b887a3
DCO Remediation Commit for R. Garcia-Dias <rafaelagd@gmail.com>
claude Apr 2, 2026
e57f64e
Merge remote-tracking branch 'origin/claude/fix-blackwell-monai-tests…
garciadias Apr 2, 2026
e2b99da
Merge remote-tracking branch 'upstream/dev' into 8587-test-erros-on-p…
garciadias Apr 2, 2026
256b671
DCO Remediation Commit for Claude <noreply@anthropic.com>
claude Apr 2, 2026
f3e612c
Merge remote-tracking branch 'origin/claude/fix-blackwell-monai-tests…
garciadias Apr 2, 2026
43dd636
style: reformat with black 25.1.0 to match CI
garciadias Apr 2, 2026
f284c48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2026
c00aaf8
DCO Remediation Commit for R. Garcia-Dias <rafaelagd@gmail.com>
garciadias Apr 2, 2026
0cd053a
remove log file
garciadias Apr 2, 2026
d30be44
Merge remote-tracking branch 'upstream/dev' into 8587-test-erros-on-p…
garciadias Apr 9, 2026
08a1cce
Merge remote-tracking branch 'upstream/dev' into 8587-test-erros-on-p…
garciadias Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def _run_cmd(self, cmd: str, devices_info: str = "") -> subprocess.CompletedProc
look_up_option(self.device_setting["MN_START_METHOD"], ["bcprun"])
except ValueError as err:
raise NotImplementedError(
f"{self.device_setting['MN_START_METHOD']} is not supported yet."
"Try modify BundleAlgo._run_cmd for your cluster."
f"{self.device_setting['MN_START_METHOD']} is not supported yet.Try modify BundleAlgo._run_cmd for your cluster."
) from err

return _run_cmd_bcprun(cmd, n=self.device_setting["NUM_NODES"], p=self.device_setting["n_devices"])
Expand Down Expand Up @@ -396,7 +395,7 @@ def _download_algos_url(url: str, at_path: str) -> dict[str, dict[str, str]]:
try:
download_and_extract(url=url, filepath=algo_compressed_file, output_dir=os.path.dirname(at_path))
except Exception as e:
msg = f"Download and extract of {url} failed, attempt {i+1}/{download_attempts}."
msg = f"Download and extract of {url} failed, attempt {i + 1}/{download_attempts}."
if i < download_attempts - 1:
warnings.warn(msg)
time.sleep(i)
Expand Down
7 changes: 3 additions & 4 deletions monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ def set_regular_matcher(
"""
if fg_iou_thresh < bg_iou_thresh:
raise ValueError(
"Require fg_iou_thresh >= bg_iou_thresh. "
f"Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
f"Require fg_iou_thresh >= bg_iou_thresh. Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
)
self.proposal_matcher = Matcher(
fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches
Expand Down Expand Up @@ -519,7 +518,7 @@ def forward(
else:
if self.inferer is None:
raise ValueError(
"`self.inferer` is not defined." "Please refer to function self.set_sliding_window_inferer(*)."
"`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*)."
)
head_outputs = predict_with_inferer(
images, self.network, keys=[self.cls_key, self.box_reg_key], inferer=self.inferer
Expand Down Expand Up @@ -787,7 +786,7 @@ def compute_anchor_matched_idxs(
)

if self.debug:
print(f"Max box overlap between anchors and gt boxes: {torch.max(match_quality_matrix,dim=1)[0]}.")
print(f"Max box overlap between anchors and gt boxes: {torch.max(match_quality_matrix, dim=1)[0]}.")

if torch.max(matched_idxs_per_image) < 0:
warnings.warn(
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ def generate_anchors(
if (self.spatial_dims >= 3) and (len(aspect_ratios_t.shape) != 2):
raise ValueError(
f"In {self.spatial_dims}-D image, aspect_ratios for each level should be \
{len(aspect_ratios_t.shape)-1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
{len(aspect_ratios_t.shape) - 1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
)

if (self.spatial_dims >= 3) and (aspect_ratios_t.shape[1] != self.spatial_dims - 1):
raise ValueError(
f"In {self.spatial_dims}-D image, aspect_ratios for each level should has \
shape (_,{self.spatial_dims-1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
shape (_,{self.spatial_dims - 1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
)

# if 2d, w:h = 1:aspect_ratios
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def check_training_targets(
if boxes.numel() == 0:
warnings.warn(
f"Warning: Given target boxes has shape of {boxes.shape}. "
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])."
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2 * spatial_dims}])."
)
else:
raise ValueError(
f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)."
f"Expected target boxes to be a tensor of shape [N, {2 * spatial_dims}], got {boxes.shape}.)."
)
if not torch.is_floating_point(boxes):
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
Expand Down
15 changes: 8 additions & 7 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __call__(self, data):
d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get image stats spent {time.time()-start}")
logger.debug(f"Get image stats spent {time.time() - start}")
return d


Expand Down Expand Up @@ -350,7 +350,7 @@ def __call__(self, data: Mapping) -> dict:
d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get foreground image stats spent {time.time()-start}")
logger.debug(f"Get foreground image stats spent {time.time() - start}")
return d


Expand Down Expand Up @@ -519,7 +519,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
d[self.stats_name] = report # type: ignore[assignment]

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get label stats spent {time.time()-start}")
logger.debug(f"Get label stats spent {time.time() - start}")
return d # type: ignore[return-value]


Expand Down Expand Up @@ -897,9 +897,11 @@ def __init__(
for i, hist_params in enumerate(zip(self.hist_bins, self.hist_range)):
_hist_bins, _hist_range = hist_params
if not isinstance(_hist_bins, int) or _hist_bins < 0:
raise ValueError(f"Expected {i+1}. hist_bins value to be positive integer but got {_hist_bins}")
raise ValueError(f"Expected {i + 1}. hist_bins value to be positive integer but got {_hist_bins}")
if not isinstance(_hist_range, list) or len(_hist_range) != 2:
raise ValueError(f"Expected {i+1}. hist_range values to be list of length 2 but received {_hist_range}")
raise ValueError(
f"Expected {i + 1}. hist_range values to be list of length 2 but received {_hist_range}"
)

def __call__(self, data: dict) -> dict:
"""
Expand Down Expand Up @@ -934,8 +936,7 @@ def __call__(self, data: dict) -> dict:
self.hist_range = nr_channels * self.hist_range
if len(self.hist_range) != nr_channels:
raise ValueError(
f"There is a mismatch between the number of channels ({nr_channels}) "
f"and histogram ranges ({len(self.hist_range)})."
f"There is a mismatch between the number of channels ({nr_channels}) and histogram ranges ({len(self.hist_range)})."
)

# perform calculation
Expand Down
7 changes: 4 additions & 3 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def get_valid_level(
# Set the default value if no resolution parameter is provided.
level = 0
if level >= n_levels:
raise ValueError(f"The maximum level of this image is {n_levels-1} while level={level} is requested)!")
raise ValueError(
f"The maximum level of this image is {n_levels - 1} while level={level} is requested)!"
)

return level

Expand Down Expand Up @@ -414,8 +416,7 @@ def get_data(
# Check if there are three color channels for RGB
elif mode in "RGB" and patch.shape[self.channel_dim] != 3:
raise ValueError(
f"The image is expected to have three color channels in '{mode}' mode but has "
f"{patch.shape[self.channel_dim]}. "
f"The image is expected to have three color channels in '{mode}' mode but has {patch.shape[self.channel_dim]}. "
)
# Get patch-related metadata
metadata: dict = self._get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_true = one_hot(y_true, num_classes=self.num_classes)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def aggregate(
_f = {}
if isinstance(self.return_with_label, bool):
for i, v in enumerate(f):
_label_key = f"label_{i+1}" if not self.include_background else f"label_{i}"
_label_key = f"label_{i + 1}" if not self.include_background else f"label_{i}"
_f[_label_key] = round(v.item(), 4)
else:
for key, v in zip(self.return_with_label, f):
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)}
self.patch_embeddings = nn.Sequential(
Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)
)
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def add_factory_callable(self, name: str, func: Callable, desc: str | None = Non
self.add(name.upper(), description, func)
# append name to the docstring
assert self.__doc__ is not None
self.__doc__ += f"{', ' if len(self.names)>1 else ' '}``{name}``"
self.__doc__ += f"{', ' if len(self.names) > 1 else ' '}``{name}``"

def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in
else:
raise ValueError(
f"Unsupported spatial_border length: {len(spatial_border)}, available options are "
f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]."
f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2 * len(spatial_shape)}]."
)
return tuple([(0, 0)] + data_pad_width) # type: ignore

Expand Down
3 changes: 1 addition & 2 deletions monai/transforms/regularization/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class Mixer(RandomizableTransform):

def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
"""
Mixer is a base class providing the basic logic for the mixup-class of
Expand All @@ -41,7 +40,7 @@ def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
"""
super().__init__()
if alpha <= 0:
raise ValueError(f"Expected positive number, but got {alpha = }")
raise ValueError(f"Expected positive number, but got {alpha=}")
self.alpha = alpha
self.batch_size = batch_size

Expand Down
8 changes: 5 additions & 3 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
_compiled_unsupported,
affine_func,
convert_box_to_points,
convert_points_to_box,
Expand Down Expand Up @@ -2062,14 +2063,15 @@ def __call__(
_align_corners = self.align_corners if align_corners is None else align_corners
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device)
sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3)
_use_compiled = USE_COMPILED and not _compiled_unsupported(img_t.device)
backend, _interp_mode, _padding_mode, _ = resolves_modes(
self.mode if mode is None else mode,
self.padding_mode if padding_mode is None else padding_mode,
backend=None,
use_compiled=USE_COMPILED,
use_compiled=_use_compiled,
)

if USE_COMPILED or backend == TransformBackends.NUMPY:
if _use_compiled or backend == TransformBackends.NUMPY:
grid_t, *_ = convert_to_dst_type(grid[:sr], img_t, dtype=grid.dtype, wrap_sequence=True)
if isinstance(grid, torch.Tensor) and grid_t.data_ptr() == grid.data_ptr():
grid_t = grid_t.clone(memory_format=torch.contiguous_format)
Expand All @@ -2080,7 +2082,7 @@ def __call__(
grid_t[i] = ((_dim - 1) / _dim) * grid_t[i] + t if _align_corners else grid_t[i] + t
elif _align_corners:
grid_t[i] = ((_dim - 1) / _dim) * (grid_t[i] + 0.5)
if USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name
if _use_compiled and backend == TransformBackends.TORCH: # compiled is using torch backend param name
grid_t = moveaxis(grid_t, 0, -1) # type: ignore
out = grid_pull(
img_t.unsqueeze(0),
Expand Down
24 changes: 23 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@
__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"]


def _compiled_unsupported(device: torch.device) -> bool:
"""
Return True if ``monai._C`` (the compiled C extension providing ``grid_pull``) is not
compiled with support for the given CUDA device's compute capability.

``monai._C`` is built at install time against a fixed set of CUDA architectures.
NVIDIA Blackwell GPUs (sm_120, compute capability 12.x) and newer were not included in
the default ``TORCH_CUDA_ARCH_LIST`` when the MONAI slim image was originally built,
so executing ``grid_pull`` on those devices produces incorrect results. Falling back to
the PyTorch-native ``affine_grid`` + ``grid_sample`` path (``USE_COMPILED=False``) gives
correct output on all architectures.

The threshold (``major >= 12``) matches the first architecture family (Blackwell, sm_120)
that shipped after the highest sm supported in the current default build list (sm_90,
Hopper). Adjust this constant when ``monai._C`` is rebuilt with sm_120+ support.
"""
if device.type != "cuda":
return False
return torch.cuda.get_device_properties(device).major >= 12


def _maybe_new_metatensor(img, dtype=None, device=None):
"""create a metatensor with fresh metadata if track_meta is True otherwise convert img into a torch tensor"""
return convert_to_tensor(
Expand Down Expand Up @@ -158,7 +179,8 @@ def spatial_resample(
xform_shape = [-1] + in_sp_size
img = img.reshape(xform_shape)
img = img.to(dtype_pt)
if isinstance(mode, int) or USE_COMPILED:
_use_compiled = USE_COMPILED and not _compiled_unsupported(img.device)
if isinstance(mode, int) or _use_compiled:
dst_xform = create_translate(spatial_rank, [float(d - 1) / 2 for d in spatial_size])
xform = xform @ convert_to_dst_type(dst_xform, xform)[0]
affine_xform = monai.transforms.Affine(
Expand Down
Loading