Skip to content
80 changes: 38 additions & 42 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,22 +217,6 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())

def __call__(self, data):
# Input Validation Addition
if not isinstance(data, dict):
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
if self.image_key not in data:
raise KeyError(f"Key '{self.image_key}' not found in input data.")
image = data[self.image_key]
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
raise TypeError(
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
f"but got {type(image).__name__}."
)
if image.ndim < 3:
raise ValueError(
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
)
# --- End of validation ---
"""
Callable to execute the pre-defined functions

Expand All @@ -242,7 +226,9 @@ def __call__(self, data):
has stats pre-defined by SampleOperations (max, min, ....).

Raises:
RuntimeError if the stats report generated is not consistent with the pre-
ValueError: if ``nda_croppeds`` is present in the input dict but is not a
list/tuple or has a different length than the number of image channels.
RuntimeError: if the stats report generated is not consistent with the pre-
defined report_format.

Note:
Expand All @@ -255,36 +241,46 @@ def __call__(self, data):
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)
try:
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
else:
nda_croppeds = d["nda_croppeds"]
if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas):
raise ValueError(
f"Pre-computed 'nda_croppeds' must be a list with one entry per image channel "
f"(expected {len(ndas)}, got "
f"{len(nda_croppeds) if isinstance(nda_croppeds, (list, tuple)) else type(nda_croppeds).__name__})."
)

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)

report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]
report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]

report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]
report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report
d[self.stats_name] = report
finally:
torch.set_grad_enabled(restore_grad_state)

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get image stats spent {time.time() - start}")
return d

Expand Down
59 changes: 59 additions & 0 deletions tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def __call__(self, data):


class TestDataAnalyzer(unittest.TestCase):
"""Integration tests for the auto3dseg analyzer pipeline."""

def setUp(self):
"""Create temporary directory and write simulated datalist JSON file."""
self.test_dir = tempfile.TemporaryDirectory()
work_dir = self.test_dir.name
self.dataroot_dir = os.path.join(work_dir, "sim_dataroot")
Expand All @@ -188,6 +191,7 @@ def setUp(self):

@parameterized.expand(SIM_CPU_TEST_CASES)
def test_data_analyzer_cpu(self, input_params):
"""Verify DataAnalyzer produces per-case stats on CPU across dim/label combinations."""
sim_dim = input_params["sim_dim"]
label_key = input_params["label_key"]
image_only = not bool(label_key)
Expand All @@ -204,6 +208,7 @@ def test_data_analyzer_cpu(self, input_params):
assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])

def test_data_analyzer_histogram(self):
"""Verify DataAnalyzer runs in histogram_only mode with no label key."""
create_sim_data(
self.dataroot_dir, sim_datalist, [32] * 3, image_only=True, rad_max=8, rad_min=1, num_seg_classes=1
)
Expand All @@ -221,6 +226,7 @@ def test_data_analyzer_histogram(self):
@parameterized.expand(SIM_GPU_TEST_CASES)
@skip_if_no_cuda
def test_data_analyzer_gpu(self, input_params):
"""Verify DataAnalyzer produces per-case stats on GPU (skipped if CUDA unavailable)."""
sim_dim = input_params["sim_dim"]
label_key = input_params["label_key"]
image_only = not bool(label_key)
Expand All @@ -236,6 +242,7 @@ def test_data_analyzer_gpu(self, input_params):
assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])

def test_basic_operation_class(self):
"""Verify Operations.evaluate returns correct stat keys and shapes with and without axis."""
op = TestOperations()
test_data = np.random.rand(10, 10).astype(np.float64)
test_ret_1 = op.evaluate(test_data)
Expand All @@ -250,6 +257,7 @@ def test_basic_operation_class(self):
assert test_ret_2["max"].ndim == 1

def test_sample_operations(self):
"""Verify SampleOperations works with both numpy arrays and MetaTensors."""
op = SampleOperations()
test_data_np = np.random.rand(10, 10).astype(np.float64)
test_data_mt = MetaTensor(test_data_np, device=device)
Expand All @@ -265,6 +273,7 @@ def test_sample_operations(self):
assert "sum" in test_ret_np

def test_summary_operations(self):
"""Verify SummaryOperations reduces a stat dict to scalar summary values."""
op = SummaryOperations()
test_dict = {"min": [0, 1, 2, 3], "max": [2, 3, 4, 5], "mean": [1, 2, 3, 4], "sum": [2, 4, 6, 8]}
test_ret = op.evaluate(test_dict)
Expand All @@ -277,6 +286,7 @@ def test_summary_operations(self):
assert isinstance(test_ret["sum"], Number)

def test_basic_analyzer_class(self):
"""Verify a custom Analyzer subclass computes and stores stats in the output dict."""
test_data = {}
test_data["image_test"] = np.random.rand(10, 10)
report_format = {"stats": None}
Expand All @@ -288,6 +298,7 @@ def test_basic_analyzer_class(self):
assert result["test"]["stats"]["mean"] == np.mean(test_data["image_test"])

def test_transform_analyzer_class(self):
"""Verify a custom Analyzer integrates correctly as a step in a Compose transform."""
transform = Compose([LoadImaged(keys=["image"]), TestImageAnalyzer(image_key="image")])
create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
Expand All @@ -302,6 +313,7 @@ def test_transform_analyzer_class(self):
assert "mean" in d["test_image"]["test_stats"]

def test_image_stats_case_analyzer(self):
"""Verify ImageStats produces a report matching the expected format for 3-D images."""
analyzer = ImageStats(image_key="image")
transform = Compose(
[
Expand All @@ -323,6 +335,7 @@ def test_image_stats_case_analyzer(self):
assert verify_report_format(d["image_stats"], report_format)

def test_foreground_image_stats_cases_analyzer(self):
"""Verify FgImageStats produces a valid foreground stats report."""
analyzer = FgImageStats(image_key="image", label_key="label")
transform_list = [
LoadImaged(keys=["image", "label"]),
Expand All @@ -345,6 +358,7 @@ def test_foreground_image_stats_cases_analyzer(self):
assert verify_report_format(d["image_foreground_stats"], report_format)

def test_label_stats_case_analyzer(self):
"""Verify LabelStats produces a valid report including per-label statistics."""
analyzer = LabelStats(image_key="image", label_key="label")
transform = Compose(
[
Expand All @@ -369,6 +383,7 @@ def test_label_stats_case_analyzer(self):

@parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES)
def test_label_stats_mixed_device_analyzer(self, input_params):
"""Verify LabelStats handles tensors split across CPU and CUDA devices."""
image_device = torch.device(input_params["image_device"])
label_device = torch.device(input_params["label_device"])

Expand Down Expand Up @@ -413,6 +428,7 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)

def test_filename_case_analyzer(self):
"""Verify FilenameStats records both image and label paths in the output dict."""
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)
transform_list = [LoadImaged(keys=["image", "label"]), analyzer_image, analyzer_label]
Expand All @@ -426,6 +442,7 @@ def test_filename_case_analyzer(self):
assert DataStatsKeys.BY_CASE_IMAGE_PATH in d

def test_filename_case_analyzer_image_only(self):
"""Verify FilenameStats handles image-only input and stores 'None' for the label path."""
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
analyzer_label = FilenameStats(None, DataStatsKeys.BY_CASE_IMAGE_PATH)
transform_list = [LoadImaged(keys=["image"]), analyzer_image, analyzer_label]
Expand All @@ -440,6 +457,7 @@ def test_filename_case_analyzer_image_only(self):
assert d[DataStatsKeys.BY_CASE_IMAGE_PATH] == "None"

def test_image_stats_summary_analyzer(self):
"""Verify ImageStatsSumm correctly aggregates per-case image stats."""
summary_analyzer = ImageStatsSumm("image_stats")

transform_list = [
Expand All @@ -463,6 +481,7 @@ def test_image_stats_summary_analyzer(self):
assert verify_report_format(summary_report, report_format)

def test_fg_image_stats_summary_analyzer(self):
"""Verify FgImageStatsSumm correctly aggregates per-case foreground stats."""
summary_analyzer = FgImageStatsSumm("image_foreground_stats")

transform_list = [
Expand All @@ -488,6 +507,7 @@ def test_fg_image_stats_summary_analyzer(self):
assert verify_report_format(summary_report, report_format)

def test_label_stats_summary_analyzer(self):
"""Verify LabelStatsSumm correctly aggregates per-case label stats."""
summary_analyzer = LabelStatsSumm("label_stats")

transform_list = [
Expand All @@ -513,6 +533,7 @@ def test_label_stats_summary_analyzer(self):
assert verify_report_format(summary_report, report_format)

def test_seg_summarizer(self):
"""Verify SegSummarizer produces a summary with image, foreground, and label stat keys."""
summarizer = SegSummarizer("image", "label")
keys = ["image", "label"]
transform_list = [
Expand All @@ -539,7 +560,45 @@ def test_seg_summarizer(self):
assert str(DataStatsKeys.FG_IMAGE_STATS) in report
assert str(DataStatsKeys.LABEL_STATS) in report

def test_image_stats_precomputed_nda_croppeds(self):
"""Verify ImageStats handles pre-populated nda_croppeds without crashing.

Previously raised UnboundLocalError because nda_croppeds was only assigned
inside the ``if "nda_croppeds" not in d`` branch but used unconditionally.
"""
analyzer = ImageStats(image_key="image")
image = torch.rand(1, 10, 10, 10)
precomputed = [np.random.rand(8, 8, 8)] # simulated pre-cropped foreground
data = {"image": MetaTensor(image), "nda_croppeds": precomputed}
result = analyzer(data)
assert "image_stats" in result
assert verify_report_format(result["image_stats"], analyzer.get_report_format())

def test_analyzer_grad_state_restored_after_call(self):
"""Verify ImageStats restores torch grad-enabled state on both normal and disabled entry.

Checks that the try/finally guard correctly restores the state regardless of
whether grad was enabled or disabled before the call.
"""
analyzer = ImageStats(image_key="image")
image = torch.rand(1, 10, 10, 10)
data = {"image": MetaTensor(image)}

# grad enabled before call → must still be enabled after
torch.set_grad_enabled(True)
analyzer(data)
assert torch.is_grad_enabled(), "grad state was not restored after ImageStats call"

# grad disabled before call → must still be disabled after
torch.set_grad_enabled(False)
try:
analyzer(data)
assert not torch.is_grad_enabled(), "grad state was not restored after ImageStats call"
finally:
torch.set_grad_enabled(True) # always restore for subsequent tests

def tearDown(self) -> None:
"""Remove the temporary test directory."""
self.test_dir.cleanup()


Expand Down
Loading