Skip to content

Commit d5c24fd

Browse files
committed
Address reviewer feedback: use MONAI utilities and simplify data setup
- Replace custom as_numpy() with monai.utils.convert_data_type - Replace custom ensure_channel_first_2d() with EnsureChannelFirstd - Remove MONAI_DATA_DIRECTORY/tempfile pattern, use direct path - Add error handling for missing data directory and .h5 files - Fix Part 4 slice mismatch for multi-slice MONAI output - Remove cleanup cell (no longer needed) Signed-off-by: Vidya Sagar <[email protected]>
1 parent df38e67 commit d5c24fd

1 file changed

Lines changed: 33 additions & 79 deletions

File tree

reconstruction/MRI_reconstruction/tutorials/01_kspace_basics_fastmri_knee.ipynb

Lines changed: 33 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@
6565
"outputs": [],
6666
"source": [
6767
"import os\n",
68-
"import shutil\n",
69-
"import tempfile\n",
7068
"\n",
7169
"import h5py\n",
7270
"import matplotlib.pyplot as plt\n",
@@ -83,39 +81,17 @@
8381
"from monai.transforms import (\n",
8482
" CenterSpatialCropd,\n",
8583
" Compose,\n",
84+
" EnsureChannelFirstd,\n",
8685
" EnsureTyped,\n",
8786
" Lambdad,\n",
8887
" LoadImaged,\n",
8988
" ThresholdIntensityd,\n",
9089
")\n",
90+
"from monai.utils.type_conversion import convert_data_type\n",
9191
"\n",
9292
"print_config()"
9393
]
9494
},
95-
{
96-
"cell_type": "markdown",
97-
"metadata": {},
98-
"source": [
99-
"## Setup data directory\n",
100-
"\n",
101-
"You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.\n",
102-
"This allows you to save results and reuse downloads.\n",
103-
"If not specified a temporary directory will be used."
104-
]
105-
},
106-
{
107-
"cell_type": "code",
108-
"execution_count": null,
109-
"metadata": {},
110-
"outputs": [],
111-
"source": [
112-
"directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
113-
"if directory is not None:\n",
114-
" os.makedirs(directory, exist_ok=True)\n",
115-
"root_dir = tempfile.mkdtemp() if directory is None else directory\n",
116-
"print(root_dir)"
117-
]
118-
},
11995
{
12096
"cell_type": "markdown",
12197
"metadata": {},
@@ -132,7 +108,7 @@
132108
"**How to obtain the data:**\n",
133109
"1. Register at https://fastmri.org/dataset\n",
134110
"2. Download the knee single-coil validation set (`knee_singlecoil_val.tar.gz`)\n",
135-
"3. Extract to a folder (e.g., `<MONAI_DATA_DIRECTORY>/knee_singlecoil_val/`)\n",
111+
"3. Extract to a folder and update `data_path` in the cell below\n",
136112
"\n",
137113
"**Note:** This dataset is under a non-commercial license. You may not use it for commercial purposes.\n",
138114
"\n",
@@ -149,11 +125,24 @@
149125
"source": [
150126
"# Update this path to where your fastMRI knee single-coil data is stored.\n",
151127
"# You only need ONE .h5 file from the knee_singlecoil_val set.\n",
152-
"data_path = os.path.join(root_dir, \"knee_singlecoil_val\")\n",
128+
"data_path = os.path.join(\"YOUR_DIR_HERE\", \"knee_singlecoil_val\")\n",
129+
"\n",
130+
"if not os.path.isdir(data_path):\n",
131+
" raise FileNotFoundError(\n",
132+
" f\"Data directory not found: {data_path}\\n\"\n",
133+
" \"Please download the fastMRI knee single-coil validation set from \"\n",
134+
" \"https://fastmri.org/dataset and update the path above.\"\n",
135+
" )\n",
153136
"\n",
154137
"sample_files = sorted(\n",
155138
" [f for f in os.listdir(data_path) if f.endswith(\".h5\")]\n",
156139
")\n",
140+
"if len(sample_files) == 0:\n",
141+
" raise FileNotFoundError(\n",
142+
" f\"No .h5 files found in {data_path}\\n\"\n",
143+
" \"Please place at least one .h5 file from the knee_singlecoil_val set.\"\n",
144+
" )\n",
145+
"\n",
157146
"sample_file = os.path.join(data_path, sample_files[0])\n",
158147
"print(f\"Using sample file: {sample_file}\")\n",
159148
"print(f\"Total .h5 files found: {len(sample_files)}\")"
@@ -460,10 +449,9 @@
460449
"outputs": [],
461450
"source": [
462451
"def as_numpy(x):\n",
463-
" \"\"\"Convert torch tensor or array to numpy.\"\"\"\n",
464-
" if isinstance(x, torch.Tensor):\n",
465-
" return x.detach().cpu().numpy()\n",
466-
" return np.asarray(x)\n",
452+
" \"\"\"Convert torch tensor or array to numpy using MONAI utility.\"\"\"\n",
453+
" arr, *_ = convert_data_type(x, np.ndarray)\n",
454+
" return arr\n",
467455
"\n",
468456
"\n",
469457
"fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
@@ -476,7 +464,10 @@
476464
"axes[0, 0].set_title(\"Random mask pattern\")\n",
477465
"axes[0, 0].axis(\"off\")\n",
478466
"\n",
479-
"rand_recon = np.abs(as_numpy(random_result[\"kspace_masked_ifft\"]).squeeze())\n",
467+
"rand_recon = as_numpy(random_result[\"kspace_masked_ifft\"])\n",
468+
"if rand_recon.ndim == 3:\n",
469+
" rand_recon = rand_recon[rand_recon.shape[0] // 2]\n",
470+
"rand_recon = np.abs(rand_recon.squeeze())\n",
480471
"axes[0, 1].imshow(rand_recon, cmap=\"gray\")\n",
481472
"axes[0, 1].set_title(\"Zero-filled recon (random 4x)\")\n",
482473
"axes[0, 1].axis(\"off\")\n",
@@ -494,9 +485,10 @@
494485
"axes[1, 0].set_title(\"Equispaced mask pattern\")\n",
495486
"axes[1, 0].axis(\"off\")\n",
496487
"\n",
497-
"equi_recon = np.abs(\n",
498-
" as_numpy(equispaced_result[\"kspace_masked_ifft\"]).squeeze()\n",
499-
")\n",
488+
"equi_recon = as_numpy(equispaced_result[\"kspace_masked_ifft\"])\n",
489+
"if equi_recon.ndim == 3:\n",
490+
" equi_recon = equi_recon[equi_recon.shape[0] // 2]\n",
491+
"equi_recon = np.abs(equi_recon.squeeze())\n",
500492
"axes[1, 1].imshow(equi_recon, cmap=\"gray\")\n",
501493
"axes[1, 1].set_title(\"Zero-filled recon (equispaced 4x)\")\n",
502494
"axes[1, 1].axis(\"off\")\n",
@@ -558,21 +550,6 @@
558550
" return x\n",
559551
"\n",
560552
"\n",
561-
"def to_numpy(x):\n",
562-
" \"\"\"Convert torch tensor to numpy safely.\"\"\"\n",
563-
" if isinstance(x, torch.Tensor):\n",
564-
" return x.detach().cpu().numpy()\n",
565-
" return np.asarray(x)\n",
566-
"\n",
567-
"\n",
568-
"def ensure_channel_first_2d(x):\n",
569-
" \"\"\"Promote (H, W) to (1, H, W) for CenterSpatialCropd.\"\"\"\n",
570-
" x = np.asarray(x)\n",
571-
" if x.ndim == 2:\n",
572-
" return x[None, ...]\n",
573-
" return x\n",
574-
"\n",
575-
"\n",
576553
"def complex_to_magnitude(x):\n",
577554
" \"\"\"Take magnitude if complex; otherwise return as-is.\"\"\"\n",
578555
" x = np.asarray(x)\n",
@@ -609,11 +586,10 @@
609586
"\n",
610587
"# Stage 2: Shape fixes, crop to 320x320, normalize, clamp\n",
611588
"post_transform = Compose([\n",
612-
" Lambdad(\n",
613-
" keys=[\"kspace_masked_ifft\", \"reconstruction_esc\"],\n",
614-
" func=to_numpy,\n",
589+
" EnsureChannelFirstd(\n",
590+
" keys=[\"kspace_masked_ifft\"],\n",
591+
" channel_dim=\"no_channel\",\n",
615592
" ),\n",
616-
" Lambdad(keys=[\"kspace_masked_ifft\"], func=ensure_channel_first_2d),\n",
617593
" CenterSpatialCropd(\n",
618594
" keys=[\"kspace_masked_ifft\", \"reconstruction_esc\"],\n",
619595
" roi_size=(320, 320),\n",
@@ -677,10 +653,8 @@
677653
"source": [
678654
"def as_numpy_2d(x):\n",
679655
" \"\"\"Return a 2D numpy image from (1, H, W) torch/numpy.\"\"\"\n",
680-
" if isinstance(x, torch.Tensor):\n",
681-
" x = x.detach().cpu().numpy()\n",
682-
" x = np.asarray(x)\n",
683-
" return x[0] if x.ndim == 3 else x\n",
656+
" arr, *_ = convert_data_type(x, np.ndarray)\n",
657+
" return arr[0] if arr.ndim == 3 else arr\n",
684658
"\n",
685659
"\n",
686660
"input_img = as_numpy_2d(result[\"kspace_masked_ifft\"])\n",
@@ -751,26 +725,6 @@
751725
"\n",
752726
"Data used in the preparation of this tutorial were obtained from the NYU fastMRI Initiative database (fastmri.med.nyu.edu). [Citation: Knoll et al., Radiol Artif Intell. 2020 Jan 29;2(1):e190007. doi: 10.1148/ryai.2020190007. (https://pubs.rsna.org/doi/10.1148/ryai.2020190007), and the arXiv paper: https://arxiv.org/abs/1811.08839] As such, NYU fastMRI investigators provided data but did not participate in analysis or writing of this tutorial. A listing of NYU fastMRI investigators, subject to updates, can be found at: fastmri.med.nyu.edu. The primary goal of fastMRI is to test whether machine learning can aid in the reconstruction of medical images."
753727
]
754-
},
755-
{
756-
"cell_type": "markdown",
757-
"metadata": {},
758-
"source": [
759-
"## Cleanup data directory\n",
760-
"\n",
761-
"Remove directory if a temporary was used."
762-
]
763-
},
764-
{
765-
"cell_type": "code",
766-
"execution_count": null,
767-
"metadata": {},
768-
"outputs": [],
769-
"source": [
770-
"if directory is None:\n",
771-
" shutil.rmtree(root_dir)\n",
772-
"print(\"Cleanup complete!\")"
773-
]
774728
}
775729
],
776730
"metadata": {

0 commit comments

Comments
 (0)