diff --git a/docs/notebooks/pet_motion_estimation.ipynb b/docs/notebooks/pet_motion_estimation.ipynb index 95195f903..0884ce564 100644 --- a/docs/notebooks/pet_motion_estimation.ipynb +++ b/docs/notebooks/pet_motion_estimation.ipynb @@ -74,343 +74,6 @@ "pet_dataset" ] }, - { - "cell_type": "code", - "execution_count": 4, - "id": "09e5053e-4565-4892-a0a9-a8410fbe6748", - "metadata": {}, - "outputs": [], - "source": [ - "data_train, data_test = pet_dataset.lofo_split(15)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "a51fcb74-0e57-4ee2-8c24-1f89ff6f879c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[[[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., -0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., -0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., 0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., -0., ..., -0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " [[[-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., -0., ..., -0., -0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., 0., -0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., -0.],\n", - " [-0., 0., -0., ..., 0., -0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., -0.],\n", - " [ 0., 0., -0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., -0., ..., -0., 0., 0.]]],\n", - "\n", - "\n", - " [[[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., -0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., -0., -0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., -0., 0., 0.],\n", - " [ 0., 0., -0., ..., 0., -0., 0.],\n", - " [ 0., -0., 0., ..., -0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., -0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " [[[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " [[[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]]]], dtype=float32)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_train[0]" - ] - }, { "metadata": {}, "cell_type": "markdown", @@ -2332,7 +1995,7 @@ "source": [ "from nifreeze.estimator import PETMotionEstimator\n", "\n", - "estimator = PETMotionEstimator()\n", + "estimator = PETMotionEstimator(model)\n", "\n", "# Run the estimator\n", "affines = estimator.run(\n", diff --git a/src/nifreeze/data/pet/base.py b/src/nifreeze/data/pet/base.py index d7d0b1fd5..695a92271 100644 --- a/src/nifreeze/data/pet/base.py +++ b/src/nifreeze/data/pet/base.py @@ -280,44 +280,6 @@ def __getitem__( """ return super().__getitem__(idx) - def lofo_split(self, index): - """ - Leave-one-frame-out (LOFO) for PET data. - - Parameters - ---------- - index : int - Index of the PET frame to be left out in this fold. - - Returns - ------- - (train_data, train_timings) : tuple - Training data and corresponding timings, excluding the left-out frame. - (test_data, test_timing) : tuple - Test data (one PET frame) and corresponding timing. - """ - - if not Path(self._filepath).exists(): - self.to_filename(self._filepath) - - # Read original PET data - with h5py.File(self._filepath, "r") as in_file: - root = in_file["/0"] - pet_frame = np.asanyarray(root["dataobj"][..., index]) - timing_frame = np.asanyarray(root["midframe"][..., index]) - - # Mask to exclude the selected frame - mask = np.ones(self.dataobj.shape[-1], dtype=bool) - mask[index] = False - - train_data = self.dataobj[..., mask] - train_timings = self.midframe[mask] - - test_data = pet_frame - test_timing = timing_frame - - return (train_data, train_timings), (test_data, test_timing) - def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None: """Set an affine, and update data object and gradients.""" ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 488c78c3f..b9d0b971f 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -40,7 +40,6 @@ from nifreeze.data.base import BaseDataset from nifreeze.data.pet import PET from nifreeze.model.base import BaseModel, ModelFactory -from nifreeze.model.pet import BSplinePETModel from nifreeze.registration.ants import ( Registration, _prepare_registration_data, @@ -228,87 +227,93 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: class PETMotionEstimator: """Estimates motion within PET imaging data aligned with generic Estimator workflow.""" - def __init__(self, align_kwargs: dict | None = None, strategy: str = "lofo"): - self.align_kwargs = align_kwargs or {} - self.strategy = strategy + def __init__( + self, + model: BaseModel | str, + strategy="linear", + model_kwargs: dict | None = None, + align_kwargs: dict | None = None, + ): + self._model = model + self._strategy = strategy + self._model_kwargs = model_kwargs or {} + self._align_kwargs = align_kwargs or {} + + def run(self, dataset: PET, omp_nthreads: int | None = None) -> list: + # Prepare iterator + iterfunc = getattr(iterators, f"{self._strategy}_iterator") + index_iter = iterfunc(size=len(dataset), seed=self._align_kwargs.get("seed", None)) + + # Initialize model + if isinstance(self._model, str): + model = ModelFactory.init( + model=self._model, + dataset=dataset, + **self._model_kwargs, + ) + else: + model = self._model - def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list: - n_frames = len(pet_dataset) - frame_indices = np.arange(n_frames).astype(int) + dataset_length = len(dataset) if omp_nthreads: - self.align_kwargs["num_threads"] = omp_nthreads + self._align_kwargs["num_threads"] = omp_nthreads affine_matrices = [] with TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) - for idx in tqdm(frame_indices, desc="Estimating PET motion"): - (train_data, train_times), (test_data, test_time) = pet_dataset.lofo_split(idx) + with tqdm(total=dataset_length, unit="vols.") as pbar: + for i in index_iter: + pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>") + + # Fit the model once on the training dataset + model.fit_predict(None) + + # Predict the reference volume at the test frame's timepoint + predicted = model.fit_predict(i) - if train_times is None: - raise ValueError( - f"train_times is None at index {idx}, check midframe initialization." + fixed_image_path = tmp_path / f"fixed_frame_{i:03d}.nii.gz" + moving_image_path = tmp_path / f"moving_frame_{i:03d}.nii.gz" + + fixed_img = nb.Nifti1Image(predicted, dataset.affine) + moving_img = nb.Nifti1Image(dataset[i][0], dataset.affine) + + moving_img = nb.as_closest_canonical(moving_img, enforce_diag=True) + + fixed_img.to_filename(fixed_image_path) + moving_img.to_filename(moving_image_path) + + registration_config = files("nifreeze.registration.config").joinpath( + "pet-to-pet_level1.json" ) - # Build a temporary dataset excluding the test frame - train_dataset = PET( - dataobj=train_data, - affine=pet_dataset.affine, - brainmask=pet_dataset.brainmask, - midframe=train_times, - total_duration=pet_dataset.total_duration, - ) - - # Instantiate the PET model explicitly - model = BSplinePETModel(dataset=train_dataset) - - # Fit the model once on the training dataset - model.fit_predict(None) - - # Predict the reference volume at the test frame's timepoint - predicted = model.fit_predict(idx) - - fixed_image_path = tmp_path / f"fixed_frame_{idx:03d}.nii.gz" - moving_image_path = tmp_path / f"moving_frame_{idx:03d}.nii.gz" - - fixed_img = nb.Nifti1Image(predicted, pet_dataset.affine) - moving_img = nb.Nifti1Image(test_data, pet_dataset.affine) - - moving_img = nb.as_closest_canonical(moving_img, enforce_diag=True) - - fixed_img.to_filename(fixed_image_path) - moving_img.to_filename(moving_image_path) - - registration_config = files("nifreeze.registration.config").joinpath( - "pet-to-pet_level1.json" - ) - - registration = Registration( - from_file=registration_config, - fixed_image=str(fixed_image_path), - moving_image=str(moving_image_path), - output_warped_image=True, - output_transform_prefix=f"ants_{idx:03d}", - **self.align_kwargs, - ) - - try: - result = registration.run(cwd=str(tmp_path)) - if result.outputs.forward_transforms: - transform = nt.io.itk.ITKLinearTransform.from_filename( - result.outputs.forward_transforms[0] - ) - matrix = transform.to_ras( - reference=str(fixed_image_path), moving=str(moving_image_path) - ) - affine_matrices.append(matrix) - else: - affine_matrices.append(np.eye(4)) - print(f"No transforms produced for index {idx}") - except Exception as e: - affine_matrices.append(np.eye(4)) - print(f"Failed to process frame {idx} due to {e}") - - return affine_matrices + registration = Registration( + from_file=registration_config, + fixed_image=str(fixed_image_path), + moving_image=str(moving_image_path), + output_warped_image=True, + output_transform_prefix=f"ants_{i:03d}", + **self._align_kwargs, + ) + + try: + result = registration.run(cwd=str(tmp_path)) + if result.outputs.forward_transforms: + transform = nt.io.itk.ITKLinearTransform.from_filename( + result.outputs.forward_transforms[0] + ) + matrix = transform.to_ras( + reference=str(fixed_image_path), moving=str(moving_image_path) + ) + affine_matrices.append(matrix) + else: + affine_matrices.append(np.eye(4)) + print(f"No transforms produced for index {i}") + except Exception as e: + print(f"Failed to process frame {i} due to {e}") + + pbar.update() + + return affine_matrices diff --git a/test/test_integration_pet.py b/test/test_integration_pet.py index 470e32b93..76a7ed8f0 100644 --- a/test/test_integration_pet.py +++ b/test/test_integration_pet.py @@ -30,26 +30,7 @@ from nifreeze.data.pet import PET from nifreeze.estimator import PETMotionEstimator - - -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 13.0, 17.0, 30.0, 33.0])) -def test_lofo_split_shapes(tmp_path, setup_random_pet_data): - pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data - - pet_obj = PET( - dataobj=pet_dataobj, - affine=affine, - brainmask=brainmask_dataobj, - midframe=midframe, - total_duration=total_duration, - ) - - idx = 2 - (train_data, train_times), (test_data, test_time) = pet_obj.lofo_split(idx) - assert train_data.shape[-1] == pet_obj.dataobj.shape[-1] - 1 - np.testing.assert_array_equal(test_data, pet_obj.dataobj[..., idx]) - np.testing.assert_array_equal(train_times, np.delete(pet_obj.midframe, idx)) - assert test_time == pet_obj.midframe[idx] +from nifreeze.model.base import BaseModel @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0])) @@ -101,16 +82,17 @@ def test_pet_motion_estimator_run(monkeypatch, setup_random_pet_data): total_duration=total_duration, ) - class DummyModel: + class DummyModel(BaseModel): def __init__(self, dataset): + super().__init__(dataset) self.dataset = dataset - def fit_predict(self, index): + def fit_predict(self, index=None, **kawargs): if index is None: return None return np.zeros(self.dataset.shape3d, dtype=np.float32) - monkeypatch.setattr("nifreeze.estimator.BSplinePETModel", DummyModel) + model = DummyModel(pet_obj) class DummyRegistration: def __init__(self, *args, **kwargs): @@ -121,7 +103,7 @@ def run(self, cwd=None): monkeypatch.setattr("nifreeze.estimator.Registration", DummyRegistration) - estimator = PETMotionEstimator(None) + estimator = PETMotionEstimator(model) affines = estimator.run(pet_obj) assert len(affines) == len(pet_obj) for mat in affines: