diff --git a/src/dodal/beamlines/i05.py b/src/dodal/beamlines/i05.py index cba83c2b6cb..1ad7d939e36 100644 --- a/src/dodal/beamlines/i05.py +++ b/src/dodal/beamlines/i05.py @@ -2,8 +2,9 @@ from dodal.common.beamlines.beamline_utils import set_beamline as set_utils_beamline from dodal.device_manager import DeviceManager from dodal.devices.beamlines.i05 import I05Goniometer -from dodal.devices.beamlines.i05_shared import M4M5Mirror +from dodal.devices.beamlines.i05_shared import LensMode, M4M5Mirror, PassEnergy from dodal.devices.common_mirror import XYZSwitchingMirror +from dodal.devices.electron_analyser.mbs import MbsAnalyserDriverIO from dodal.devices.hutch_shutter import HutchShutter from dodal.devices.temperture_controller import Lakeshore336 from dodal.log import set_beamline as set_log_beamline @@ -46,3 +47,12 @@ def sa() -> I05Goniometer: y_infix="SAY", z_infix="SAZ", ) + + +@devices.factory +def analyser_driver() -> MbsAnalyserDriverIO: + return MbsAnalyserDriverIO[LensMode, PassEnergy]( + prefix=f"{PREFIX.beamline_prefix}-EA-DET-02:CAM:", + lens_mode_type=LensMode, + pass_energy_type=PassEnergy, + ) diff --git a/src/dodal/beamlines/i05_1.py b/src/dodal/beamlines/i05_1.py index 763b64988fb..0623dc15c26 100644 --- a/src/dodal/beamlines/i05_1.py +++ b/src/dodal/beamlines/i05_1.py @@ -2,8 +2,9 @@ from dodal.common.beamlines.beamline_utils import set_beamline as set_utils_beamline from dodal.device_manager import DeviceManager from dodal.devices.beamlines.i05_1 import XYZAzimuthPolarDefocusStage -from dodal.devices.beamlines.i05_shared import Mj7j8Mirror +from dodal.devices.beamlines.i05_shared import LensMode, Mj7j8Mirror, PassEnergy from dodal.devices.common_mirror import XYZPiezoSwitchingMirror +from dodal.devices.electron_analyser.mbs import MbsAnalyserDriverIO from dodal.devices.hutch_shutter import HutchShutter from dodal.log import set_beamline as set_log_beamline from dodal.utils import BeamlinePrefix, get_beamline_name @@ -35,3 +36,12 @@ def nano_shutter() -> HutchShutter: def sm() -> XYZAzimuthPolarDefocusStage: """Sample Manipulator.""" return XYZAzimuthPolarDefocusStage(prefix=f"{PREFIX.beamline_prefix}-EA-SM-01:") + + +@devices.factory +def analyser_driver() -> MbsAnalyserDriverIO[LensMode, PassEnergy]: + return MbsAnalyserDriverIO[LensMode, PassEnergy]( + prefix=f"{PREFIX.beamline_prefix}-EA-DET-04:CAM:", + lens_mode_type=LensMode, + pass_energy_type=PassEnergy, + ) diff --git a/src/dodal/common/data_util.py b/src/dodal/common/data_util.py index 7a0b3ab7ad2..989641940f6 100644 --- a/src/dodal/common/data_util.py +++ b/src/dodal/common/data_util.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from os.path import isabs, isfile, join, split -from typing import Protocol, Self, TypeVar +from typing import Generic, Self, TypeVar from pydantic import BaseModel @@ -36,11 +37,14 @@ def save_class_to_json_file(model: BaseModel, file: str) -> None: f.write(model.model_dump_json()) -class JsonModelLoader(Protocol[TBaseModel]): - def __call__(self, file: str | None = None) -> TBaseModel: ... +LoadModelFromFile = Callable[[str], TBaseModel] -class JsonLoaderConfig(BaseModel): +def json_model_loader(model: type[TBaseModel]) -> LoadModelFromFile[TBaseModel]: + return lambda file: load_json_file_to_class(model, file) + + +class ModelLoaderConfig(BaseModel): default_path: str default_file: str | None @@ -60,15 +64,40 @@ def update_config_from_file(self, new_file: str) -> None: self.default_path, self.default_file = split(new_file) -def json_model_loader( - model: type[TBaseModel], config: JsonLoaderConfig | None = None -) -> JsonModelLoader[TBaseModel]: - """Factory to create a function that loads a json file into a configured pydantic - model and with optional configuration for default path and file to use. +class ModelLoader(Generic[TBaseModel]): + """A generic model loader that can be configured with any kind of method to read in + a file and convert the data into a pydantic model. It also takes configuration + to handle the file paths before they are passed to the method to convert to a + pydantic model. """ - def load_json(file: str | None = None) -> TBaseModel: - """Load a json file and return it is as the configured pydantic model. + def __init__( + self, + load_model_from_file: LoadModelFromFile[TBaseModel], + cfg: ModelLoaderConfig | None = None, + ): + self._load_model_from_file = load_model_from_file + self._cfg = cfg + + def _handle_file_path(self, file: str | None) -> str: + """Handle the file path based on the configuration provided. If a default path + is given and a relative file path used, it will join the default path and + relative path together. If a default file is configured, then you don't need to + provide a file when using __call__. + """ + if file is None: + if self._cfg is None or self._cfg.default_file is None: + raise RuntimeError( + "Model loader has no default file configured and no file was provided." + ) + file = self._cfg.default_file + + if not isabs(file) and self._cfg is not None: + file = join(self._cfg.default_path, file) + return file + + def __call__(self, file: str | None = None) -> TBaseModel: + """Load a file and return it is as the configured pydantic model. Args: file (str, optional): The file to load into a pydantic class. If None @@ -77,16 +106,5 @@ def load_json(file: str | None = None) -> TBaseModel: Returns: An instance of the configurated pydantic base_model type. """ - if file is None: - if config is None or config.default_file is None: - raise RuntimeError( - f"{model.__name__} loader has no default file configured " - "and no file was provided." - ) - file = config.default_file - - if not isabs(file) and config is not None: - file = join(config.default_path, file) - return load_json_file_to_class(model, file) - - return load_json + file = self._handle_file_path(file) + return self._load_model_from_file(file) diff --git a/src/dodal/devices/beamlines/i05/__init__.py b/src/dodal/devices/beamlines/i05/__init__.py index 760a9c47a47..3675d760ef6 100644 --- a/src/dodal/devices/beamlines/i05/__init__.py +++ b/src/dodal/devices/beamlines/i05/__init__.py @@ -1,5 +1,3 @@ from .i05_motors import I05Goniometer -__all__ = [ - "I05Goniometer", -] +__all__ = ["I05Goniometer"] diff --git a/src/dodal/devices/beamlines/i05_shared/__init__.py b/src/dodal/devices/beamlines/i05_shared/__init__.py index c25434d7eae..5014d0d7a48 100644 --- a/src/dodal/devices/beamlines/i05_shared/__init__.py +++ b/src/dodal/devices/beamlines/i05_shared/__init__.py @@ -1,18 +1,14 @@ -from dodal.devices.beamlines.i05_shared.apple_knot_constants import ( +from .apple_knot_constants import ( APPLE_KNOT_EXCLUSION_ZONES, energy_to_gap_converter, energy_to_phase_converter, ) -from dodal.devices.beamlines.i05_shared.compound_motors import PolynomCompoundMotors -from dodal.devices.beamlines.i05_shared.enums import ( - Grating, - M3MJ6Mirror, - M4M5Mirror, - Mj7j8Mirror, -) +from .compound_motors import PolynomCompoundMotors +from .enums import Grating, LensMode, M3MJ6Mirror, M4M5Mirror, Mj7j8Mirror, PassEnergy __all__ = [ "Grating", + "LensMode", "Mj7j8Mirror", "M3MJ6Mirror", "M4M5Mirror", @@ -20,4 +16,5 @@ "energy_to_gap_converter", "energy_to_phase_converter", "APPLE_KNOT_EXCLUSION_ZONES", + "PassEnergy", ] diff --git a/src/dodal/devices/beamlines/i05_shared/enums.py b/src/dodal/devices/beamlines/i05_shared/enums.py index dd97882c22d..be109fd58f3 100644 --- a/src/dodal/devices/beamlines/i05_shared/enums.py +++ b/src/dodal/devices/beamlines/i05_shared/enums.py @@ -27,3 +27,22 @@ class Mj7j8Mirror(StrictEnum): MJ8 = "MJ8" MJ7 = "MJ7" REFERENCE = "Reference" + + +class LensMode(StrictEnum): + L4_ANG0_D8 = "L4Ang0d8" + L4_ANG1_D6 = "L4Ang1d6" + L4_ANG3_D9 = "L4Ang3d9" + L4M_ANG0_D7 = "L4MAng0d7" + L4M_SPAT_5 = "L4MSpat5" + + +class PassEnergy(StrictEnum): + PE001 = "PE001" + PE002 = "PE002" + PE005 = "PE005" + PE010 = "PE010" + PE020 = "PE020" + PE050 = "PE050" + PE100 = "PE100" + PE200 = "PE200" diff --git a/src/dodal/devices/electron_analyser/base/base_driver_io.py b/src/dodal/devices/electron_analyser/base/base_driver_io.py index 3849ae4b5f2..0a77ac56da5 100644 --- a/src/dodal/devices/electron_analyser/base/base_driver_io.py +++ b/src/dodal/devices/electron_analyser/base/base_driver_io.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeAlias, TypeVar +from dataclasses import dataclass +from typing import ClassVar, Generic, TypeAlias, TypeVar import numpy as np from bluesky.protocols import Movable @@ -29,10 +30,27 @@ TPassEnergy, ) -AnyPsuMode: TypeAlias = SupersetEnum | StrictEnum +AnyPsuMode: TypeAlias = SupersetEnum | StrictEnum | str TPsuMode = TypeVar("TPsuMode", bound=AnyPsuMode) -_PSU = "PSU_MODE" + +@dataclass(frozen=True) +class ElectronAnalyserPVConfig: + """Configuration for PV's. Temporary work around until PV's are standardised between + beamlines. + """ + + low_energy: str = "LOW_ENERGY" + high_energy: str = "HIGH_ENERGY" + centre_energy: str = "CENTRE_ENERGY" + slices: str = "SLICES" + lens_mode: str = "LENS_MODE" + pass_energy: str = "PASS_ENERGY" + energy_step: str = "STEP_SIZE" + iterations: str = "NumExposures" + acquisition_mode: str = "ACQ_MODE" + psu_mode: str = "PSU_MODE" + total_steps: str = "TOTAL_POINTS_RBV" class AbstractAnalyserDriverIO( @@ -57,10 +75,11 @@ class AbstractAnalyserDriverIO( pass_energy_type (type[TPassEnergy]): Can be enum or float, depending on electron analyser model. If enum, it determines the available pass energies for this device. - psu_suffix (str, optional): The psu infix to connect to EPICS. Defaults to PSU_MODE. name (str, optional): Name of the device. """ + PV_CFG: ClassVar[ElectronAnalyserPVConfig] + def __init__( self, prefix: str, @@ -68,7 +87,6 @@ def __init__( lens_mode_type: type[TLensMode], psu_mode_type: type[TPsuMode], pass_energy_type: type[TPassEnergy], - psu_suffix: str = _PSU, name: str = "", ) -> None: self.acquisition_mode_type = acquisition_mode_type @@ -76,16 +94,9 @@ def __init__( self.psu_mode_type = psu_mode_type self.pass_energy_type = pass_energy_type - # must call first to initiate parent variables + # Must call first to initiate parent variables super().__init__(prefix=prefix, name=name) - with self.add_children_as_readables(): - self.image = epics_signal_r(Array1D[np.float64], prefix + "IMAGE") - self.spectrum = epics_signal_r(Array1D[np.float64], prefix + "INT_SPECTRUM") - self.total_intensity = derived_signal_r( - self._calculate_total_intensity, spectrum=self.spectrum - ) - with self.add_children_as_readables(StandardReadableFormat.CONFIG_SIGNAL): # Read once per scan after data acquired # Used for setting up region data acquisition @@ -93,20 +104,27 @@ def __init__( self.energy_mode = soft_signal_rw( EnergyMode, initial_value=EnergyMode.KINETIC ) - self.low_energy = epics_signal_rw(float, prefix + "LOW_ENERGY") - self.centre_energy = epics_signal_rw(float, prefix + "CENTRE_ENERGY") - self.high_energy = epics_signal_rw(float, prefix + "HIGH_ENERGY") - self.slices = epics_signal_rw(int, prefix + "SLICES") - self.lens_mode = epics_signal_rw(lens_mode_type, prefix + "LENS_MODE") - self.pass_energy = epics_signal_rw(pass_energy_type, prefix + "PASS_ENERGY") - self.energy_step = epics_signal_rw(float, prefix + "STEP_SIZE") - self.iterations = epics_signal_rw(int, prefix + "NumExposures") + self.low_energy = epics_signal_rw(float, prefix + self.PV_CFG.low_energy) + self.centre_energy = epics_signal_rw( + float, prefix + self.PV_CFG.centre_energy + ) + self.high_energy = epics_signal_rw(float, prefix + self.PV_CFG.high_energy) + self.slices = epics_signal_rw(int, prefix + self.PV_CFG.slices) + self.lens_mode = epics_signal_rw( + lens_mode_type, prefix + self.PV_CFG.lens_mode + ) + self.pass_energy = epics_signal_rw( + pass_energy_type, prefix + self.PV_CFG.pass_energy + ) + self.energy_step = epics_signal_rw(float, prefix + self.PV_CFG.energy_step) + self.iterations = epics_signal_rw(int, prefix + self.PV_CFG.iterations) self.acquisition_mode = epics_signal_rw( - acquisition_mode_type, prefix + "ACQ_MODE" + acquisition_mode_type, prefix + self.PV_CFG.acquisition_mode ) - # This is used by each electron analyser, however it depends on the electron - # analyser type to know if is moved with region settings. - self.psu_mode = epics_signal_rw(psu_mode_type, prefix + psu_suffix) + # This is used by each electron analyser, however it is not writeable for + # all types and it depends on the electron analyser type to know if is moved + # with region settings. + self.psu_mode = epics_signal_r(psu_mode_type, prefix + self.PV_CFG.psu_mode) # This is defined in the parent class, add it as readable configuration. self.add_readables([self.acquire_time], StandardReadableFormat.CONFIG_SIGNAL) @@ -115,7 +133,7 @@ def __init__( # NOT used for setting up region data acquisition. self.energy_axis = self._create_energy_axis_signal(prefix) self.angle_axis = self._create_angle_axis_signal(prefix) - self.total_steps = epics_signal_r(int, prefix + "TOTAL_POINTS_RBV") + self.total_steps = epics_signal_r(int, prefix + self.PV_CFG.total_steps) self.total_time = derived_signal_r( self._calculate_total_time, "s", diff --git a/src/dodal/devices/electron_analyser/base/base_region.py b/src/dodal/devices/electron_analyser/base/base_region.py index 2496a37a176..29563bd0ec4 100644 --- a/src/dodal/devices/electron_analyser/base/base_region.py +++ b/src/dodal/devices/electron_analyser/base/base_region.py @@ -74,7 +74,6 @@ class BaseRegion( name: str = "New_region" enabled: bool = False - slices: int = 1 iterations: int = 1 excitation_energy_source: SelectedSource = SelectedSource.SOURCE1 # These ones we need subclasses to provide sensible default values @@ -82,8 +81,8 @@ class BaseRegion( pass_energy: TPassEnergy acquisition_mode: TAcquisitionMode low_energy: float - centre_energy: float high_energy: float + centre_energy: float acquire_time: float energy_step: float # in eV energy_mode: EnergyMode = EnergyMode.KINETIC diff --git a/src/dodal/devices/electron_analyser/mbs/__init__.py b/src/dodal/devices/electron_analyser/mbs/__init__.py new file mode 100644 index 00000000000..51bf30be1db --- /dev/null +++ b/src/dodal/devices/electron_analyser/mbs/__init__.py @@ -0,0 +1,5 @@ +from .mbs_driver_io import MbsAnalyserDriverIO +from .mbs_enums import AcquisitionMode +from .mbs_region import MbsRegion, MbsSequence + +__all__ = ["MbsAnalyserDriverIO", "AcquisitionMode", "MbsRegion", "MbsSequence"] diff --git a/src/dodal/devices/electron_analyser/mbs/mbs_driver_io.py b/src/dodal/devices/electron_analyser/mbs/mbs_driver_io.py new file mode 100644 index 00000000000..037129a07aa --- /dev/null +++ b/src/dodal/devices/electron_analyser/mbs/mbs_driver_io.py @@ -0,0 +1,99 @@ +import asyncio +from typing import Generic + +import numpy as np +from ophyd_async.core import Array1D, AsyncStatus, SignalR, StandardReadableFormat +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw + +from dodal.devices.electron_analyser.base.base_driver_io import ( + AbstractAnalyserDriverIO, + ElectronAnalyserPVConfig, +) +from dodal.devices.electron_analyser.base.base_region import TLensMode, TPassEnergy +from dodal.devices.electron_analyser.mbs.mbs_enums import AcquisitionMode +from dodal.devices.electron_analyser.mbs.mbs_region import MbsRegion + + +class MbsAnalyserDriverIO( + AbstractAnalyserDriverIO[ + MbsRegion[TLensMode, TPassEnergy], + AcquisitionMode, + TLensMode, + str, + TPassEnergy, + ], + Generic[TLensMode, TPassEnergy], +): + PV_CFG = ElectronAnalyserPVConfig( + lens_mode="LensMode", + pass_energy="PassEnergy", + acquisition_mode="AcqMode", + energy_step="StepSize", + low_energy="StartKE", + centre_energy="CentreKE", + high_energy="EndKE", + psu_mode="PsuMode_RBV", + slices="NumSlice", + iterations="NumExposures", + total_steps="NumSteps", + ) + + def __init__( + self, + prefix: str, + lens_mode_type: type[TLensMode], + pass_energy_type: type[TPassEnergy], + name: str = "", + ): + super().__init__( + prefix=prefix, + acquisition_mode_type=AcquisitionMode, + lens_mode_type=lens_mode_type, + psu_mode_type=str, + pass_energy_type=pass_energy_type, + name=name, + ) + with self.add_children_as_readables(StandardReadableFormat.CONFIG_SIGNAL): + self.deflector_x = epics_signal_rw(float, prefix + "DeflX") + self.dither_steps = epics_signal_rw(int, prefix + "DithSteps") + self.spin_offset = epics_signal_rw(float, prefix + "SpinOffs") + # Region origin + self.min_x = epics_signal_rw(int, prefix + "MinX") + self.min_y = epics_signal_rw(int, prefix + "MinY") + # Sensor size + self.max_x = epics_signal_r(int, prefix + "MaxSizeX_RBV") + self.max_y = epics_signal_r(int, prefix + "MaxSizeY_RBV") + + region_size = [self.array_size_x, self.array_size_y] + self.add_readables( + region_size + [self.acquire_period], StandardReadableFormat.CONFIG_SIGNAL + ) + + def _create_angle_axis_signal(self, prefix: str) -> SignalR[Array1D[np.float64]]: + return epics_signal_r(Array1D[np.float64], prefix + "LensScale_RBV") + + def _create_energy_axis_signal(self, prefix: str) -> SignalR[Array1D[np.float64]]: + return epics_signal_r(Array1D[np.float64], prefix + "EScale_RBV") + + @AsyncStatus.wrap + async def set(self, epics_region: MbsRegion[TLensMode, TPassEnergy]): + coroutines = [ + self.region_name.set(epics_region.name), + self.energy_mode.set(epics_region.energy_mode), + self.acquisition_mode.set(epics_region.acquisition_mode), + self.pass_energy.set(epics_region.pass_energy), + self.lens_mode.set(epics_region.lens_mode), + # Start stop and centre energy are always set even though start and stop are + # used in swept and centre is used in fixed because the readback values are + # saved into the data file. + self.low_energy.set(epics_region.low_energy), + self.centre_energy.set(epics_region.centre_energy), + self.high_energy.set(epics_region.high_energy), + self.deflector_x.set(epics_region.deflector_x), # go in sub class? + self.acquire_time.set(epics_region.acquire_time), + self.iterations.set(epics_region.iterations), + ] + if epics_region.acquisition_mode == AcquisitionMode.SWEPT: + coroutines.append(self.energy_step.set(epics_region.energy_step)) + + await asyncio.gather(*coroutines) diff --git a/src/dodal/devices/electron_analyser/mbs/mbs_enums.py b/src/dodal/devices/electron_analyser/mbs/mbs_enums.py new file mode 100644 index 00000000000..d7f18f88875 --- /dev/null +++ b/src/dodal/devices/electron_analyser/mbs/mbs_enums.py @@ -0,0 +1,7 @@ +from ophyd_async.core import StrictEnum + + +class AcquisitionMode(StrictEnum): + FIXED = "Fixed" + SWEPT = "Swept" + DITHER = "Dither" diff --git a/src/dodal/devices/electron_analyser/mbs/mbs_region.py b/src/dodal/devices/electron_analyser/mbs/mbs_region.py new file mode 100644 index 00000000000..889a2e42d9e --- /dev/null +++ b/src/dodal/devices/electron_analyser/mbs/mbs_region.py @@ -0,0 +1,70 @@ +from os.path import basename, splitext +from typing import Generic, Self + +import xmltodict +from ophyd_async.core import StrictEnum +from pydantic import Field, field_validator + +from dodal.devices.electron_analyser.base.base_region import ( + BaseRegion, + BaseSequence, + TLensMode, + TPassEnergy, +) +from dodal.devices.electron_analyser.mbs.mbs_enums import AcquisitionMode + + +class MbsRegion( + BaseRegion[AcquisitionMode, TLensMode, TPassEnergy], + Generic[TLensMode, TPassEnergy], +): + # Override base class with defaults + lens_mode: TLensMode + pass_energy: TPassEnergy + acquisition_mode: AcquisitionMode = AcquisitionMode.FIXED + low_energy: float = Field(default=800, alias="start_energy") + high_energy: float = Field(default=850, alias="end_energy") + centre_energy: float = Field( + default_factory=lambda data: (data["high_energy"] + data["low_energy"]) / 2 + ) + acquire_time: float = Field(default=1.0, alias="time_per_step") + energy_step: float = Field(default=0.1, alias="step_energy") + # Default is True as mbs ususally only uses one region. + enabled: bool = True + + # Specific to this class + deflector_x: float = 0 + + @field_validator("pass_energy", mode="before") + @classmethod + def convert_pass_energy(cls, value): + # Allow for using enum or int. + if isinstance(value, StrictEnum): + return value + return f"PE{int(value):03d}" + + @classmethod + def from_xml(cls, file: str) -> Self: + name = splitext(basename(file))[0] + with open(file) as f: + data = xmltodict.parse(f.read()) + region = cls.model_validate(data["ARPESScanBean"]) + region.name = name + # Convert from meV to eV + region.energy_step = round(region.energy_step / 1000.0, 6) + return region + + +class MbsSequence( + BaseSequence[MbsRegion[TLensMode, TPassEnergy]], Generic[TLensMode, TPassEnergy] +): + @classmethod + def from_xml(cls, file: str) -> Self: + regions = [] + # Must find the region type annotation because reconstructing the generic + # manually doing MbsRegion[TLensMode, TPassEnergy].from_xml(file) will not work. + annotation = cls.model_fields["regions"].annotation + assert annotation is not None + region_type = annotation.__args__[0] + regions = [region_type.from_xml(file)] + return cls.model_validate({"regions": regions}) diff --git a/src/dodal/devices/electron_analyser/specs/specs_driver_io.py b/src/dodal/devices/electron_analyser/specs/specs_driver_io.py index 6911cf65356..2c484b09eb0 100644 --- a/src/dodal/devices/electron_analyser/specs/specs_driver_io.py +++ b/src/dodal/devices/electron_analyser/specs/specs_driver_io.py @@ -5,21 +5,34 @@ from ophyd_async.core import ( Array1D, AsyncStatus, + DeviceMock, SignalR, StandardReadableFormat, + callback_on_mock_put, + default_mock_class, derived_signal_r, + set_mock_value, ) -from ophyd_async.epics.core import epics_signal_r, epics_signal_rw +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw, epics_signal_w from dodal.devices.electron_analyser.base.base_driver_io import ( - _PSU, AbstractAnalyserDriverIO, + ElectronAnalyserPVConfig, ) from dodal.devices.electron_analyser.base.base_region import TLensMode, TPsuMode from dodal.devices.electron_analyser.specs.specs_enums import AcquisitionMode from dodal.devices.electron_analyser.specs.specs_region import SpecsRegion +class MockSpecsAnalyserDriverIO(DeviceMock["SpecsAnalyserDriverIO"]): + async def connect(self, device: "SpecsAnalyserDriverIO"): + def _sync_psu_mode_rbv(value): + set_mock_value(device.psu_mode, value) + + callback_on_mock_put(device.psu_mode_w, _sync_psu_mode_rbv) + + +@default_mock_class(MockSpecsAnalyserDriverIO) class SpecsAnalyserDriverIO( AbstractAnalyserDriverIO[ SpecsRegion[TLensMode, TPsuMode], @@ -30,12 +43,13 @@ class SpecsAnalyserDriverIO( ], Generic[TLensMode, TPsuMode], ): + PV_CFG = ElectronAnalyserPVConfig() + def __init__( self, prefix: str, lens_mode_type: type[TLensMode], psu_mode_type: type[TPsuMode], - psu_suffix: str = _PSU, name: str = "", ) -> None: with self.add_children_as_readables(StandardReadableFormat.CONFIG_SIGNAL): @@ -50,6 +64,7 @@ def __init__( self.energy_channels = epics_signal_r( int, prefix + "TOTAL_POINTS_ITERATION_RBV" ) + self.psu_mode_w = epics_signal_w(psu_mode_type, prefix + self.PV_CFG.psu_mode) super().__init__( prefix=prefix, @@ -57,7 +72,6 @@ def __init__( lens_mode_type=lens_mode_type, psu_mode_type=psu_mode_type, pass_energy_type=float, - psu_suffix=psu_suffix, name=name, ) @@ -74,7 +88,7 @@ async def set(self, epics_region: SpecsRegion[TLensMode, TPsuMode]): self.iterations.set(epics_region.iterations), self.acquisition_mode.set(epics_region.acquisition_mode), self.snapshot_values.set(epics_region.values), - self.psu_mode.set(epics_region.psu_mode), + self.psu_mode_w.set(epics_region.psu_mode), self.energy_mode.set(epics_region.energy_mode), ) if epics_region.acquisition_mode == AcquisitionMode.FIXED_TRANSMISSION: diff --git a/src/dodal/devices/electron_analyser/specs/specs_region.py b/src/dodal/devices/electron_analyser/specs/specs_region.py index cfb6bc426fc..5204997fb18 100644 --- a/src/dodal/devices/electron_analyser/specs/specs_region.py +++ b/src/dodal/devices/electron_analyser/specs/specs_region.py @@ -26,6 +26,7 @@ class SpecsRegion( energy_step: float = Field(default=0.1, alias="step_energy") # Specific to this class + slices: int = 1 values: int = 1 psu_mode: TPsuMode estimated_time_in_ms: float = 0 @@ -34,4 +35,4 @@ class SpecsRegion( class SpecsSequence( BaseSequence[SpecsRegion[TLensMode, TPsuMode]], Generic[TLensMode, TPsuMode] ): - regions: list[SpecsRegion[TLensMode, TPsuMode]] = Field(default_factory=lambda: []) + pass diff --git a/src/dodal/devices/electron_analyser/vgscienta/vgscienta_driver_io.py b/src/dodal/devices/electron_analyser/vgscienta/vgscienta_driver_io.py index 4756baec460..cce7f6023c8 100644 --- a/src/dodal/devices/electron_analyser/vgscienta/vgscienta_driver_io.py +++ b/src/dodal/devices/electron_analyser/vgscienta/vgscienta_driver_io.py @@ -12,6 +12,7 @@ from dodal.devices.electron_analyser.base.base_driver_io import ( AbstractAnalyserDriverIO, + ElectronAnalyserPVConfig, ) from dodal.devices.electron_analyser.base.base_region import TLensMode, TPsuMode from dodal.devices.electron_analyser.vgscienta.vgscienta_enums import ( @@ -34,13 +35,14 @@ class VGScientaAnalyserDriverIO( ], Generic[TLensMode, TPsuMode, TPassEnergyEnum], ): + PV_CFG = ElectronAnalyserPVConfig(psu_mode="ELEMENT_SET") + def __init__( self, prefix: str, lens_mode_type: type[TLensMode], psu_mode_type: type[TPsuMode], pass_energy_type: type[TPassEnergyEnum], - psu_suffix: str = "ELEMENT_SET", name: str = "", ) -> None: with self.add_children_as_readables(StandardReadableFormat.CONFIG_SIGNAL): @@ -61,7 +63,6 @@ def __init__( lens_mode_type=lens_mode_type, psu_mode_type=psu_mode_type, pass_energy_type=pass_energy_type, - psu_suffix=psu_suffix, name=name, ) diff --git a/src/dodal/devices/electron_analyser/vgscienta/vgscienta_region.py b/src/dodal/devices/electron_analyser/vgscienta/vgscienta_region.py index 556da7bfa78..228edd54683 100644 --- a/src/dodal/devices/electron_analyser/vgscienta/vgscienta_region.py +++ b/src/dodal/devices/electron_analyser/vgscienta/vgscienta_region.py @@ -7,7 +7,6 @@ BaseRegion, BaseSequence, TLensMode, - TPsuMode, ) from dodal.devices.electron_analyser.vgscienta.vgscienta_enums import ( AcquisitionMode, @@ -31,6 +30,7 @@ class VGScientaRegion( acquire_time: float = Field(default=1.0, alias="step_time") energy_step: float = Field(default=200.0) # Specific to this class + slices: int = 1 total_steps: float = 13.0 total_time: float = 13.0 min_x: int = Field(alias="first_x_channel", default=1) @@ -57,9 +57,6 @@ def validate_pass_energy(cls, val): class VGScientaSequence( BaseSequence[VGScientaRegion[TLensMode, TPassEnergyEnum]], - Generic[TLensMode, TPsuMode, TPassEnergyEnum], + Generic[TLensMode, TPassEnergyEnum], ): - psu_mode: TPsuMode = Field(alias="element_set") - regions: list[VGScientaRegion[TLensMode, TPassEnergyEnum]] = Field( - default_factory=lambda: [] - ) + pass diff --git a/tests/common/test_data_util.py b/tests/common/test_data_util.py index c22e49a376a..656005eb130 100644 --- a/tests/common/test_data_util.py +++ b/tests/common/test_data_util.py @@ -4,8 +4,8 @@ from pydantic import BaseModel from dodal.common.data_util import ( - JsonLoaderConfig, - JsonModelLoader, + ModelLoader, + ModelLoaderConfig, json_model_loader, save_class_to_json_file, ) @@ -50,14 +50,15 @@ def tmp_file(tmp_path, other_model: MyModel) -> str: @pytest.fixture def load_json_model_with_default_file_only( default_tmp_file: str, -) -> JsonModelLoader[MyModel]: - return json_model_loader( - MyModel, JsonLoaderConfig.from_default_file(default_tmp_file) +) -> ModelLoader[MyModel]: + return ModelLoader( + json_model_loader(MyModel), + ModelLoaderConfig.from_default_file(default_tmp_file), ) def test_json_model_loader_with_configured_default_file_only( - load_json_model_with_default_file_only: JsonModelLoader[MyModel], + load_json_model_with_default_file_only: ModelLoader[MyModel], tmp_file: str, other_model: MyModel, default_model: MyModel, @@ -78,13 +79,15 @@ def test_json_model_loader_with_configured_default_file_only( @pytest.fixture def load_json_model_with_default_path_only( default_tmp_file: str, -) -> JsonModelLoader[MyModel]: +) -> ModelLoader[MyModel]: path, file = split(default_tmp_file) - return json_model_loader(MyModel, JsonLoaderConfig.from_default_path(path)) + return ModelLoader( + json_model_loader(MyModel), ModelLoaderConfig.from_default_path(path) + ) def test_load_json_model_with_configued_path_only( - load_json_model_with_default_path_only: JsonModelLoader[MyModel], + load_json_model_with_default_path_only: ModelLoader[MyModel], tmp_file: str, other_model: MyModel, ) -> None: @@ -99,7 +102,7 @@ def test_load_json_model_with_configued_path_only( with pytest.raises( RuntimeError, - match="MyModel loader has no default file configured and no file was provided.", + match="Model loader has no default file configured and no file was provided.", ): load_json_model_with_default_path_only() @@ -107,15 +110,16 @@ def test_load_json_model_with_configued_path_only( @pytest.fixture def load_json_model_with_default_path_and_file( default_tmp_file: str, -) -> JsonModelLoader[MyModel]: +) -> ModelLoader[MyModel]: path, file = split(default_tmp_file) - return json_model_loader( - MyModel, JsonLoaderConfig(default_path=path, default_file=file) + return ModelLoader( + json_model_loader(MyModel), + ModelLoaderConfig(default_path=path, default_file=file), ) def test_load_json_model_with_configued_path_and_file( - load_json_model_with_default_path_and_file: JsonModelLoader[MyModel], + load_json_model_with_default_path_and_file: ModelLoader[MyModel], tmp_file: str, other_model: MyModel, default_model: MyModel, @@ -135,18 +139,18 @@ def test_load_json_model_with_configued_path_and_file( @pytest.fixture -def load_json_model_no_config() -> JsonModelLoader[MyModel]: - return json_model_loader(MyModel) +def load_json_model_no_config() -> ModelLoader[MyModel]: + return ModelLoader(json_model_loader(MyModel)) def test_json_model_loader_with_no_config( - load_json_model_no_config: JsonModelLoader[MyModel], + load_json_model_no_config: ModelLoader[MyModel], tmp_file: str, other_model: MyModel, ) -> None: with pytest.raises( RuntimeError, - match="MyModel loader has no default file configured and no file was provided.", + match="Model loader has no default file configured and no file was provided.", ): load_json_model_no_config() @@ -163,8 +167,8 @@ def test_json_model_loader_with_no_config( def test_updating_config_updates_factory_function( default_tmp_file: str, tmp_file: str, default_model: MyModel, other_model: MyModel ) -> None: - config = JsonLoaderConfig.from_default_file(default_tmp_file) - model_loader = json_model_loader(MyModel, config) + config = ModelLoaderConfig.from_default_file(default_tmp_file) + model_loader = ModelLoader(json_model_loader(MyModel), config) # Test uses default file model_result = model_loader() @@ -178,11 +182,11 @@ def test_updating_config_updates_factory_function( @pytest.fixture def all_json_model_loaders( - load_json_model_with_default_file_only: JsonModelLoader[MyModel], - load_json_model_with_default_path_only: JsonModelLoader[MyModel], - load_json_model_with_default_path_and_file: JsonModelLoader[MyModel], - load_json_model_no_config: JsonModelLoader[MyModel], -) -> list[JsonModelLoader[MyModel]]: + load_json_model_with_default_file_only: ModelLoader[MyModel], + load_json_model_with_default_path_only: ModelLoader[MyModel], + load_json_model_with_default_path_and_file: ModelLoader[MyModel], + load_json_model_no_config: ModelLoader[MyModel], +) -> list[ModelLoader[MyModel]]: return [ load_json_model_with_default_file_only, load_json_model_with_default_path_only, @@ -193,7 +197,7 @@ def all_json_model_loaders( @pytest.mark.parametrize("loader_position", range(4)) def test_all_json_model_loader_raise_error_if_invalid_file( - all_json_model_loaders: list[JsonModelLoader[MyModel]], + all_json_model_loaders: list[ModelLoader[MyModel]], loader_position: int, ) -> None: json_loader = all_json_model_loaders[loader_position] diff --git a/tests/devices/electron_analyser/base/test_base_detector.py b/tests/devices/electron_analyser/base/test_base_detector.py index 9eceaeb6621..66a877e0f35 100644 --- a/tests/devices/electron_analyser/base/test_base_detector.py +++ b/tests/devices/electron_analyser/base/test_base_detector.py @@ -14,17 +14,20 @@ GenericSequence, ) from tests.devices.electron_analyser.helper_util import ( - TEST_SEQUENCE_REGION_NAMES, - get_test_sequence, + generate_fixture_regions_pair, + load_b07_specs_test_seq, + load_i09_vgscienta_test_seq, ) +DETECTOR_REGIONS_PAIR = [ + *generate_fixture_regions_pair("ew4000", load_i09_vgscienta_test_seq().regions), + *generate_fixture_regions_pair("b07b_specs150", load_b07_specs_test_seq().regions), +] -@pytest.fixture -def sequence(sim_detector: GenericElectronAnalyserDetector) -> GenericSequence: - return get_test_sequence(type(sim_detector)) - -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize( + ("sim_detector", "region"), DETECTOR_REGIONS_PAIR, indirect=["sim_detector"] +) async def test_base_analyser_detector_describe_configuration( sim_detector: GenericElectronAnalyserDetector, region: GenericRegion ) -> None: @@ -46,14 +49,11 @@ async def test_base_analyser_detector_describe_configuration( await assert_value(sim_detector.binding_energy_axis, expected_binding_energy_axis) -# def test_analyser_detector_set_called_region_logic_setup_with_region( async def test_analyser_detector_stage( sim_detector: GenericElectronAnalyserDetector, ) -> None: sim_detector.sequence.stage = AsyncMock() - await sim_detector.stage() - sim_detector.sequence.stage.assert_awaited_once() @@ -61,23 +61,31 @@ async def test_analyser_detector_unstage( sim_detector: GenericElectronAnalyserDetector, ) -> None: sim_detector.sequence.unstage = AsyncMock() - await sim_detector.unstage() - sim_detector.sequence.unstage.assert_awaited_once() +@pytest.mark.parametrize( + ("sim_detector", "region"), DETECTOR_REGIONS_PAIR, indirect=["sim_detector"] +) def test_analyser_detector_set_called_region_logic_setup_with_region( sim_detector: GenericElectronAnalyserDetector, - sequence: GenericSequence, + region: GenericRegion, run_engine: RunEngine, ) -> None: - region = sequence.get_enabled_regions()[0] sim_detector._region_logic.setup_with_region = AsyncMock() run_engine(bps.mv(sim_detector, region), wait=True) sim_detector._region_logic.setup_with_region.assert_awaited_once_with(region) +@pytest.mark.parametrize( + ("sim_detector", "sequence"), + [ + pytest.param("ew4000", load_i09_vgscienta_test_seq()), + pytest.param("b07b_specs150", load_b07_specs_test_seq()), + ], + indirect=["sim_detector"], +) def test_analyser_read_configuration_is_unique_per_region( sim_detector: GenericElectronAnalyserDetector, sequence: GenericSequence, diff --git a/tests/devices/electron_analyser/base/test_base_driver_io.py b/tests/devices/electron_analyser/base/test_base_driver_io.py index 0f588658e62..a6968d48b0e 100644 --- a/tests/devices/electron_analyser/base/test_base_driver_io.py +++ b/tests/devices/electron_analyser/base/test_base_driver_io.py @@ -41,16 +41,3 @@ class AcquisitionModeTestEnum(StrictEnum): acq_datatype_name = acq_datatype.__name__ if acq_datatype is not None else "" with pytest.raises(FailedStatus, match=f"is not a valid {acq_datatype_name}"): run_engine(bps.mv(sim_driver.acquisition_mode, AcquisitionModeTestEnum.TEST_1)) - - -def test_driver_throws_error_with_wrong_psu_mode( - sim_driver: GenericAnalyserDriverIO, - run_engine: RunEngine, -) -> None: - class PsuModeTestEnum(StrictEnum): - TEST_1 = "Invalid mode" - - psu_datatype = sim_driver.psu_mode.datatype - psu_datatype_name = psu_datatype.__name__ if psu_datatype is not None else "" - with pytest.raises(FailedStatus, match=f"is not a valid {psu_datatype_name}"): - run_engine(bps.mv(sim_driver.psu_mode, PsuModeTestEnum.TEST_1)) diff --git a/tests/devices/electron_analyser/base/test_base_region.py b/tests/devices/electron_analyser/base/test_base_region.py index bcf7b595d98..ef5697c0c38 100644 --- a/tests/devices/electron_analyser/base/test_base_region.py +++ b/tests/devices/electron_analyser/base/test_base_region.py @@ -2,7 +2,7 @@ import pytest -from dodal.devices.beamlines import b07, b07_shared, i09 +from dodal.devices.beamlines import b07, b07_shared, i05_shared, i09 from dodal.devices.electron_analyser.base import ( BaseRegion, EnergyMode, @@ -12,41 +12,24 @@ to_binding_energy, to_kinetic_energy, ) -from dodal.devices.electron_analyser.specs import ( - SpecsRegion, - SpecsSequence, -) -from dodal.devices.electron_analyser.vgscienta import VGScientaRegion, VGScientaSequence +from dodal.devices.electron_analyser.mbs import MbsRegion +from dodal.devices.electron_analyser.specs import SpecsRegion +from dodal.devices.electron_analyser.vgscienta import VGScientaRegion from tests.devices.electron_analyser.helper_util import ( - TEST_SEQUENCE_REGION_NAMES, - get_test_sequence, + load_b07_specs_test_seq, + load_i05_mbs_test_xml_seq, + load_i09_vgscienta_test_seq, ) -@pytest.fixture( - params=[ - SpecsSequence[b07.LensMode, b07_shared.PsuMode], - VGScientaSequence[i09.LensMode, i09.PsuMode, i09.PassEnergy], +@pytest.mark.parametrize( + "sequence, expected_region_names", + [ + (load_b07_specs_test_seq(), ["New_Region", "New_Region1", "New_Region2"]), + (load_i09_vgscienta_test_seq(), ["New_Region", "New_Region1", "New_Region2"]), + (load_i05_mbs_test_xml_seq(), ["mbs_region1"]), ], ) -def sequence(request: pytest.FixtureRequest) -> GenericSequence: - return get_test_sequence(request.param) - - -@pytest.fixture -def expected_region_class(sequence: GenericSequence) -> type[BaseRegion]: - if isinstance(sequence, SpecsSequence): - return SpecsRegion[b07.LensMode, b07_shared.PsuMode] - elif isinstance(sequence, VGScientaSequence): - return VGScientaRegion[i09.LensMode, i09.PassEnergy] - raise TypeError(f"Unknown sequence type {type(sequence)}") - - -@pytest.fixture -def expected_region_names() -> list[str]: - return TEST_SEQUENCE_REGION_NAMES - - def test_sequence_get_expected_region_from_name( sequence: GenericSequence, expected_region_names: list[str] ) -> None: @@ -55,6 +38,33 @@ def test_sequence_get_expected_region_from_name( assert sequence.get_region_by_name("region name should not be in sequence") is None +@pytest.mark.parametrize( + "sequence, expected_enabled_region_names", + [ + (load_b07_specs_test_seq(), ["New_Region1", "New_Region2"]), + (load_i09_vgscienta_test_seq(), ["New_Region", "New_Region2"]), + (load_i05_mbs_test_xml_seq(), ["mbs_region1"]), + ], +) +def test_load_sequence_has_expected_enabled_region_names( + sequence: GenericSequence, expected_enabled_region_names: list[str] +) -> None: + assert sequence.get_enabled_region_names() == expected_enabled_region_names + for i, region in enumerate(sequence.get_enabled_regions()): + assert region.name == expected_enabled_region_names[i] + + +@pytest.mark.parametrize( + "sequence, expected_region_class", + [ + (load_b07_specs_test_seq(), SpecsRegion[b07.LensMode, b07_shared.PsuMode]), + (load_i09_vgscienta_test_seq(), VGScientaRegion[i09.LensMode, i09.PassEnergy]), + ( + load_i05_mbs_test_xml_seq(), + MbsRegion[i05_shared.LensMode, i05_shared.PassEnergy], + ), + ], +) def test_sequence_get_expected_region_type( sequence: GenericSequence, expected_region_class: type[TBaseRegion], @@ -69,27 +79,40 @@ def test_sequence_get_expected_region_type( ) +@pytest.mark.parametrize( + "sequence, expected_region_names", + [ + (load_b07_specs_test_seq(), ["New_Region", "New_Region1", "New_Region2"]), + (load_i09_vgscienta_test_seq(), ["New_Region", "New_Region1", "New_Region2"]), + (load_i05_mbs_test_xml_seq(), ["mbs_region1"]), + ], +) def test_sequence_get_expected_region_names( sequence: GenericSequence, expected_region_names: list[str] ) -> None: assert sequence.get_region_names() == expected_region_names -def test_region_kinetic_and_binding_energy( - sequence: GenericSequence, -) -> None: - for r in sequence.regions: - is_binding_energy = r.energy_mode == EnergyMode.BINDING - is_kinetic_energy = r.energy_mode == EnergyMode.KINETIC - assert r.is_binding_energy() == is_binding_energy - assert r.is_binding_energy() != is_kinetic_energy - assert r.is_kinetic_energy() == is_kinetic_energy - assert r.is_kinetic_energy() != is_binding_energy +ALL_REGION_TESTS_CASES = [ + *load_b07_specs_test_seq().regions, + *load_i09_vgscienta_test_seq().regions, + *load_i05_mbs_test_xml_seq().regions, +] + + +@pytest.mark.parametrize("region", ALL_REGION_TESTS_CASES) +def test_region_kinetic_and_binding_energy(region: GenericRegion) -> None: + is_binding_energy = region.energy_mode == EnergyMode.BINDING + is_kinetic_energy = region.energy_mode == EnergyMode.KINETIC + assert region.is_binding_energy() == is_binding_energy + assert region.is_binding_energy() != is_kinetic_energy + assert region.is_kinetic_energy() == is_kinetic_energy + assert region.is_kinetic_energy() != is_binding_energy @pytest.mark.parametrize("field", ["low_energy", "centre_energy", "high_energy"]) @pytest.mark.parametrize("copy", [True, False]) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize("region", ALL_REGION_TESTS_CASES) def test_each_energy_field_for_region_is_correct_when_switching_energy_modes( region: GenericRegion, field: str, copy: bool ) -> None: @@ -132,7 +155,7 @@ def test_each_energy_field_for_region_is_correct_when_switching_energy_modes( @pytest.mark.parametrize("copy", [True, False]) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize("region", ALL_REGION_TESTS_CASES) def test_region_prepare_for_epics(region: GenericRegion, copy: bool) -> None: # Patch switch_energy_mode so we can spy on if it was called while also returning # true function return value diff --git a/tests/devices/electron_analyser/base/test_detector_logic.py b/tests/devices/electron_analyser/base/test_detector_logic.py index be4f8aab176..1118cf630d6 100644 --- a/tests/devices/electron_analyser/base/test_detector_logic.py +++ b/tests/devices/electron_analyser/base/test_detector_logic.py @@ -13,30 +13,34 @@ from ophyd_async.epics.adcore import ADImageMode from ophyd_async.testing import assert_configuration, partial_reading -from dodal.devices.beamlines import b07, b07_shared, i09 +from dodal.devices.beamlines import b07, b07_shared, i05_shared, i09 from dodal.devices.electron_analyser.base import ( AbstractAnalyserDriverIO, AbstractEnergySource, BaseRegion, - BaseSequence, ) from dodal.devices.electron_analyser.base.detector_logic import ( ElectronAnalayserTriggerLogic, RegionLogic, ShutterCoordinatorADArmLogic, ) +from dodal.devices.electron_analyser.mbs import MbsAnalyserDriverIO from dodal.devices.electron_analyser.specs import SpecsAnalyserDriverIO from dodal.devices.electron_analyser.vgscienta import VGScientaAnalyserDriverIO from dodal.devices.fast_shutter import GenericFastShutter from dodal.devices.selectable_source import SourceSelector from tests.devices.electron_analyser.helper_util import ( - TEST_SEQUENCE_REGION_NAMES, - get_test_sequence, + generate_fixture_regions_pair, + load_b07_specs_test_seq, + load_i05_mbs_test_xml_seq, + load_i09_vgscienta_test_seq, ) @pytest.fixture -def vgscienta_driver() -> VGScientaAnalyserDriverIO: +def vgscienta_driver() -> VGScientaAnalyserDriverIO[ + i09.LensMode, i09.PsuMode, i09.PassEnergy +]: with init_devices(mock=True): vgscienta_driver = VGScientaAnalyserDriverIO( "TEST:", i09.LensMode, i09.PsuMode, i09.PassEnergy @@ -45,13 +49,22 @@ def vgscienta_driver() -> VGScientaAnalyserDriverIO: @pytest.fixture -def specs_driver() -> SpecsAnalyserDriverIO: +def specs_driver() -> SpecsAnalyserDriverIO[b07.LensMode, b07_shared.PsuMode]: with init_devices(mock=True): specs_driver = SpecsAnalyserDriverIO("TEST:", b07.LensMode, b07_shared.PsuMode) return specs_driver -@pytest.fixture(params=["specs_driver", "vgscienta_driver"]) +@pytest.fixture +def mbs_driver() -> MbsAnalyserDriverIO[i05_shared.LensMode, i05_shared.PassEnergy]: + with init_devices(mock=True): + mbs_driver = MbsAnalyserDriverIO( + "TEST:", i05_shared.LensMode, i05_shared.PassEnergy + ) + return mbs_driver + + +@pytest.fixture(params=["specs_driver", "vgscienta_driver", "mbs_driver"]) def driver(request: pytest.FixtureRequest) -> AbstractAnalyserDriverIO: return request.getfixturevalue(request.param) @@ -105,11 +118,6 @@ def energy_source(request: pytest.FixtureRequest) -> AbstractEnergySource: return request.getfixturevalue(request.param) -@pytest.fixture -def sequence(driver: AbstractAnalyserDriverIO) -> BaseSequence: - return get_test_sequence(type(driver)) - - @pytest.fixture def region_logic( driver: AbstractAnalyserDriverIO, @@ -119,12 +127,20 @@ def region_logic( return RegionLogic(driver, energy_source, source_selector) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +DRIVER_REGIONS_PAIR = [ + *generate_fixture_regions_pair("specs_driver", load_b07_specs_test_seq().regions), + *generate_fixture_regions_pair( + "vgscienta_driver", load_i09_vgscienta_test_seq().regions + ), + *generate_fixture_regions_pair("mbs_driver", load_i05_mbs_test_xml_seq().regions), +] + + +@pytest.mark.parametrize(("driver", "region"), DRIVER_REGIONS_PAIR, indirect=["driver"]) async def test_region_logic_setup_with_region_sets_region_for_epics_and_sets_driver( region: BaseRegion, region_logic: RegionLogic, ) -> None: - region_logic.driver.set = AsyncMock() # Patch switch_energy_mode so we can check on calls, but still run the real function @@ -152,7 +168,7 @@ async def test_region_logic_setup_with_region_sets_region_for_epics_and_sets_dri region_logic.driver.set.assert_called_once_with(epics_region) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize(("driver", "region"), DRIVER_REGIONS_PAIR, indirect=["driver"]) async def test_region_logic_setup_with_region_moves_selected_source_if_not_none( region: BaseRegion, region_logic: RegionLogic ) -> None: diff --git a/tests/devices/electron_analyser/conftest.py b/tests/devices/electron_analyser/conftest.py index 897255f4590..42f2adb92ba 100644 --- a/tests/devices/electron_analyser/conftest.py +++ b/tests/devices/electron_analyser/conftest.py @@ -11,12 +11,7 @@ PitchAndRollCrystal, StationaryCrystal, ) -from dodal.devices.electron_analyser.base import ( - BaseRegion, - BaseSequence, - DualEnergySource, - EnergySource, -) +from dodal.devices.electron_analyser.base import DualEnergySource, EnergySource from dodal.devices.electron_analyser.specs import SpecsDetector from dodal.devices.electron_analyser.vgscienta import VGScientaDetector from dodal.devices.fast_shutter import DualFastShutter, FastShutter @@ -138,17 +133,6 @@ async def ew4000( return ew4000 -@pytest.fixture -def region( - request: pytest.FixtureRequest, - sequence: BaseSequence[BaseRegion], -) -> BaseRegion: - region = sequence.get_region_by_name(request.param) - if region is None: - raise ValueError("Region " + request.param + " is not found.") - return region - - @pytest.fixture def expected_enabled_region_names( expected_region_values: list[dict[str, Any]], diff --git a/tests/devices/electron_analyser/helper_util/__init__.py b/tests/devices/electron_analyser/helper_util/__init__.py index c3026b1e71d..4958262123a 100644 --- a/tests/devices/electron_analyser/helper_util/__init__.py +++ b/tests/devices/electron_analyser/helper_util/__init__.py @@ -1,8 +1,15 @@ from .assert_func import assert_region_has_expected_values -from .sequence import TEST_SEQUENCE_REGION_NAMES, get_test_sequence +from .sequence import ( + generate_fixture_regions_pair, + load_b07_specs_test_seq, + load_i05_mbs_test_xml_seq, + load_i09_vgscienta_test_seq, +) __all__ = [ "assert_region_has_expected_values", - "get_test_sequence", - "TEST_SEQUENCE_REGION_NAMES", + "generate_fixture_regions_pair", + "load_b07_specs_test_seq", + "load_i05_mbs_test_xml_seq", + "load_i09_vgscienta_test_seq", ] diff --git a/tests/devices/electron_analyser/helper_util/sequence.py b/tests/devices/electron_analyser/helper_util/sequence.py index 1981fbbee30..0e321a2a73a 100644 --- a/tests/devices/electron_analyser/helper_util/sequence.py +++ b/tests/devices/electron_analyser/helper_util/sequence.py @@ -1,47 +1,52 @@ -from dodal.common.data_util import JsonLoaderConfig, json_model_loader -from dodal.devices.beamlines import b07, b07_shared, i09 -from dodal.devices.electron_analyser.specs import ( - SpecsAnalyserDriverIO, - SpecsDetector, - SpecsSequence, -) -from dodal.devices.electron_analyser.vgscienta import ( - VGScientaAnalyserDriverIO, - VGScientaDetector, - VGScientaSequence, -) +from collections.abc import Sequence + +import pytest + +from dodal.common.data_util import ModelLoader, ModelLoaderConfig, json_model_loader +from dodal.devices.beamlines import b07, b07_shared, i05_shared, i09 +from dodal.devices.electron_analyser.base import BaseRegion +from dodal.devices.electron_analyser.mbs import MbsSequence +from dodal.devices.electron_analyser.specs import SpecsSequence +from dodal.devices.electron_analyser.vgscienta import VGScientaSequence from tests.devices.electron_analyser.test_data import ( + TEST_MBS_XML_SEQUENCE, TEST_SPECS_SEQUENCE, TEST_VGSCIENTA_SEQUENCE, ) -TEST_SEQUENCE_REGION_NAMES = ["New_Region", "New_Region1", "New_Region2"] +B07SpecsSequence = SpecsSequence[b07.LensMode, b07_shared.PsuMode] +I09VGScientaSequence = VGScientaSequence[i09.LensMode, i09.PassEnergy] +I05MbsSequence = MbsSequence[i05_shared.LensMode, i05_shared.PassEnergy] -load_b07_specs_test_sequence = json_model_loader( - SpecsSequence[b07.LensMode, b07_shared.PsuMode], - JsonLoaderConfig.from_default_file(TEST_SPECS_SEQUENCE), +load_b07_specs_test_seq = ModelLoader[B07SpecsSequence]( + json_model_loader(B07SpecsSequence), + ModelLoaderConfig.from_default_file(TEST_SPECS_SEQUENCE), ) -load_i09_vgscienta_test_sequence = json_model_loader( - VGScientaSequence[i09.LensMode, i09.PsuMode, i09.PassEnergy], - JsonLoaderConfig.from_default_file(TEST_VGSCIENTA_SEQUENCE), +load_i09_vgscienta_test_seq = ModelLoader[I09VGScientaSequence]( + json_model_loader(I09VGScientaSequence), + ModelLoaderConfig.from_default_file(TEST_VGSCIENTA_SEQUENCE), +) +load_i05_mbs_test_xml_seq = ModelLoader[I05MbsSequence]( + lambda file: I05MbsSequence.from_xml(file), + ModelLoaderConfig.from_default_file(TEST_MBS_XML_SEQUENCE), ) -# Map to know what function to load in sequence an analyser driver should use. -TEST_SEQUENCES = { - SpecsDetector: load_b07_specs_test_sequence, - SpecsAnalyserDriverIO: load_b07_specs_test_sequence, - SpecsSequence: load_b07_specs_test_sequence, - VGScientaDetector: load_i09_vgscienta_test_sequence, - VGScientaAnalyserDriverIO: load_i09_vgscienta_test_sequence, - VGScientaSequence: load_i09_vgscienta_test_sequence, -} - - -def get_test_sequence(key: type): - for cls in key.__mro__: - # Check for unscripted class only - if cls in TEST_SEQUENCES: - return TEST_SEQUENCES[cls]() - raise KeyError(f"Found no match with type {key}") +def generate_fixture_regions_pair( + fixture_name: str, regions: Sequence[BaseRegion] +) -> list: + """Generate a parameterised pytest with a fixture name with the assoicated regions. + Useful for tests where you need to test each driver or detector with the paried + sequence file. + """ + test_cases = [] + for region in regions: + test_cases.append( + pytest.param( + fixture_name, + region, + id=f"{fixture_name}-{type(region).__name__}-{region.name}", + ) + ) + return test_cases diff --git a/tests/devices/electron_analyser/mbs/__init__.py b/tests/devices/electron_analyser/mbs/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/devices/electron_analyser/mbs/test_mbs_driver_io.py b/tests/devices/electron_analyser/mbs/test_mbs_driver_io.py new file mode 100644 index 00000000000..6412eb044a7 --- /dev/null +++ b/tests/devices/electron_analyser/mbs/test_mbs_driver_io.py @@ -0,0 +1,90 @@ +from unittest.mock import ANY + +import pytest +from bluesky import plan_stubs as bps +from bluesky.run_engine import RunEngine +from ophyd_async.core import get_mock_put, init_devices +from ophyd_async.testing import assert_configuration, partial_reading + +from dodal.devices.beamlines.i05_shared import LensMode, PassEnergy +from dodal.devices.electron_analyser.mbs import ( + AcquisitionMode, + MbsAnalyserDriverIO, + MbsRegion, +) +from tests.devices.electron_analyser.helper_util import load_i05_mbs_test_xml_seq + + +@pytest.fixture +async def sim_driver() -> MbsAnalyserDriverIO[LensMode, PassEnergy]: + with init_devices(mock=True): + sim_driver = MbsAnalyserDriverIO("TEST:", LensMode, PassEnergy) + return sim_driver + + +@pytest.mark.parametrize("region", load_i05_mbs_test_xml_seq().regions) +async def test_analyser_sets_region_correctly( + sim_driver: MbsAnalyserDriverIO[LensMode, PassEnergy], + region: MbsRegion[LensMode, PassEnergy], + run_engine: RunEngine, +) -> None: + run_engine(bps.mv(sim_driver, region), wait=True) + + get_mock_put(sim_driver.region_name).assert_called_once_with(region.name) + get_mock_put(sim_driver.energy_mode).assert_called_once_with(region.energy_mode) + get_mock_put(sim_driver.acquisition_mode).assert_called_once_with( + region.acquisition_mode + ) + get_mock_put(sim_driver.lens_mode).assert_called_once_with(region.lens_mode) + get_mock_put(sim_driver.low_energy).assert_called_once_with(region.low_energy) + get_mock_put(sim_driver.centre_energy).assert_called_once_with(region.centre_energy) + get_mock_put(sim_driver.high_energy).assert_called_once_with(region.high_energy) + get_mock_put(sim_driver.deflector_x).assert_awaited_once_with(region.deflector_x) + get_mock_put(sim_driver.pass_energy).assert_called_once_with(region.pass_energy) + get_mock_put(sim_driver.acquire_time).assert_called_once_with(region.acquire_time) + get_mock_put(sim_driver.iterations).assert_called_once_with(region.iterations) + if region.acquisition_mode == AcquisitionMode.SWEPT: + get_mock_put(sim_driver.energy_step).assert_called_once_with(region.energy_step) + + +@pytest.mark.parametrize("region", load_i05_mbs_test_xml_seq().regions) +async def test_analyser_sets_region_and_read_configuration_is_correct( + sim_driver: MbsAnalyserDriverIO[LensMode, PassEnergy], + region: MbsRegion[LensMode, PassEnergy], + run_engine: RunEngine, +) -> None: + run_engine(bps.mv(sim_driver, region), wait=True) + + prefix = sim_driver.name + "-" + await assert_configuration( + sim_driver, + { + f"{prefix}region_name": partial_reading(region.name), + f"{prefix}energy_mode": partial_reading(region.energy_mode), + f"{prefix}acquisition_mode": partial_reading(region.acquisition_mode), + f"{prefix}lens_mode": partial_reading(region.lens_mode), + f"{prefix}low_energy": partial_reading(region.low_energy), + f"{prefix}centre_energy": partial_reading(region.centre_energy), + f"{prefix}high_energy": partial_reading(region.high_energy), + f"{prefix}deflector_x": partial_reading(region.deflector_x), + f"{prefix}energy_step": partial_reading(region.energy_step), + f"{prefix}pass_energy": partial_reading(region.pass_energy), + f"{prefix}slices": partial_reading(ANY), + f"{prefix}iterations": partial_reading(region.iterations), + f"{prefix}total_steps": partial_reading(ANY), + f"{prefix}acquire_time": partial_reading(region.acquire_time), + f"{prefix}acquire_period": partial_reading(ANY), + f"{prefix}total_time": partial_reading(ANY), + f"{prefix}energy_axis": partial_reading(ANY), + f"{prefix}angle_axis": partial_reading(ANY), + f"{prefix}psu_mode": partial_reading(ANY), + f"{prefix}dither_steps": partial_reading(0), + f"{prefix}spin_offset": partial_reading(0), + f"{prefix}array_size_x": partial_reading(0), + f"{prefix}array_size_y": partial_reading(0), + f"{prefix}min_x": partial_reading(0), + f"{prefix}min_y": partial_reading(0), + f"{prefix}max_x": partial_reading(0), + f"{prefix}max_y": partial_reading(0), + }, + ) diff --git a/tests/devices/electron_analyser/mbs/test_mbs_region.py b/tests/devices/electron_analyser/mbs/test_mbs_region.py new file mode 100644 index 00000000000..955a414f29a --- /dev/null +++ b/tests/devices/electron_analyser/mbs/test_mbs_region.py @@ -0,0 +1,58 @@ +from typing import Any + +import pytest + +from dodal.devices.beamlines.i05_shared import LensMode, PassEnergy +from dodal.devices.electron_analyser.base import EnergyMode +from dodal.devices.electron_analyser.mbs import AcquisitionMode, MbsRegion, MbsSequence +from dodal.devices.selectable_source import SelectedSource +from tests.devices.electron_analyser.helper_util import ( + assert_region_has_expected_values, +) +from tests.devices.electron_analyser.helper_util.sequence import ( + load_i05_mbs_test_xml_seq, +) + + +@pytest.fixture +def expected_xml_region_values() -> list[dict[str, Any]]: + return [ + { + "name": "mbs_region1", + "enabled": True, + "lens_mode": LensMode.L4_ANG0_D8, + "pass_energy": PassEnergy.PE005, + "iterations": 3, + "acquisition_mode": AcquisitionMode.SWEPT, + "excitation_energy_source": SelectedSource.SOURCE1, + "energy_mode": EnergyMode.KINETIC, + "low_energy": 72.386, + "high_energy": 73.814, + "centre_energy": 73.1, + "acquire_time": 1.0, + "energy_step": 0.000405, + "deflector_x": 0.0, + }, + ] + + +def test_mbs_sequence_from_xml( + expected_xml_region_values: list[dict[str, Any]], +) -> None: + sequence = load_i05_mbs_test_xml_seq() + for i, r in zip(sequence.regions, expected_xml_region_values, strict=True): + assert_region_has_expected_values(i, r) + + +def test_mbs_region_load_using_field_names_has_expected_values( + expected_xml_region_values: list[dict[str, Any]], +) -> None: + for expected_region in expected_xml_region_values: + r = MbsRegion[LensMode, PassEnergy].model_validate(expected_region) + assert_region_has_expected_values(r, expected_region) + + seq = MbsSequence[LensMode, PassEnergy].model_validate( + {"regions": expected_xml_region_values} + ) + for r, expected_r in zip(seq.regions, expected_xml_region_values, strict=True): + assert_region_has_expected_values(r, expected_r) diff --git a/tests/devices/electron_analyser/specs/test_specs_driver_io.py b/tests/devices/electron_analyser/specs/test_specs_driver_io.py index a8b964fc9aa..f637afb8712 100644 --- a/tests/devices/electron_analyser/specs/test_specs_driver_io.py +++ b/tests/devices/electron_analyser/specs/test_specs_driver_io.py @@ -1,16 +1,10 @@ from unittest.mock import ANY -import numpy as np import pytest from bluesky import plan_stubs as bps from bluesky.run_engine import RunEngine from ophyd_async.core import get_mock_put, set_mock_value -from ophyd_async.testing import ( - assert_configuration, - assert_reading, - assert_value, - partial_reading, -) +from ophyd_async.testing import assert_configuration, assert_value, partial_reading from dodal.devices.beamlines.b07 import LensMode from dodal.devices.beamlines.b07_shared import PsuMode @@ -20,10 +14,7 @@ SpecsDetector, SpecsRegion, ) -from tests.devices.electron_analyser.helper_util import ( - TEST_SEQUENCE_REGION_NAMES, - get_test_sequence, -) +from tests.devices.electron_analyser.helper_util.sequence import load_b07_specs_test_seq @pytest.fixture @@ -33,12 +24,7 @@ async def sim_driver( return b07b_specs150.driver -@pytest.fixture -def sequence(sim_driver: SpecsAnalyserDriverIO[LensMode, PsuMode]): - return get_test_sequence(type(sim_driver)) - - -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize("region", load_b07_specs_test_seq().get_enabled_regions()) async def test_analyser_sets_region_correctly( sim_driver: SpecsAnalyserDriverIO[LensMode, PsuMode], region: SpecsRegion[LensMode, PsuMode], @@ -71,11 +57,11 @@ async def test_analyser_sets_region_correctly( else: get_mock_put(sim_driver.energy_step).assert_not_called() - get_mock_put(sim_driver.psu_mode).assert_called_once_with(region.psu_mode) + get_mock_put(sim_driver.psu_mode_w).assert_called_once_with(region.psu_mode) get_mock_put(sim_driver.snapshot_values).assert_called_once_with(region.values) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize("region", load_b07_specs_test_seq().get_enabled_regions()) async def test_analyser_sets_region_and_read_configuration_is_correct( sim_driver: SpecsAnalyserDriverIO[LensMode, PsuMode], region: SpecsRegion[LensMode, PsuMode], @@ -110,28 +96,6 @@ async def test_analyser_sets_region_and_read_configuration_is_correct( ) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) -async def test_analyser_sets_region_and_read_is_correct( - sim_driver: SpecsAnalyserDriverIO[LensMode, PsuMode], - region: SpecsRegion[LensMode, PsuMode], - run_engine: RunEngine, -) -> None: - run_engine(bps.mv(sim_driver, region)) - spectrum = np.array([1, 2, 3, 4, 5], dtype=float) - expected_total_intensity = np.sum(spectrum) - set_mock_value(sim_driver.spectrum, spectrum) - - prefix = sim_driver.name + "-" - await assert_reading( - sim_driver, - { - f"{prefix}image": partial_reading([]), - f"{prefix}spectrum": partial_reading(spectrum), - f"{prefix}total_intensity": partial_reading(expected_total_intensity), - }, - ) - - async def test_specs_analyser_energy_axis( sim_driver: SpecsAnalyserDriverIO[LensMode, PsuMode], run_engine: RunEngine, diff --git a/tests/devices/electron_analyser/specs/test_specs_region.py b/tests/devices/electron_analyser/specs/test_specs_region.py index 5151ce141ff..083301fd199 100644 --- a/tests/devices/electron_analyser/specs/test_specs_region.py +++ b/tests/devices/electron_analyser/specs/test_specs_region.py @@ -13,13 +13,8 @@ from dodal.devices.selectable_source import SelectedSource from tests.devices.electron_analyser.helper_util import ( assert_region_has_expected_values, - get_test_sequence, ) - - -@pytest.fixture -def sequence() -> SpecsSequence[LensMode, PsuMode]: - return get_test_sequence(SpecsSequence[LensMode, PsuMode]) +from tests.devices.electron_analyser.helper_util.sequence import load_b07_specs_test_seq @pytest.fixture @@ -85,21 +80,13 @@ def expected_region_values() -> list[dict[str, Any]]: ] -def test_load_sequence_using_alias_field_names_has_expected_enabled_region_names( - sequence: SpecsSequence[LensMode, PsuMode], - expected_enabled_region_names: list[str], -) -> None: - assert sequence.get_enabled_region_names() == expected_enabled_region_names - for i, region in enumerate(sequence.get_enabled_regions()): - assert region.name == expected_enabled_region_names[i] - - def test_load_sequence_using_alias_field_names_has_expected_values( - sequence: SpecsSequence[LensMode, PsuMode], expected_region_values: list[dict[str, Any]], ) -> None: - for i, r in zip(sequence.regions, expected_region_values, strict=True): - assert_region_has_expected_values(i, r) + for r, expected_r in zip( + load_b07_specs_test_seq().regions, expected_region_values, strict=True + ): + assert_region_has_expected_values(r, expected_r) def test_region_loads_using_field_names_has_expected_values( @@ -108,3 +95,9 @@ def test_region_loads_using_field_names_has_expected_values( for expected_region in expected_region_values: r = SpecsRegion[LensMode, PsuMode].model_validate(expected_region) assert_region_has_expected_values(r, expected_region) + + seq = SpecsSequence[LensMode, PsuMode].model_validate( + {"regions": expected_region_values} + ) + for r, expected_r in zip(seq.regions, expected_region_values, strict=True): + assert_region_has_expected_values(r, expected_r) diff --git a/tests/devices/electron_analyser/test_data/__init__.py b/tests/devices/electron_analyser/test_data/__init__.py index 3ffb5858e30..e0127f00454 100644 --- a/tests/devices/electron_analyser/test_data/__init__.py +++ b/tests/devices/electron_analyser/test_data/__init__.py @@ -4,5 +4,6 @@ TEST_DATA_PATH = Path(__file__).parent TEST_VGSCIENTA_SEQUENCE = join(TEST_DATA_PATH, "vgscienta_sequence.seq") TEST_SPECS_SEQUENCE = join(TEST_DATA_PATH, "specs_sequence.seq") +TEST_MBS_XML_SEQUENCE = join(TEST_DATA_PATH, "mbs_region1.arpes") -__all__ = ["TEST_SPECS_SEQUENCE", "TEST_VGSCIENTA_SEQUENCE"] +__all__ = ["TEST_SPECS_SEQUENCE", "TEST_VGSCIENTA_SEQUENCE", "TEST_MBS_XML_SEQUENCE"] diff --git a/tests/devices/electron_analyser/test_data/mbs_region1.arpes b/tests/devices/electron_analyser/test_data/mbs_region1.arpes new file mode 100644 index 00000000000..355c1c8569c --- /dev/null +++ b/tests/devices/electron_analyser/test_data/mbs_region1.arpes @@ -0,0 +1,13 @@ + + + L4Ang0d8 + Swept + 5 + 72.386 + 73.814 + 0.405 + 1.0 + 3 + 0.0 + true + diff --git a/tests/devices/electron_analyser/vgscienta/test_vgscienta_driver_io.py b/tests/devices/electron_analyser/vgscienta/test_vgscienta_driver_io.py index fda8f88eb8e..deca833bf39 100644 --- a/tests/devices/electron_analyser/vgscienta/test_vgscienta_driver_io.py +++ b/tests/devices/electron_analyser/vgscienta/test_vgscienta_driver_io.py @@ -1,16 +1,11 @@ from unittest.mock import ANY -import numpy as np import pytest from bluesky import plan_stubs as bps from bluesky.run_engine import RunEngine from bluesky.utils import FailedStatus -from ophyd_async.core import StrictEnum, get_mock_put, set_mock_value -from ophyd_async.testing import ( - assert_configuration, - assert_reading, - partial_reading, -) +from ophyd_async.core import StrictEnum, get_mock_put +from ophyd_async.testing import assert_configuration, partial_reading from dodal.devices.beamlines.i09 import LensMode, PassEnergy, PsuMode from dodal.devices.electron_analyser.vgscienta import ( @@ -18,10 +13,7 @@ VGScientaDetector, VGScientaRegion, ) -from tests.devices.electron_analyser.helper_util import ( - TEST_SEQUENCE_REGION_NAMES, - get_test_sequence, -) +from tests.devices.electron_analyser.helper_util import load_i09_vgscienta_test_seq @pytest.fixture @@ -31,12 +23,7 @@ async def sim_driver( return ew4000.driver -@pytest.fixture -def sequence(sim_driver: VGScientaAnalyserDriverIO[LensMode, PsuMode, PassEnergy]): - return get_test_sequence(type(sim_driver)) - - -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize("region", load_i09_vgscienta_test_seq().regions) async def test_analyser_sets_region_correctly( sim_driver: VGScientaAnalyserDriverIO[LensMode, PsuMode, PassEnergy], region: VGScientaRegion[LensMode, PassEnergy], @@ -66,7 +53,7 @@ async def test_analyser_sets_region_correctly( get_mock_put(sim_driver.region_size_y).assert_called_once_with(region.size_y) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +@pytest.mark.parametrize("region", load_i09_vgscienta_test_seq().regions) async def test_analyser_sets_region_and_read_configuration_is_correct( sim_driver: VGScientaAnalyserDriverIO[LensMode, PsuMode, PassEnergy], region: VGScientaRegion[LensMode, PassEnergy], @@ -106,30 +93,6 @@ async def test_analyser_sets_region_and_read_configuration_is_correct( ) -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) -async def test_analyser_sets_region_and_read_is_correct( - sim_driver: VGScientaAnalyserDriverIO[LensMode, PsuMode, PassEnergy], - region: VGScientaRegion[LensMode, PassEnergy], - run_engine: RunEngine, -) -> None: - run_engine(bps.mv(sim_driver, region), wait=True) - - spectrum = np.array([1, 2, 3, 4, 5], dtype=float) - expected_total_intensity = np.sum(spectrum) - set_mock_value(sim_driver.spectrum, spectrum) - - prefix = sim_driver.name + "-" - - await assert_reading( - sim_driver, - { - f"{prefix}image": partial_reading(ANY), - f"{prefix}spectrum": partial_reading(spectrum), - f"{prefix}total_intensity": partial_reading(expected_total_intensity), - }, - ) - - def test_driver_throws_error_with_wrong_pass_energy( sim_driver: VGScientaAnalyserDriverIO[LensMode, PsuMode, PassEnergy], run_engine: RunEngine, diff --git a/tests/devices/electron_analyser/vgscienta/test_vgscienta_region.py b/tests/devices/electron_analyser/vgscienta/test_vgscienta_region.py index 9933b9a0919..1e24a4ddc73 100644 --- a/tests/devices/electron_analyser/vgscienta/test_vgscienta_region.py +++ b/tests/devices/electron_analyser/vgscienta/test_vgscienta_region.py @@ -2,7 +2,7 @@ import pytest -from dodal.devices.beamlines.i09 import LensMode, PassEnergy, PsuMode +from dodal.devices.beamlines.i09 import LensMode, PassEnergy from dodal.devices.electron_analyser.base import EnergyMode from dodal.devices.electron_analyser.vgscienta import ( AcquisitionMode, @@ -13,20 +13,10 @@ from dodal.devices.selectable_source import SelectedSource from tests.devices.electron_analyser.helper_util import ( assert_region_has_expected_values, - get_test_sequence, + load_i09_vgscienta_test_seq, ) -@pytest.fixture -def sequence() -> VGScientaSequence[LensMode, PsuMode, PassEnergy]: - return get_test_sequence(VGScientaSequence[LensMode, PsuMode, PassEnergy]) - - -@pytest.fixture -def expected_region_class() -> type[VGScientaRegion[LensMode, PassEnergy]]: - return VGScientaRegion[LensMode, PassEnergy] - - @pytest.fixture def expected_region_values() -> list[dict[str, Any]]: return [ @@ -102,20 +92,12 @@ def expected_region_values() -> list[dict[str, Any]]: ] -def test_load_sequence_using_alias_field_names_has_expected_enabled_region_names( - sequence: VGScientaSequence[LensMode, PsuMode, PassEnergy], - expected_enabled_region_names: list[str], -) -> None: - assert sequence.get_enabled_region_names() == expected_enabled_region_names - for i, region in enumerate(sequence.get_enabled_regions()): - assert region.name == expected_enabled_region_names[i] - - def test_load_sequence_using_alias_field_names_has_expected_values( - sequence: VGScientaSequence[LensMode, PsuMode, PassEnergy], expected_region_values: list[dict[str, Any]], ) -> None: - for i, r in zip(sequence.regions, expected_region_values, strict=True): + for i, r in zip( + load_i09_vgscienta_test_seq().regions, expected_region_values, strict=True + ): assert_region_has_expected_values(i, r) @@ -125,3 +107,8 @@ def test_region_loads_using_field_names_has_expected_values( for expected_region in expected_region_values: r = VGScientaRegion[LensMode, PassEnergy].model_validate(expected_region) assert_region_has_expected_values(r, expected_region) + seq = VGScientaSequence[LensMode, PassEnergy].model_validate( + {"regions": expected_region_values} + ) + for r, expected_r in zip(seq.regions, expected_region_values, strict=True): + assert_region_has_expected_values(r, expected_r)