Skip to content

Commit 09c2cd9

Browse files
committed
fix(spatial): fall back to PyTorch path on Blackwell (sm_120) GPUs when USE_COMPILED=True
monai._C (grid_pull) was not compiled with sm_120 (Blackwell) architecture support, causing spatial_resample to produce incorrect results on RTX 50-series GPUs when USE_COMPILED=True. Add _compiled_unsupported() to detect compute capability major >= 12 at runtime and transparently fall back to the PyTorch-native affine_grid + grid_sample path, which is verified correct on sm_120. Fixes test_flips_inverse_124 in tests.transforms.spatial.test_spatial_resampled on NVIDIA GeForce RTX 5090 (Blackwell, sm_120).
1 parent 19cab57 commit 09c2cd9

1 file changed

Lines changed: 23 additions & 1 deletion

File tree

monai/transforms/spatial/functional.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,27 @@
5454
__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"]
5555

5656

57+
def _compiled_unsupported(device: torch.device) -> bool:
58+
"""
59+
Return True if ``monai._C`` (the compiled C extension providing ``grid_pull``) is not
60+
compiled with support for the given CUDA device's compute capability.
61+
62+
``monai._C`` is built at install time against a fixed set of CUDA architectures.
63+
NVIDIA Blackwell GPUs (sm_120, compute capability 12.x) and newer were not included in
64+
the default ``TORCH_CUDA_ARCH_LIST`` when the MONAI slim image was originally built,
65+
so executing ``grid_pull`` on those devices produces incorrect results. Falling back to
66+
the PyTorch-native ``affine_grid`` + ``grid_sample`` path (``USE_COMPILED=False``) gives
67+
correct output on all architectures.
68+
69+
The threshold (``major >= 12``) matches the first architecture family (Blackwell, sm_120)
70+
that shipped after the highest sm supported in the current default build list (sm_90,
71+
Hopper). Adjust this constant when ``monai._C`` is rebuilt with sm_120+ support.
72+
"""
73+
if device.type != "cuda":
74+
return False
75+
return torch.cuda.get_device_properties(device).major >= 12
76+
77+
5778
def _maybe_new_metatensor(img, dtype=None, device=None):
5879
"""create a metatensor with fresh metadata if track_meta is True otherwise convert img into a torch tensor"""
5980
return convert_to_tensor(
@@ -158,7 +179,8 @@ def spatial_resample(
158179
xform_shape = [-1] + in_sp_size
159180
img = img.reshape(xform_shape)
160181
img = img.to(dtype_pt)
161-
if isinstance(mode, int) or USE_COMPILED:
182+
_use_compiled = USE_COMPILED and not _compiled_unsupported(img.device)
183+
if isinstance(mode, int) or _use_compiled:
162184
dst_xform = create_translate(spatial_rank, [float(d - 1) / 2 for d in spatial_size])
163185
xform = xform @ convert_to_dst_type(dst_xform, xform)[0]
164186
affine_xform = monai.transforms.Affine(

0 commit comments

Comments
 (0)