diff --git a/docs/notebooks/pet_motion_estimation.ipynb b/docs/notebooks/pet_motion_estimation.ipynb index ba528406d..3e22afbd7 100644 --- a/docs/notebooks/pet_motion_estimation.ipynb +++ b/docs/notebooks/pet_motion_estimation.ipynb @@ -406,9 +406,9 @@ } ], "source": [ - "from nifreeze.model import PETModel\n", + "from nifreeze.model import BSplinePETModel\n", "\n", - "model = PETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)" + "model = BSplinePETModel(dataset=pet_dataset)" ] }, { @@ -429,7 +429,7 @@ "outputs": [], "source": [ "index = 2\n", - "predicted = model.fit_predict(pet_dataset.midframe[index])" + "predicted = model.fit_predict(index)" ] }, { diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 93622bf40..488c78c3f 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -40,7 +40,7 @@ from nifreeze.data.base import BaseDataset from nifreeze.data.pet import PET from nifreeze.model.base import BaseModel, ModelFactory -from nifreeze.model.pet import PETModel +from nifreeze.model.pet import BSplinePETModel from nifreeze.registration.ants import ( Registration, _prepare_registration_data, @@ -234,7 +234,7 @@ def __init__(self, align_kwargs: dict | None = None, strategy: str = "lofo"): def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list: n_frames = len(pet_dataset) - frame_indices = np.arange(n_frames) + frame_indices = np.arange(n_frames).astype(int) if omp_nthreads: self.align_kwargs["num_threads"] = omp_nthreads @@ -261,18 +261,14 @@ def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list: total_duration=pet_dataset.total_duration, ) - # Instantiate PETModel explicitly - model = PETModel( - dataset=train_dataset, - timepoints=train_times, - xlim=pet_dataset.total_duration, - ) + # Instantiate the PET model explicitly + model = BSplinePETModel(dataset=train_dataset) # Fit the model once on the training dataset model.fit_predict(None) # Predict the reference volume at the test frame's timepoint - predicted = model.fit_predict(test_time) + predicted = model.fit_predict(idx) fixed_image_path = tmp_path / f"fixed_frame_{idx:03d}.nii.gz" moving_image_path = tmp_path / f"moving_frame_{idx:03d}.nii.gz" diff --git a/src/nifreeze/model/__init__.py b/src/nifreeze/model/__init__.py index 42c41217f..2bdaccb52 100644 --- a/src/nifreeze/model/__init__.py +++ b/src/nifreeze/model/__init__.py @@ -33,7 +33,7 @@ DTIModel, GPModel, ) -from nifreeze.model.pet import PETModel +from nifreeze.model.pet import BSplinePETModel __all__ = ( "ModelFactory", @@ -43,5 +43,5 @@ "DTIModel", "GPModel", "TrivialModel", - "PETModel", + "BSplinePETModel", ) diff --git a/src/nifreeze/model/pet.py b/src/nifreeze/model/pet.py index c4668bd5a..c38ff9c6d 100644 --- a/src/nifreeze/model/pet.py +++ b/src/nifreeze/model/pet.py @@ -22,6 +22,7 @@ # """Models for nuclear imaging.""" +from abc import ABC, ABCMeta, abstractmethod from os import cpu_count from typing import Union @@ -35,125 +36,160 @@ from nifreeze.data.pet import PET from nifreeze.model.base import BaseModel -TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG = """\ -'timepoints' and 'xlim' must be specified, found: {timepoints} and {xlim}.""" -"""PET model underspecification error.""" +PET_OBJECT_ERROR_MSG = "Dataset MUST be a PET object." +"""PET object error message.""" -FIRST_TIMEPOINT_VALUE_ERROR_MSG = """\ -First frame 'timepoint' should not be zero or negative, found: {timepoints}.""" -"""PET model timepoint value error message.""" - -LAST_TIMEPOINT_CONSISTENCY_ERROR_MSG = """\ -Last frame 'timepoints' value should not be equal or greater than 'xlim' \ -duration, found: {timepoints} and {xlim}.""" -"""PET model parameter consistency error message.""" +PET_MIDFRAME_ERROR_MSG = "Dataset MUST have a 'midframe'." +"""PET midframe error message.""" DEFAULT_TIMEPOINT_TOL = 1e-2 """Time frame tolerance in seconds.""" -class PETModel(BaseModel): +def _exec_fit(model, data, chunk=None, **kwargs): + return model.fit(data, **kwargs), chunk + + +def _exec_predict(model, chunk=None, **kwargs): + """Propagate model parameters and call predict.""" + return np.squeeze(model.predict(**kwargs)), chunk + + +class BasePETModel(BaseModel, ABC): + """Interface and default methods for PET models.""" + + __metaclass__ = ABCMeta + + __slots__ = { + "_data_mask": "A mask for the voxels that will be fitted and predicted", + "_smooth_fwhm": "FWHM in mm over which to smooth", + "_thresh_pct": "Thresholding percentile for the signal", + "_model_class": "Defining a model class", + "_modelargs": "Arguments acceptable by the underlying model", + "_models": "List with one or more (if parallel execution) model instances", + } + + def __init__( + self, + dataset: PET, + smooth_fwhm: float = 10.0, + thresh_pct: float = 20.0, + **kwargs, + ): + """Initialization. + + Parameters + ---------- + smooth_fwhm : obj:`float` + FWHM in mm over which to smooth the signal. + thresh_pct : obj:`float` + Thresholding percentile for the signal. + """ + + super().__init__(dataset, **kwargs) + + # Duck typing, instead of explicitly testing for PET type + if not hasattr(dataset, "total_duration"): + raise TypeError(PET_OBJECT_ERROR_MSG) + + if not hasattr(dataset, "midframe"): + raise ValueError(PET_MIDFRAME_ERROR_MSG) + + self._data_mask = ( + dataset.brainmask + if dataset.brainmask is not None + else np.ones(dataset.dataobj.shape[:3], dtype=bool) + ) + + self._smooth_fwhm = smooth_fwhm + self._thresh_pct = thresh_pct + + def _preprocess_data(self) -> np.ndarray: + # ToDo + # data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed + data = self._dataset.dataobj + brainmask = self._dataset.brainmask + + # Preprocess the data + if self._smooth_fwhm > 0: + smoothed_img = smooth_image( + nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm + ) + data = smoothed_img.get_fdata() + + if self._thresh_pct > 0: + thresh_val = np.percentile(data, self._thresh_pct) + data[data < thresh_val] = 0 + + # Convert data into V (voxels) x T (timepoints) + return data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask] + + @property + def is_fitted(self) -> bool: + return self._locked_fit is not None + + @abstractmethod + def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]: + """Predict the corrected volume.""" + return None + + +class BSplinePETModel(BasePETModel): """A PET imaging realignment model based on B-Spline approximation.""" __slots__ = ( "_t", - "_x", - "_xlim", "_order", "_n_ctrl", - "_datashape", - "_mask", - "_smooth_fwhm", - "_thresh_pct", ) def __init__( self, dataset: PET, - timepoints: list | np.ndarray, - xlim: float, n_ctrl: int | None = None, order: int = 3, - smooth_fwhm: float = 10.0, - thresh_pct: float = 20.0, **kwargs, ): - """ - Create the B-Spline interpolating matrix. - - Parameters: - ----------- - timepoints : :obj:`list` - The timing (in sec) of each PET volume. - E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., - 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` + """Create the B-Spline interpolating matrix. + Parameters + ---------- n_ctrl : :obj:`int` Number of B-Spline control points. If `None`, then one control point every six timepoints will be used. The less control points, the smoother is the model. - + order : :obj:`int` + Order of the B-Spline approximation. """ - super().__init__(dataset, **kwargs) - - if timepoints is None or xlim is None: - raise ValueError( - TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG.format(timepoints=timepoints, xlim=xlim) - ) - if timepoints[0] < DEFAULT_TIMEPOINT_TOL: - raise ValueError(FIRST_TIMEPOINT_VALUE_ERROR_MSG.format(timepoints=timepoints[0])) - - if timepoints[-1] > xlim - DEFAULT_TIMEPOINT_TOL: - raise ValueError( - LAST_TIMEPOINT_CONSISTENCY_ERROR_MSG.format(timepoints=timepoints, xlim=xlim) - ) + super().__init__(dataset, **kwargs) self._order = order - self._x = np.array(timepoints, dtype="float32") - self._xlim = xlim - self._smooth_fwhm = smooth_fwhm - self._thresh_pct = thresh_pct # Calculate index coordinates in the B-Spline grid - self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 + self._n_ctrl = n_ctrl or (len(self._dataset.midframe) // 4) + 1 # B-Spline knots - self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") - - self._datashape = None - self._mask = None - - @property - def is_fitted(self) -> bool: - return self._locked_fit is not None + self._t = np.arange(-3, self._n_ctrl + 4, dtype="float32") def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int: """Fit the model.""" + n_jobs = n_jobs or min(cpu_count() or 1, 8) + if self._locked_fit is not None: return n_jobs if index is not None: raise NotImplementedError("Fitting with held-out data is not supported") - timepoints = kwargs.get("timepoints", None) or self._x - x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl) - - data = self._dataset.dataobj - brainmask = self._dataset.brainmask - - if self._smooth_fwhm > 0: - smoothed_img = smooth_image( - nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm - ) - data = smoothed_img.get_fdata() - if self._thresh_pct > 0: - thresh_val = np.percentile(data, self._thresh_pct) - data[data < thresh_val] = 0 + data = self._preprocess_data() - # Convert data into V (voxels) x T (timepoints) - data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask] + x = ( + np.asarray(self._dataset.midframe, dtype="float32") + / self._dataset.total_duration + * self._n_ctrl + ) # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) A = BSpline.design_matrix(x, self._t, k=self._order) @@ -161,16 +197,22 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int: ATdotA = AT @ A # Parallelize process with joblib - with Parallel(n_jobs=n_jobs or min(cpu_count() or 1, 8)) as executor: + with Parallel(n_jobs=n_jobs) as executor: results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) - self._locked_fit = np.array([r[0] for r in results]) + self._locked_fit = np.asarray([r[0] for r in results]) return n_jobs def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]: """Return the corrected volume using B-spline interpolation.""" + # ToDo + # Does the below apply to PET ? Martin has the return None statement + # if index is None: + # raise RuntimeError( + # f"Model {self.__class__.__name__} does not allow locking.") + # Fit the BSpline basis on all data if self._locked_fit is None: self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs) @@ -183,7 +225,10 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N return None # Project sample timing into B-Spline coordinates - x = np.asarray((index / self._xlim) * self._n_ctrl) + # ToDo: x is not really a matrix ... + x = np.asarray( + (self._dataset.midframe[index] / self._dataset.total_duration) * self._n_ctrl + ) A = BSpline.design_matrix(x, self._t, k=self._order) # A is 1 (num. timepoints) x C (num. coeff) diff --git a/test/test_integration_pet.py b/test/test_integration_pet.py index 207de2831..470e32b93 100644 --- a/test/test_integration_pet.py +++ b/test/test_integration_pet.py @@ -102,7 +102,7 @@ def test_pet_motion_estimator_run(monkeypatch, setup_random_pet_data): ) class DummyModel: - def __init__(self, dataset, timepoints, xlim): + def __init__(self, dataset): self.dataset = dataset def fit_predict(self, index): @@ -110,7 +110,7 @@ def fit_predict(self, index): return None return np.zeros(self.dataset.shape3d, dtype=np.float32) - monkeypatch.setattr("nifreeze.estimator.PETModel", DummyModel) + monkeypatch.setattr("nifreeze.estimator.BSplinePETModel", DummyModel) class DummyRegistration: def __init__(self, *args, **kwargs): diff --git a/test/test_model_pet.py b/test/test_model_pet.py index 852cffb78..77747a47c 100644 --- a/test/test_model_pet.py +++ b/test/test_model_pet.py @@ -22,97 +22,91 @@ # import re -import sys import numpy as np import pytest +from nifreeze.data.base import BaseDataset from nifreeze.data.pet import PET from nifreeze.model.pet import ( - DEFAULT_TIMEPOINT_TOL, - FIRST_TIMEPOINT_VALUE_ERROR_MSG, - LAST_TIMEPOINT_CONSISTENCY_ERROR_MSG, - TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG, - PETModel, + PET_MIDFRAME_ERROR_MSG, + PET_OBJECT_ERROR_MSG, + BSplinePETModel, ) -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([1.0, 2.0, 3.0, 4.0, 5.0])) -@pytest.mark.parametrize( - "none_params", [("timepoints",), ("xlim",), ("timepoints=timepoints", "xlim")] -) -def test_petmodel_init_parameters_error(request, setup_random_pet_data, none_params): - rng = request.node.rng - pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data - - pet_obj = PET( - dataobj=pet_dataobj, - affine=affine, - brainmask=brainmask_dataobj, - midframe=midframe, - total_duration=total_duration, - ) - - timepoints = rng.random(len(pet_obj)) if "timepoints" in none_params else None - xlim = rng.random(1).item() if "xlim" in none_params else None +def test_pet_base_model(): + from nifreeze.model.pet import BasePETModel with pytest.raises( - ValueError, + TypeError, match=re.escape( - TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG.format(timepoints=timepoints, xlim=xlim) + "Can't instantiate abstract class BasePETModel without an implementation " + "for abstract method 'fit_predict'" ), ): - PETModel(dataset=pet_obj, timepoints=timepoints, xlim=xlim) # type: ignore[arg-type] + BasePETModel(None, xlim=None) # type: ignore[abstract, arg-type] @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([1.0, 2.0, 3.0, 4.0, 5.0])) -def test_petmodel_init_timepoint_value_error(request, setup_random_pet_data): +def test_petmodel_init_dataset_error(request, setup_random_pet_data, monkeypatch): rng = request.node.rng - pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + pet_dataobj, _affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data - pet_obj = PET( - dataobj=pet_dataobj, - affine=affine, - brainmask=brainmask_dataobj, - midframe=midframe, - total_duration=total_duration, - ) + xlim = rng.random(pet_dataobj.shape[-1]) - timepoints = rng.random(len(pet_obj)) - xlim = rng.random(1).item() + # Create a dummy dataset class without attributes + class AttributelessPETDataset(BaseDataset[np.ndarray]): + def __init__(self, dataobj, affine, brainmask): + self.dataobj = dataobj + self.affine = affine + self.brainmask = brainmask - timepoints[0] = DEFAULT_TIMEPOINT_TOL - sys.float_info.epsilon + # Monkeypatch the PET dataset + monkeypatch.setattr("nifreeze.data.pet.PET", AttributelessPETDataset) - with pytest.raises( - ValueError, match=FIRST_TIMEPOINT_VALUE_ERROR_MSG.format(timepoints=timepoints) - ): - PETModel(dataset=pet_obj, timepoints=timepoints, xlim=xlim) + pet_obj_attless = AttributelessPETDataset( + dataobj=pet_dataobj, affine=_affine, brainmask=brainmask_dataobj + ) + with pytest.raises(TypeError, match=PET_OBJECT_ERROR_MSG): + BSplinePETModel(dataset=pet_obj_attless, xlim=xlim) # type:ignore[arg-type] -@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([1.0, 2.0, 3.0, 4.0, 5.0])) -def test_petmodel_parameter_consistency_error(request, setup_random_pet_data): - rng = request.node.rng - pet_dataobj, affine, brainmask_dataobj, _, midframe, total_duration = setup_random_pet_data + # Create a dummy dataset class without total_duration data + class MidframePETDataset(BaseDataset[np.ndarray]): + def __init__(self, dataobj, affine, brainmask): + self.dataobj = dataobj + self.affine = affine + self.brainmask = brainmask + self.midframe = np.ones_like(dataobj.shape[-1]) - pet_obj = PET( - dataobj=pet_dataobj, - affine=affine, - brainmask=brainmask_dataobj, - midframe=midframe, - total_duration=total_duration, + # Monkeypatch the PET dataset + monkeypatch.setattr("nifreeze.data.pet.PET", MidframePETDataset) + + pet_obj_midf = MidframePETDataset( + dataobj=pet_dataobj, affine=_affine, brainmask=brainmask_dataobj ) - xlim = rng.random(1).item() - timepoints = np.ones(len(pet_obj)) * DEFAULT_TIMEPOINT_TOL - timepoints[-1] = xlim - DEFAULT_TIMEPOINT_TOL + sys.float_info.epsilon + with pytest.raises(TypeError, match=PET_OBJECT_ERROR_MSG): + BSplinePETModel(dataset=pet_obj_midf, xlim=xlim) # type:ignore[arg-type] - with pytest.raises( - ValueError, - match=re.escape( - LAST_TIMEPOINT_CONSISTENCY_ERROR_MSG.format(timepoints=timepoints, xlim=xlim) - ), - ): - PETModel(dataset=pet_obj, timepoints=timepoints, xlim=xlim) + # Create a dummy dataset class without midframe data + class TotalDurationPETDataset(BaseDataset[np.ndarray]): + def __init__(self, dataobj, affine, brainmask): + self.dataobj = dataobj + self.affine = affine + self.brainmask = brainmask + self.total_duration = np.ones_like(dataobj.shape[-1]) + + # Monkeypatch the PET dataset + monkeypatch.setattr("nifreeze.data.pet.PET", TotalDurationPETDataset) + + pet_obj_totald = TotalDurationPETDataset( + dataobj=pet_dataobj, affine=_affine, brainmask=brainmask_dataobj + ) + + with pytest.raises(ValueError, match=PET_MIDFRAME_ERROR_MSG): + BSplinePETModel(dataset=pet_obj_totald, xlim=xlim) # type:ignore[arg-type] @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0])) @@ -127,20 +121,15 @@ def test_petmodel_fit_predict(setup_random_pet_data): total_duration=total_duration, ) - model = PETModel( - dataset=pet_obj, - timepoints=pet_obj.midframe, - xlim=pet_obj.total_duration, - smooth_fwhm=0, - thresh_pct=0, - ) + model = BSplinePETModel(dataset=pet_obj, smooth_fwhm=0, thresh_pct=0) # Fit on all data model.fit_predict(None) assert model.is_fitted # Predict at a specific timepoint - vol = model.fit_predict(pet_obj.midframe[2]) + index = 2 + vol = model.fit_predict(index) assert vol is not None assert vol.shape == pet_obj.shape3d assert vol.dtype == pet_obj.dataobj.dtype