Skip to content

Commit 7cd0607

Browse files
committed
fix(spatial): extend Blackwell fallback to Resample class in array.py
The same USE_COMPILED guard that was fixed in spatial_resample (functional.py) was also present in Resample.__call__ (array.py), used by Affine, RandAffine and related transforms. Apply the same _compiled_unsupported() check so that grid_pull is not called on sm_120 (Blackwell) devices when monai._C lacks sm_120 support, preventing garbage output in test_affine, test_affined, test_rand_affine and test_rand_affined on RTX 50-series GPUs.
1 parent 09c2cd9 commit 7cd0607

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

monai/transforms/spatial/array.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch
2525

2626
from monai.config import USE_COMPILED, DtypeLike
27+
from monai.transforms.spatial.functional import _compiled_unsupported
2728
from monai.config.type_definitions import NdarrayOrTensor
2829
from monai.data.box_utils import BoxMode, StandardMode
2930
from monai.data.meta_obj import get_track_meta, set_track_meta
@@ -2062,14 +2063,15 @@ def __call__(
20622063
_align_corners = self.align_corners if align_corners is None else align_corners
20632064
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device)
20642065
sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3)
2066+
_use_compiled = USE_COMPILED and not _compiled_unsupported(img_t.device)
20652067
backend, _interp_mode, _padding_mode, _ = resolves_modes(
20662068
self.mode if mode is None else mode,
20672069
self.padding_mode if padding_mode is None else padding_mode,
20682070
backend=None,
2069-
use_compiled=USE_COMPILED,
2071+
use_compiled=_use_compiled,
20702072
)
20712073

2072-
if USE_COMPILED or backend == TransformBackends.NUMPY:
2074+
if _use_compiled or backend == TransformBackends.NUMPY:
20732075
grid_t, *_ = convert_to_dst_type(grid[:sr], img_t, dtype=grid.dtype, wrap_sequence=True)
20742076
if isinstance(grid, torch.Tensor) and grid_t.data_ptr() == grid.data_ptr():
20752077
grid_t = grid_t.clone(memory_format=torch.contiguous_format)
@@ -2080,7 +2082,7 @@ def __call__(
20802082
grid_t[i] = ((_dim - 1) / _dim) * grid_t[i] + t if _align_corners else grid_t[i] + t
20812083
elif _align_corners:
20822084
grid_t[i] = ((_dim - 1) / _dim) * (grid_t[i] + 0.5)
2083-
if USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name
2085+
if _use_compiled and backend == TransformBackends.TORCH: # compiled is using torch backend param name
20842086
grid_t = moveaxis(grid_t, 0, -1) # type: ignore
20852087
out = grid_pull(
20862088
img_t.unsqueeze(0),

0 commit comments

Comments
 (0)