Skip to content

Commit d30be44

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into 8587-test-erros-on-pytorch-release-2508-on-series-50
2 parents 0cd053a + 8d39519 commit d30be44

6 files changed

Lines changed: 20 additions & 12 deletions

File tree

monai/apps/auto3dseg/auto_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def __init__(
229229
input = os.path.join(os.path.abspath(work_dir), "input.yaml")
230230
logger.info(f"Input config is not provided, using the default {input}")
231231

232-
self.data_src_cfg = dict()
232+
self.data_src_cfg = {}
233233
if isinstance(input, dict):
234234
self.data_src_cfg = input
235235
elif isinstance(input, str) and os.path.isfile(input):

monai/auto3dseg/analyzer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Hashable, Mapping
1717
from copy import deepcopy
18-
from typing import Any
18+
from typing import Any, cast
1919

2020
import numpy as np
2121
import torch
@@ -470,6 +470,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470470
start = time.time()
471471
image_tensor = d[self.image_key]
472472
label_tensor = d[self.label_key]
473+
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
473474
using_cuda = any(
474475
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
475476
)
@@ -480,7 +481,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
480481
label_tensor, (MetaTensor, torch.Tensor)
481482
):
482483
if label_tensor.device != image_tensor.device:
483-
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
484+
if using_cuda:
485+
# Move both tensors to CUDA when mixing devices
486+
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487+
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488+
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
489+
else:
490+
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
484491

485492
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
486493
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)

monai/engines/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def __call__(
219219
`kwargs` supports other args for `Tensor.to()` API.
220220
"""
221221
image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
222-
args_ = list()
223-
kwargs_ = dict()
222+
args_ = []
223+
kwargs_ = {}
224224

225225
def _get_data(key: str) -> torch.Tensor:
226226
data = batchdata[key]

monai/transforms/inverse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def track_transform_meta(
282282
msg += f" for key {key}"
283283

284284
pend = out_obj.pending_operations[-1]
285-
statuses = pend.get(TraceKeys.STATUSES, dict())
286-
messages = statuses.get(TraceStatusKeys.PENDING_DURING_APPLY, list())
285+
statuses = pend.get(TraceKeys.STATUSES, {})
286+
messages = statuses.get(TraceStatusKeys.PENDING_DURING_APPLY, [])
287287
messages.append(msg)
288288
statuses[TraceStatusKeys.PENDING_DURING_APPLY] = messages
289289
info[TraceKeys.STATUSES] = statuses

tests/apps/test_auto3dseg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
393393
result = analyzer({"image": image_tensor, "label": label_tensor})
394394
report = result["label_stats"]
395395

396+
# Verify report format and computation succeeded despite mixed/unified devices
396397
assert verify_report_format(report, analyzer.get_report_format())
397398
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]
398399

tests/metrics/test_ssim_metric.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class TestSSIMMetric(unittest.TestCase):
2323

24-
def test2d_gaussian(self):
24+
def test_2d_gaussian(self):
2525
set_determinism(0)
2626
preds = torch.abs(torch.randn(2, 3, 16, 16))
2727
target = torch.abs(torch.randn(2, 3, 16, 16))
@@ -32,9 +32,9 @@ def test2d_gaussian(self):
3232
metric(preds, target)
3333
result = metric.aggregate()
3434
expected_value = 0.045415
35-
self.assertTrue(expected_value - result.item() < 0.000001)
35+
self.assertTrue(abs(expected_value - result.item()) < 0.000001)
3636

37-
def test2d_uniform(self):
37+
def test_2d_uniform(self):
3838
set_determinism(0)
3939
preds = torch.abs(torch.randn(2, 3, 16, 16))
4040
target = torch.abs(torch.randn(2, 3, 16, 16))
@@ -45,9 +45,9 @@ def test2d_uniform(self):
4545
metric(preds, target)
4646
result = metric.aggregate()
4747
expected_value = 0.050103
48-
self.assertTrue(expected_value - result.item() < 0.000001)
48+
self.assertTrue(abs(expected_value - result.item()) < 0.000001)
4949

50-
def test3d_gaussian(self):
50+
def test_3d_gaussian(self):
5151
set_determinism(0)
5252
preds = torch.abs(torch.randn(2, 3, 16, 16, 16))
5353
target = torch.abs(torch.randn(2, 3, 16, 16, 16))

0 commit comments

Comments
 (0)