|
65 | 65 | "outputs": [], |
66 | 66 | "source": [ |
67 | 67 | "import os\n", |
68 | | - "import shutil\n", |
69 | | - "import tempfile\n", |
70 | 68 | "\n", |
71 | 69 | "import h5py\n", |
72 | 70 | "import matplotlib.pyplot as plt\n", |
|
83 | 81 | "from monai.transforms import (\n", |
84 | 82 | " CenterSpatialCropd,\n", |
85 | 83 | " Compose,\n", |
| 84 | + " EnsureChannelFirstd,\n", |
86 | 85 | " EnsureTyped,\n", |
87 | 86 | " Lambdad,\n", |
88 | 87 | " LoadImaged,\n", |
89 | 88 | " ThresholdIntensityd,\n", |
90 | 89 | ")\n", |
| 90 | + "from monai.utils.type_conversion import convert_data_type\n", |
91 | 91 | "\n", |
92 | 92 | "print_config()" |
93 | 93 | ] |
94 | 94 | }, |
95 | | - { |
96 | | - "cell_type": "markdown", |
97 | | - "metadata": {}, |
98 | | - "source": [ |
99 | | - "## Setup data directory\n", |
100 | | - "\n", |
101 | | - "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.\n", |
102 | | - "This allows you to save results and reuse downloads.\n", |
103 | | - "If not specified a temporary directory will be used." |
104 | | - ] |
105 | | - }, |
106 | | - { |
107 | | - "cell_type": "code", |
108 | | - "execution_count": null, |
109 | | - "metadata": {}, |
110 | | - "outputs": [], |
111 | | - "source": [ |
112 | | - "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", |
113 | | - "if directory is not None:\n", |
114 | | - " os.makedirs(directory, exist_ok=True)\n", |
115 | | - "root_dir = tempfile.mkdtemp() if directory is None else directory\n", |
116 | | - "print(root_dir)" |
117 | | - ] |
118 | | - }, |
119 | 95 | { |
120 | 96 | "cell_type": "markdown", |
121 | 97 | "metadata": {}, |
|
132 | 108 | "**How to obtain the data:**\n", |
133 | 109 | "1. Register at https://fastmri.org/dataset\n", |
134 | 110 | "2. Download the knee single-coil validation set (`knee_singlecoil_val.tar.gz`)\n", |
135 | | - "3. Extract to a folder (e.g., `<MONAI_DATA_DIRECTORY>/knee_singlecoil_val/`)\n", |
| 111 | + "3. Extract to a folder and update `data_path` in the cell below\n", |
136 | 112 | "\n", |
137 | 113 | "**Note:** This dataset is under a non-commercial license. You may not use it for commercial purposes.\n", |
138 | 114 | "\n", |
|
149 | 125 | "source": [ |
150 | 126 | "# Update this path to where your fastMRI knee single-coil data is stored.\n", |
151 | 127 | "# You only need ONE .h5 file from the knee_singlecoil_val set.\n", |
152 | | - "data_path = os.path.join(root_dir, \"knee_singlecoil_val\")\n", |
| 128 | + "data_path = os.path.join(\"YOUR_DIR_HERE\", \"knee_singlecoil_val\")\n", |
| 129 | + "\n", |
| 130 | + "if not os.path.isdir(data_path):\n", |
| 131 | + " raise FileNotFoundError(\n", |
| 132 | + " f\"Data directory not found: {data_path}\\n\"\n", |
| 133 | + " \"Please download the fastMRI knee single-coil validation set from \"\n", |
| 134 | + " \"https://fastmri.org/dataset and update the path above.\"\n", |
| 135 | + " )\n", |
153 | 136 | "\n", |
154 | 137 | "sample_files = sorted(\n", |
155 | 138 | " [f for f in os.listdir(data_path) if f.endswith(\".h5\")]\n", |
156 | 139 | ")\n", |
| 140 | + "if len(sample_files) == 0:\n", |
| 141 | + " raise FileNotFoundError(\n", |
| 142 | + " f\"No .h5 files found in {data_path}\\n\"\n", |
| 143 | + " \"Please place at least one .h5 file from the knee_singlecoil_val set.\"\n", |
| 144 | + " )\n", |
| 145 | + "\n", |
157 | 146 | "sample_file = os.path.join(data_path, sample_files[0])\n", |
158 | 147 | "print(f\"Using sample file: {sample_file}\")\n", |
159 | 148 | "print(f\"Total .h5 files found: {len(sample_files)}\")" |
|
460 | 449 | "outputs": [], |
461 | 450 | "source": [ |
462 | 451 | "def as_numpy(x):\n", |
463 | | - " \"\"\"Convert torch tensor or array to numpy.\"\"\"\n", |
464 | | - " if isinstance(x, torch.Tensor):\n", |
465 | | - " return x.detach().cpu().numpy()\n", |
466 | | - " return np.asarray(x)\n", |
| 452 | + " \"\"\"Convert torch tensor or array to numpy using MONAI utility.\"\"\"\n", |
| 453 | + " arr, *_ = convert_data_type(x, np.ndarray)\n", |
| 454 | + " return arr\n", |
467 | 455 | "\n", |
468 | 456 | "\n", |
469 | 457 | "fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", |
|
476 | 464 | "axes[0, 0].set_title(\"Random mask pattern\")\n", |
477 | 465 | "axes[0, 0].axis(\"off\")\n", |
478 | 466 | "\n", |
479 | | - "rand_recon = np.abs(as_numpy(random_result[\"kspace_masked_ifft\"]).squeeze())\n", |
| 467 | + "rand_recon = as_numpy(random_result[\"kspace_masked_ifft\"])\n", |
| 468 | + "if rand_recon.ndim == 3:\n", |
| 469 | + " rand_recon = rand_recon[rand_recon.shape[0] // 2]\n", |
| 470 | + "rand_recon = np.abs(rand_recon.squeeze())\n", |
480 | 471 | "axes[0, 1].imshow(rand_recon, cmap=\"gray\")\n", |
481 | 472 | "axes[0, 1].set_title(\"Zero-filled recon (random 4x)\")\n", |
482 | 473 | "axes[0, 1].axis(\"off\")\n", |
|
494 | 485 | "axes[1, 0].set_title(\"Equispaced mask pattern\")\n", |
495 | 486 | "axes[1, 0].axis(\"off\")\n", |
496 | 487 | "\n", |
497 | | - "equi_recon = np.abs(\n", |
498 | | - " as_numpy(equispaced_result[\"kspace_masked_ifft\"]).squeeze()\n", |
499 | | - ")\n", |
| 488 | + "equi_recon = as_numpy(equispaced_result[\"kspace_masked_ifft\"])\n", |
| 489 | + "if equi_recon.ndim == 3:\n", |
| 490 | + " equi_recon = equi_recon[equi_recon.shape[0] // 2]\n", |
| 491 | + "equi_recon = np.abs(equi_recon.squeeze())\n", |
500 | 492 | "axes[1, 1].imshow(equi_recon, cmap=\"gray\")\n", |
501 | 493 | "axes[1, 1].set_title(\"Zero-filled recon (equispaced 4x)\")\n", |
502 | 494 | "axes[1, 1].axis(\"off\")\n", |
|
558 | 550 | " return x\n", |
559 | 551 | "\n", |
560 | 552 | "\n", |
561 | | - "def to_numpy(x):\n", |
562 | | - " \"\"\"Convert torch tensor to numpy safely.\"\"\"\n", |
563 | | - " if isinstance(x, torch.Tensor):\n", |
564 | | - " return x.detach().cpu().numpy()\n", |
565 | | - " return np.asarray(x)\n", |
566 | | - "\n", |
567 | | - "\n", |
568 | | - "def ensure_channel_first_2d(x):\n", |
569 | | - " \"\"\"Promote (H, W) to (1, H, W) for CenterSpatialCropd.\"\"\"\n", |
570 | | - " x = np.asarray(x)\n", |
571 | | - " if x.ndim == 2:\n", |
572 | | - " return x[None, ...]\n", |
573 | | - " return x\n", |
574 | | - "\n", |
575 | | - "\n", |
576 | 553 | "def complex_to_magnitude(x):\n", |
577 | 554 | " \"\"\"Take magnitude if complex; otherwise return as-is.\"\"\"\n", |
578 | 555 | " x = np.asarray(x)\n", |
|
609 | 586 | "\n", |
610 | 587 | "# Stage 2: Shape fixes, crop to 320x320, normalize, clamp\n", |
611 | 588 | "post_transform = Compose([\n", |
612 | | - " Lambdad(\n", |
613 | | - " keys=[\"kspace_masked_ifft\", \"reconstruction_esc\"],\n", |
614 | | - " func=to_numpy,\n", |
| 589 | + " EnsureChannelFirstd(\n", |
| 590 | + " keys=[\"kspace_masked_ifft\"],\n", |
| 591 | + " channel_dim=\"no_channel\",\n", |
615 | 592 | " ),\n", |
616 | | - " Lambdad(keys=[\"kspace_masked_ifft\"], func=ensure_channel_first_2d),\n", |
617 | 593 | " CenterSpatialCropd(\n", |
618 | 594 | " keys=[\"kspace_masked_ifft\", \"reconstruction_esc\"],\n", |
619 | 595 | " roi_size=(320, 320),\n", |
|
677 | 653 | "source": [ |
678 | 654 | "def as_numpy_2d(x):\n", |
679 | 655 | " \"\"\"Return a 2D numpy image from (1, H, W) torch/numpy.\"\"\"\n", |
680 | | - " if isinstance(x, torch.Tensor):\n", |
681 | | - " x = x.detach().cpu().numpy()\n", |
682 | | - " x = np.asarray(x)\n", |
683 | | - " return x[0] if x.ndim == 3 else x\n", |
| 656 | + " arr, *_ = convert_data_type(x, np.ndarray)\n", |
| 657 | + " return arr[0] if arr.ndim == 3 else arr\n", |
684 | 658 | "\n", |
685 | 659 | "\n", |
686 | 660 | "input_img = as_numpy_2d(result[\"kspace_masked_ifft\"])\n", |
|
751 | 725 | "\n", |
752 | 726 | "Data used in the preparation of this tutorial were obtained from the NYU fastMRI Initiative database (fastmri.med.nyu.edu). [Citation: Knoll et al., Radiol Artif Intell. 2020 Jan 29;2(1):e190007. doi: 10.1148/ryai.2020190007. (https://pubs.rsna.org/doi/10.1148/ryai.2020190007), and the arXiv paper: https://arxiv.org/abs/1811.08839] As such, NYU fastMRI investigators provided data but did not participate in analysis or writing of this tutorial. A listing of NYU fastMRI investigators, subject to updates, can be found at: fastmri.med.nyu.edu. The primary goal of fastMRI is to test whether machine learning can aid in the reconstruction of medical images." |
753 | 727 | ] |
754 | | - }, |
755 | | - { |
756 | | - "cell_type": "markdown", |
757 | | - "metadata": {}, |
758 | | - "source": [ |
759 | | - "## Cleanup data directory\n", |
760 | | - "\n", |
761 | | - "Remove directory if a temporary was used." |
762 | | - ] |
763 | | - }, |
764 | | - { |
765 | | - "cell_type": "code", |
766 | | - "execution_count": null, |
767 | | - "metadata": {}, |
768 | | - "outputs": [], |
769 | | - "source": [ |
770 | | - "if directory is None:\n", |
771 | | - " shutil.rmtree(root_dir)\n", |
772 | | - "print(\"Cleanup complete!\")" |
773 | | - ] |
774 | 728 | } |
775 | 729 | ], |
776 | 730 | "metadata": { |
|
0 commit comments