Skip to content

Commit f1873ec

Browse files
authored
Merge branch 'dev' into fix-8107-nibabel-load
2 parents e7412d5 + 5a2d0a7 commit f1873ec

11 files changed

Lines changed: 275 additions & 19 deletions

File tree

.github/workflows/pythonapp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
8383
- name: Install the dependencies
8484
run: |
85-
python -m pip install --user --upgrade pip wheel
85+
python -m pip install --user --upgrade pip wheel pybind11
8686
python -m pip install torch==2.5.1 torchvision==0.20.1
8787
cat "requirements-dev.txt"
8888
python -m pip install --no-build-isolation -r requirements-dev.txt

monai/apps/auto3dseg/auto_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def __init__(
229229
input = os.path.join(os.path.abspath(work_dir), "input.yaml")
230230
logger.info(f"Input config is not provided, using the default {input}")
231231

232-
self.data_src_cfg = dict()
232+
self.data_src_cfg = {}
233233
if isinstance(input, dict):
234234
self.data_src_cfg = input
235235
elif isinstance(input, str) and os.path.isfile(input):

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def compute_shape_offset(
881881
Default is False, using option 1 to compute the shape and offset.
882882
883883
"""
884-
shape = np.array(spatial_shape, copy=True, dtype=float)
884+
shape = np.array(tuple(spatial_shape), copy=True, dtype=float)
885885
sr = len(shape)
886886
in_affine_ = convert_data_type(to_affine_nd(sr, in_affine), np.ndarray)[0]
887887
out_affine_ = convert_data_type(to_affine_nd(sr, out_affine), np.ndarray)[0]

monai/engines/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def __call__(
219219
`kwargs` supports other args for `Tensor.to()` API.
220220
"""
221221
image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
222-
args_ = list()
223-
kwargs_ = dict()
222+
args_ = []
223+
kwargs_ = {}
224224

225225
def _get_data(key: str) -> torch.Tensor:
226226
data = batchdata[key]

monai/losses/image_dissimilarity.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.nn import functional as F
1616
from torch.nn.modules.loss import _Loss
1717

18-
from monai.networks.layers import gaussian_1d, separable_filtering
18+
from monai.networks.layers import separable_filtering
1919
from monai.utils import LossReduction
2020
from monai.utils.module import look_up_option
2121

@@ -34,11 +34,11 @@ def make_triangular_kernel(kernel_size: int) -> torch.Tensor:
3434

3535

3636
def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
37-
sigma = torch.tensor(kernel_size / 3.0)
38-
kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * (
39-
2.5066282 * sigma
40-
)
41-
return kernel[:kernel_size]
37+
sigma = kernel_size / 3.0
38+
half = kernel_size // 2
39+
x = torch.arange(-half, half + 1, dtype=torch.float)
40+
kernel = torch.exp(-0.5 / (sigma * sigma) * x**2)
41+
return kernel
4242

4343

4444
kernel_dict = {

monai/networks/nets/autoencoderkl.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
680680
681681
Args:
682682
old_state_dict: state dict from the old AutoencoderKL model.
683+
verbose: if True, print diagnostic information about key mismatches.
683684
"""
684685

685686
new_state_dict = self.state_dict()
@@ -715,13 +716,39 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
715716
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
716717
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
717718

718-
# old version did not have a projection so set these to the identity
719-
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
720-
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
721-
)
722-
new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
723-
new_state_dict[f"{block}.attn.out_proj.bias"].shape
724-
)
719+
out_w = f"{block}.attn.out_proj.weight"
720+
out_b = f"{block}.attn.out_proj.bias"
721+
proj_w = f"{block}.proj_attn.weight"
722+
proj_b = f"{block}.proj_attn.bias"
723+
724+
if out_w in new_state_dict:
725+
if proj_w in old_state_dict:
726+
new_state_dict[out_w] = old_state_dict.pop(proj_w)
727+
if proj_b in old_state_dict:
728+
new_state_dict[out_b] = old_state_dict.pop(proj_b)
729+
else:
730+
new_state_dict[out_b] = torch.zeros(
731+
new_state_dict[out_b].shape,
732+
dtype=new_state_dict[out_b].dtype,
733+
device=new_state_dict[out_b].device,
734+
)
735+
else:
736+
# No legacy proj_attn - initialize out_proj to identity/zero
737+
new_state_dict[out_w] = torch.eye(
738+
new_state_dict[out_w].shape[0],
739+
dtype=new_state_dict[out_w].dtype,
740+
device=new_state_dict[out_w].device,
741+
)
742+
new_state_dict[out_b] = torch.zeros(
743+
new_state_dict[out_b].shape,
744+
dtype=new_state_dict[out_b].dtype,
745+
device=new_state_dict[out_b].device,
746+
)
747+
elif proj_w in old_state_dict:
748+
# new model has no out_proj at all - discard the legacy keys so they
749+
# don't surface as "unexpected keys" during load_state_dict
750+
old_state_dict.pop(proj_w)
751+
old_state_dict.pop(proj_b, None)
725752

726753
# fix the upsample conv blocks which were renamed postconv
727754
for k in new_state_dict:

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ black==25.1.0
1818
isort>=5.1, <6, !=6.0.0
1919
ruff>=0.14.11,<0.15
2020
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
21+
pybind11
2122
types-setuptools
2223
mypy>=1.5.0, <1.12.0
2324
ninja
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
19+
from monai.data.utils import compute_shape_offset
20+
21+
22+
class TestComputeShapeOffset(unittest.TestCase):
23+
"""Unit tests for :func:`monai.data.utils.compute_shape_offset`."""
24+
25+
def test_pytorch_size_input(self):
26+
"""Validate `torch.Size` input produces expected shape and offset.
27+
28+
Returns:
29+
None.
30+
31+
Raises:
32+
AssertionError: If computed shape/offset are not as expected.
33+
"""
34+
# 1. Create a PyTorch Size object (which triggered the original bug)
35+
spatial_shape = torch.Size([10, 10, 10])
36+
in_affine = np.eye(4)
37+
out_affine = np.eye(4)
38+
39+
# 2. Feed it into the function
40+
shape, offset = compute_shape_offset(spatial_shape, in_affine, out_affine)
41+
42+
# 3. Prove it successfully processed the shape by checking its length
43+
self.assertEqual(len(shape), 3)
44+
45+
def setUp(self):
46+
"""Set up a 4x4 identity affine used across all test cases."""
47+
self.affine = np.eye(4)
48+
49+
def test_numpy_array_input(self):
50+
"""Verify compute_shape_offset accepts a numpy array as spatial_shape."""
51+
shape = np.array([64, 64, 64])
52+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
53+
self.assertEqual(len(out_shape), 3)
54+
55+
def test_list_input(self):
56+
"""Verify compute_shape_offset accepts a plain list as spatial_shape."""
57+
shape = [64, 64, 64]
58+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
59+
self.assertEqual(len(out_shape), 3)
60+
61+
def test_torch_tensor_input(self):
62+
"""Verify compute_shape_offset accepts a torch.Tensor as spatial_shape.
63+
64+
This path broke in PyTorch >= 2.9 because np.array() relied on the
65+
non-tuple sequence indexing protocol that PyTorch removed. Wrapping with
66+
tuple() fixes it.
67+
"""
68+
shape = torch.tensor([64, 64, 64])
69+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
70+
self.assertEqual(len(out_shape), 3)
71+
72+
def test_identity_affines_preserve_shape(self):
73+
"""Verify that identity in/out affines produce an output shape matching the input."""
74+
shape = torch.tensor([32, 48, 16])
75+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
76+
np.testing.assert_allclose(np.array(out_shape, dtype=float), shape.numpy().astype(float), atol=1e-5)
77+
78+
79+
if __name__ == "__main__":
80+
unittest.main()

tests/integration/test_reg_loss_integration.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"]],
2727
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"]],
2828
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"]],
29+
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "gaussian"}, ["pred", "target"]],
2930
[GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]],
3031
[GlobalMutualInformationLoss, {"kernel_type": "b-spline", "num_bins": 10}, ["pred", "target"]],
3132
]
@@ -98,6 +99,24 @@ def forward(self, x):
9899
optimizer.step()
99100
self.assertGreater(init_loss, loss_val, "loss did not decrease")
100101

102+
def test_lncc_gaussian_kernel_gt3_identical_images(self):
103+
"""
104+
Regression test for make_gaussian_kernel truncated parameter bug.
105+
LNCC on identical inputs must be close to -1.0 for gaussian kernel_size > 3.
106+
"""
107+
for kernel_size in [5, 7]:
108+
with self.subTest(kernel_size=kernel_size):
109+
loss_fn = LocalNormalizedCrossCorrelationLoss(
110+
spatial_dims=2, kernel_size=kernel_size, kernel_type="gaussian"
111+
).to(self.device)
112+
x = torch.rand(2, 1, 32, 32, device=self.device)
113+
y = x.clone()
114+
loss = loss_fn(x, y)
115+
self.assertTrue(
116+
torch.allclose(loss, torch.tensor(-1.0, device=self.device, dtype=loss.dtype), atol=1e-3),
117+
f"LNCC of identical images should be -1.0, got {loss.item():.6f} (kernel_size={kernel_size})",
118+
)
119+
101120

102121
if __name__ == "__main__":
103122
unittest.main()

tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from parameterized import parameterized
1919

20-
from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss
20+
from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss, make_gaussian_kernel
2121

2222
device = "cuda" if torch.cuda.is_available() else "cpu"
2323

@@ -113,6 +113,25 @@
113113
},
114114
-0.95406944,
115115
],
116+
# Regression tests for gh-8780: gaussian kernel_size > 3 was broken due to
117+
# truncated parameter being passed as pixel radius instead of sigma multiplier.
118+
# Identical images must yield loss == -1.0 for any kernel size.
119+
[
120+
{"spatial_dims": 1, "kernel_type": "gaussian", "kernel_size": 5},
121+
{
122+
"pred": torch.arange(0, 5).reshape(1, 1, -1).to(dtype=torch.float, device=device),
123+
"target": torch.arange(0, 5).reshape(1, 1, -1).to(dtype=torch.float, device=device),
124+
},
125+
-1.0,
126+
],
127+
[
128+
{"spatial_dims": 1, "kernel_type": "gaussian", "kernel_size": 9},
129+
{
130+
"pred": torch.arange(0, 9).reshape(1, 1, -1).to(dtype=torch.float, device=device),
131+
"target": torch.arange(0, 9).reshape(1, 1, -1).to(dtype=torch.float, device=device),
132+
},
133+
-1.0,
134+
],
116135
]
117136

118137

@@ -138,6 +157,15 @@ def test_ill_shape(self):
138157
torch.ones((1, 3, 4, 4, 4), dtype=torch.float, device=device),
139158
)
140159

160+
def test_gaussian_kernel_shape_and_symmetry(self):
161+
# gh-8780: kernel must have correct length, be symmetric, and peak at center
162+
for kernel_size in [3, 5, 7, 9, 11, 15]:
163+
k = make_gaussian_kernel(kernel_size)
164+
self.assertEqual(len(k), kernel_size)
165+
self.assertTrue(torch.allclose(k, k.flip(0)), f"kernel_size={kernel_size} not symmetric")
166+
self.assertEqual(k.argmax().item(), kernel_size // 2)
167+
np.testing.assert_allclose(k.max().item(), 1.0, rtol=1e-6)
168+
141169
def test_ill_opts(self):
142170
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)
143171
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)

0 commit comments

Comments
 (0)