diff --git a/CHANGES.md b/CHANGES.md index 02814d6..0e786d7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,3 +8,13 @@ - add support for UDP data stream (only json for now) - add support for RTT interface + +## 1.0.0 (18/03/2026) + +- support for nxslib 1.0.0 +- move nxscli-np plugin to nxscli +- support for dynamic plugins management +- add provider stream routing support +- add support for control server +- add shared sample and window transform modules +- add virtual channel runtime and vadd command diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a62ce18..4a9a06a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ ### Recommended tools -We use [tox](https://github.com/tox-dev/tox) to automate tedious developer's tasks, +We use [tox](https://github.com/tox-dev/tox) to automate tedious developer's tasks, thus installing it is highly recommended. ``` @@ -31,7 +31,43 @@ source venv/bin/activate ### Code style and running tests -Code formatting is ensured by [black](https://github.com/psf/black) and [isort](https://github.com/PyCQA/isort). +#### Docstring Format + +This project uses Sphinx-style docstrings exclusively. +Do not use Google-style or NumPy-style docstrings. + +Correct (Sphinx style): +```python +def example_function(param1, param2): + """Brief description of function. + + Longer description if needed. + + :param param1: description of param1 + :param param2: description of param2 + :return: description of return value + :raises ValueError: description of when this is raised + """ +``` + +Incorrect (Google style) - DO NOT USE: +```python +def example_function(param1, param2): + """Brief description of function. + + Args: + param1: description of param1 + param2: description of param2 + + Returns: + description of return value + """ +``` + +#### Code Formatting + +Code formatting is ensured by [black](https://github.com/psf/black) +and [isort](https://github.com/PyCQA/isort). To reformat your changes, use: ``` @@ -51,7 +87,7 @@ Flake8 linter is available with: tox -e flake8 ``` -CI requres 100% coverage to pass. If some of your changes can't be easy tested, +CI requires 100% coverage to pass. If some of your changes can't be easy tested, you can exclude code from coverage with `#pragma: no cover` comment. To run tests with coverage report run: diff --git a/README.md b/README.md index f684ed2..7e16eb1 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,11 @@ Nxscli is a command-line client package for the [Apache NuttX](https://nuttx.apache.org/) NxScope real-time logging module. +It is also a reusable Python runtime layer for NxScope streaming, channel +control, triggers, and plugin orchestration. The `nxscli` internals are used +by other tools (for example GUI applications) to build more advanced workflows +without re-implementing NxScope communication logic. + Compatible with Python 3.10+. ## Features @@ -11,35 +16,37 @@ Compatible with Python 3.10+. * Plugins architecture, extendable through ``nxscli.extensions`` entrypoint * Client-based triggering (global and per-channel triggers) * Save data to CSV files +* Save data to Numpy files (`pnpsave`) and memmap files (`pnpmem`) * Print samples * Stream data over UDP (compatible with [PlotJuggler](https://github.com/facontidavide/PlotJuggler)) * NxScope protocol via serial port or Segger RTT interface +* Virtual channels and math operations on channels data +* Optional control server (`--control-server`) for extensions ## Features Planned * More triggering types * Boolean operations on triggers -* Virtual channels and math operations on channels data * Improve `pdevinfo` output (human-readable prints) * Interactive mode ## Plugins -By default, we only support features that depend on the standard Python libraries. -The functionality is expadned by installing plugins. -Plugins are automatically deteceted by Nxscli. +By default, `nxscli` ships with core plugins including CSV, printer, UDP, +and NumPy file capture (`pnpsave` and `pnpmem`). +Additional functionality is expanded by installing optional plugins. +Plugins are automatically detected by Nxscli. Available plugins: * [nxscli-mpl](https://github.com/railab/nxscli-mpl) - Matplotlib extension -* [nxscli-np](https://github.com/railab/nxscli-np) - Numpy extension ## Plugins Planned * Stream data as audio (inspired by audio knock detection systems) * PyQtGraph support -## Instalation +## Installation Nxscli can be installed by running `pip install nxscli`. @@ -51,10 +58,22 @@ To install latest development version, use: Look at [docs/usage](docs/usage.rst). +## Reuse as a Library + +`nxscli` is not only a CLI frontend. It can be imported and reused by external +applications that need: + +* NxScope connection handling (serial/RTT and compatible interfaces) +* channel configuration and stream lifecycle control +* plugin loading and runtime execution +* trigger and data-processing orchestration + +This makes `nxscli` the integration layer for higher-level tools such as +custom dashboards, GUIs, and automation scripts. + ## Contributing All contributions are welcome to this project. To get started with developing Nxscli, see [CONTRIBUTING.md](CONTRIBUTING.md). - diff --git a/docs/library.rst b/docs/library.rst new file mode 100644 index 0000000..78b3b17 --- /dev/null +++ b/docs/library.rst @@ -0,0 +1,44 @@ +Using Nxscli as a Library +------------------------- + +Nxscli is not only a CLI frontend. It can also be reused as a Python +integration layer for higher-level tools such as GUIs, dashboards, and +automation scripts. + +Typical reusable components include: + +* NxScope connection and interface handling +* channel configuration and stream lifecycle control +* plugin loading and runtime orchestration +* trigger and data-processing integration + +Minimal example +=============== + +.. code-block:: python + + from nxscli.plugins_loader import plugins_list + from nxscli.phandler import PluginHandler + from nxslib.intf.dummy import DummyDev + from nxslib.proto.parse import Parser + from nxslib.nxscope import NxscopeHandler + + intf = DummyDev() + parse = Parser() + + with NxscopeHandler(intf, parse) as nxscope: + with PluginHandler(plugins_list) as phandler: + phandler.nxscope_connect(nxscope) + + # Configure and run as needed by your application: + nxscope.ch_enable([0], writenow=True) + nxscope.stream_start() + pid = phandler.plugin_start_dynamic("pprinter", channels=[0]) + + # ... do work ... + + phandler.plugin_stop_dynamic(pid) + nxscope.stream_stop() + # phandler.cleanup() called automatically on exit + # nxscope.disconnect() called automatically on exit + diff --git a/docs/usage.rst b/docs/usage.rst index a7a66d5..c775dd2 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -1,8 +1,9 @@ +===== Usage ------ +===== Commands ---------- +======== You can run Nxscli as a Python module: @@ -24,6 +25,13 @@ the same time. For commands details use ``--help`` option. +Global options include: + +* ``--control-server`` - enable optional control server plugin + (disabled by default). +* ``--control-endpoint`` - server endpoint + (``unix://``, ``unix-abstract://`` or ``tcp://``). + The following example illustrates how to run multiple plugins simultaneously with various channel configurations (based on ``pcap`` from ``nxscli-mpl``): @@ -31,9 +39,12 @@ with various channel configurations (based on ``pcap`` from ``nxscli-mpl``): python -m nxscli dummy chan 1,2,3,4 pcap --chan 1 100 pcap --chan 2,3 200 pcap 300 +Library integration guide: + +* :doc:`library` -Interace commands -================= +Interface Commands +------------------ Supported interface commands: @@ -41,23 +52,34 @@ Supported interface commands: Available device channels: - - chan0 - vdim = 1, random() - - chan1 - vdim = 1, saw wave - - chan2 - vdim = 1, triangle wave - - chan3 - vdim = 2, random() - - chan4 - vdim = 3, random() - - chan5 - vdim = 3, static vector = [1.0, 0.0, -1.0] - - chan6 - vdim = 1, 'hello' string - - chan7 - vdim = 3, static vector = [1.0, 0.0, -1.0], meta = 1B int - - chan8 - vdim = 0, meta = 'hello string', mlen = 16 - - chan9 - vdim = 3, 3-phase sine wave + - 0: noise_uniform_scalar - vdim = 1, random() + - 1: ramp_saw_up - vdim = 1, saw wave + - 2: ramp_triangle - vdim = 1, triangle wave + - 3: noise_uniform_vec2 - vdim = 2, random() + - 4: noise_uniform_vec3 - vdim = 3, random() + - 5: static_vec3 - vdim = 3, static vector = [1.0, 0.0, -1.0] + - 6: text_hello_sparse - vdim = 1, sparse 'hello' string + - 7: static_vec3_meta_counter - vdim = 3, static vec + 1B meta counter + - 8: meta_hello_only - vdim = 0, mlen = 16, meta = 'hello string' + - 9: sine_three_phase - vdim = 3, 3-phase sine wave + - 10: reserved (undefined) + - 11: fft_multitone - vdim = 1, deterministic multi-tone + - 12: fft_chirp - vdim = 1, deterministic chirp-like signal + - 13: hist_gaussian - vdim = 1, deterministic Gaussian-like + - 14: hist_bimodal - vdim = 1, deterministic bi-modal + - 15: xy_lissajous - vdim = 2, correlated XY signal + - 16: polar_theta_radius - vdim = 2, (theta, radius) signal + - 17: step_up_once - vdim = 1, one rising step + - 18: step_down_once - vdim = 1, one falling step + - 19: pulse_square_20p - vdim = 1, periodic square pulse (20% duty) + - 20: pulse_single_sparse - vdim = 1, one-sample pulse every 250 samples * ``serial`` - select serial port NxScope interface * ``rtt`` - select Segger RTT as NxScope interface -Configuratio commands -===================== +Configuration Commands +---------------------- Available configuration commands: @@ -74,16 +96,98 @@ Available configuration commands: Triggers can be configured per channel with the option ``--trig``. +* ``vadd`` - add virtual channel in `nxscli` virtual runtime. + + This command declares a derived channel from one or more inputs. + The command is non-interactive and can be chained with plugin commands. + Use ``--operator`` to select transform and ``--params`` for + comma-separated ``key=value`` operator arguments. + + Example command form: + + .. code-block:: bash -Plugin commands -=============== + python -m nxscli dummy vadd --operator scale_offset --params scale=2,offset=1 100 0 pprinter --chan v100 10 + + In command chaining, place command options before positional arguments. + For virtual data output, select virtual channel explicitly via plugin + ``--chan vNN`` (for example ``--chan v100``). + Source physical channels from ``vadd`` inputs are auto-configured. + + +Plugin Commands +--------------- Plugins supported so far: * ``pcsv`` - store samples in CSV files +* ``pnpsave`` - store samples in Numpy ``.npy`` files +* ``pnpmem`` - store samples in Numpy memmap ``.dat`` files * ``pdevinfo`` - show information about the connected NxScope device * ``pnone`` - capture data and do nothing with them * ``pprinter`` - capture data and print samples * ``pudp`` - stream data over UDP For more information, use the plugin's ``--help`` option. + +Dummy Device Cheatsheet +======================= + +Use these commands for quick local testing without hardware. +All examples use the ``dummy`` interface and channel ``0``. + +Device info +=========== + +.. code-block:: bash + + python -m nxscli dummy pdevinfo + +Print stream samples +==================== + +.. code-block:: bash + + python -m nxscli dummy chan 0 pprinter 50 + +Capture and discard samples +=========================== + +.. code-block:: bash + + python -m nxscli dummy chan 0 pnone 50000 + +Store samples to CSV +==================== + +.. code-block:: bash + + python -m nxscli dummy chan 0 pcsv 200 /tmp/nxscope_csv + +Store samples to Numpy files +---------------------------- + +.. code-block:: bash + + python -m nxscli dummy chan 0 pnpsave 200 /tmp/nxscope_np + +Store samples to Numpy memmap +----------------------------- + +.. code-block:: bash + + python -m nxscli dummy chan 0 pnpmem 200 /tmp/nxscope_mem 100 + +Stream samples over UDP +======================= + +.. code-block:: bash + + python -m nxscli dummy chan 0 pudp 2000 --address 127.0.0.1 --port 9870 + +Run multiple plugins in one command +=================================== + +.. code-block:: bash + + python -m nxscli dummy chan 0 pprinter 20 pcsv 20 /tmp/nxscope_csv pudp 20 diff --git a/pyproject.toml b/pyproject.toml index 016f1a7..887c5f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,15 +4,16 @@ build-backend = 'setuptools.build_meta' [project] name = "nxscli" -version = "0.5.1" +version = "1.0.0" authors = [{name = "raiden00", email = "raiden00@railab.me"}] description = "Nxscope CLI client" license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.10" dependencies = [ - "nxslib>=0.9.1", - "click>=8.1" + "nxslib>=1.0.0", + "click>=8.1", + "numpy" ] classifiers = [ "Development Status :: 4 - Beta", @@ -41,4 +42,4 @@ target-version = ['py310'] [tool.isort] profile = "black" -line_length = 79 \ No newline at end of file +line_length = 79 diff --git a/src/nxscli/channelref.py b/src/nxscli/channelref.py new file mode 100644 index 0000000..f605cc4 --- /dev/null +++ b/src/nxscli/channelref.py @@ -0,0 +1,55 @@ +"""Typed channel reference model.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ChannelRef: + """Channel reference for physical/virtual/all channel selectors.""" + + kind: str + value: int | None = None + + @classmethod + def all_channels(cls) -> "ChannelRef": + """Select all physical channels.""" + return cls(kind="all", value=None) + + @classmethod + def physical(cls, channel_id: int) -> "ChannelRef": + """Select physical channel ``channel_id``.""" + return cls(kind="physical", value=channel_id) + + @classmethod + def virtual(cls, virtual_id: int) -> "ChannelRef": + """Select virtual channel ``v{virtual_id}``.""" + return cls(kind="virtual", value=virtual_id) + + @property + def is_all(self) -> bool: + """Return ``True`` if this is all-physical selector.""" + return self.kind == "all" + + @property + def is_physical(self) -> bool: + """Return ``True`` if this is physical channel selector.""" + return self.kind == "physical" + + @property + def is_virtual(self) -> bool: + """Return ``True`` if this is virtual channel selector.""" + return self.kind == "virtual" + + def physical_id(self) -> int: + """Return physical channel id.""" + if not self.is_physical: + raise ValueError("not a physical channel reference") + assert self.value is not None + return self.value + + def virtual_name(self) -> str: + """Return virtual channel name (for example ``v0``).""" + if not self.is_virtual: + raise ValueError("not a virtual channel reference") + assert self.value is not None + return f"v{self.value}" diff --git a/src/nxscli/cli/environment.py b/src/nxscli/cli/environment.py index bacfd51..d2d750f 100644 --- a/src/nxscli/cli/environment.py +++ b/src/nxscli/cli/environment.py @@ -9,6 +9,7 @@ from nxslib.nxscope import NxscopeHandler from nxslib.proto.parse import Parser + from nxscli.channelref import ChannelRef from nxscli.phandler import PluginHandler from nxscli.trigger import DTriggerConfigReq @@ -28,9 +29,10 @@ class DEnvironmentData: parser: "Parser | None" = None interface: bool = False needchannels: bool = False - channels: tuple[list[int], Any] | None = None + channels: tuple[list["ChannelRef"], Any] | None = None phandler: "PluginHandler | None" = None triggers: dict[int, "DTriggerConfigReq"] | None = None + nxscope_plugins: list[Any] | None = None ############################################################################### diff --git a/src/nxscli/cli/main.py b/src/nxscli/cli/main.py index 5c22a92..3259e52 100644 --- a/src/nxscli/cli/main.py +++ b/src/nxscli/cli/main.py @@ -7,6 +7,7 @@ from nxslib.proto.parse import Parser from nxscli.cli.environment import Environment, pass_environment +from nxscli.control_server import ControlServerPlugin from nxscli.iplugin import EPluginType, IPlugin from nxscli.logger import logger from nxscli.phandler import PluginHandler @@ -24,8 +25,23 @@ is_flag=True, envvar="NXSCLI_DEBUG", ) +@click.option( + "--control-server/--no-control-server", + default=False, + help="Enable nxscli control server plugin (default: disabled).", +) +@click.option( + "--control-endpoint", + default="unix-abstract://nxscli-control", + help="Control server endpoint: unix://, unix-abstract:// or tcp://.", +) @pass_environment -def main(ctx: Environment, debug: bool) -> bool: +def main( + ctx: Environment, + debug: bool, + control_server: bool, + control_endpoint: str, +) -> bool: """Nxscli - Command-line clinet to the NxScope.""" ctx.debug = debug if debug: # pragma: no cover @@ -37,6 +53,13 @@ def main(ctx: Environment, debug: bool) -> bool: parse = Parser() ctx.parser = parse ctx.triggers = {} + ctx.nxscope_plugins = [] + if control_server: + try: + ctx.nxscope_plugins.append(ControlServerPlugin(control_endpoint)) + except Exception: + ctx.phandler.cleanup() + raise click.get_current_context().call_on_close(cli_on_close) @@ -109,6 +132,23 @@ def wait_for_plugins(ctx: Environment) -> None: # pragma: no cover ctx.phandler.wait_for_plugins() +def _cli_validate_ready(ctx: Environment) -> bool: + """Validate common preconditions before plugin start.""" + assert ctx.phandler + if ctx.interface is False: + return False + + if len(ctx.phandler.enabled) == 0: + click.secho("ERROR: No plugins selected !", err=True, fg="red") + return False + + if ctx.needchannels and ctx.channels is None: # pragma: no cover + click.secho("ERROR: No channels selected !", err=True, fg="red") + return False + + return True + + ############################################################################### # Function: cli_on_close ############################################################################### @@ -118,50 +158,49 @@ def wait_for_plugins(ctx: Environment) -> None: # pragma: no cover def cli_on_close(ctx: Environment) -> bool: """Handle requested plugins on click close.""" assert ctx.phandler - # do not show any errors if it was help request - if "--help" in sys.argv: # pragma: no cover - ctx.phandler.cleanup() - return True - - if ctx.interface is False: - ctx.phandler.cleanup() - return False + reg_names: list[str] = [] + with ctx.phandler: + # do not show any errors if it was help request + if "--help" in sys.argv: # pragma: no cover + return True - if len(ctx.phandler.enabled) == 0: - click.secho("ERROR: No plugins selected !", err=True, fg="red") - ctx.phandler.cleanup() - return False - - if ctx.needchannels: - if ctx.channels is None: # pragma: no cover - click.secho("ERROR: No channels selected !", err=True, fg="red") - ctx.phandler.cleanup() + if _cli_validate_ready(ctx) is False: return False - # connect nxscope to phandler - assert ctx.nxscope - ctx.phandler.nxscope_connect(ctx.nxscope) + # connect nxscope to phandler + assert ctx.nxscope + ctx.phandler.nxscope_connect(ctx.nxscope) + + # register optional nxscope-side plugins + if ctx.nxscope_plugins: + for plugin in ctx.nxscope_plugins: + reg_names.append(ctx.nxscope.register_plugin(plugin)) - # configure channles after connected to nxscope - if ctx.needchannels and ctx.channels: - ctx.phandler.channels_configure(ctx.channels[0], ctx.channels[1]) + try: + # configure channles after connected to nxscope + if ctx.needchannels and ctx.channels: + ctx.phandler.channels_configure( + ctx.channels[0], ctx.channels[1] + ) - # start plugins - ctx.phandler.start() + # start plugins + ctx.phandler.start() - if ctx.waitenter: # pragma: no cover - _ = input("wait for Enter key...") + if ctx.waitenter: # pragma: no cover + _ = input("wait for Enter key...") - # plugins loop - plugin_loop(ctx) + # plugins loop + plugin_loop(ctx) - # wait for plugin - wait_for_plugins(ctx) + # wait for plugin + wait_for_plugins(ctx) - print("closing...") - ctx.phandler.stop() - ctx.phandler.nxscope_disconnect() - ctx.phandler.cleanup() + print("closing...") + ctx.phandler.stop() + finally: + for name in reg_names: + ctx.nxscope.unregister_plugin(name) + ctx.phandler.nxscope_disconnect() return True diff --git a/src/nxscli/cli/types.py b/src/nxscli/cli/types.py index f53e1f6..ae66e86 100644 --- a/src/nxscli/cli/types.py +++ b/src/nxscli/cli/types.py @@ -1,9 +1,10 @@ -"""Module containing the Click types.""" +"""Module containing the Click types.""" # noqa: A005 from typing import Any import click +from nxscli.channelref import ChannelRef from nxscli.trigger import DTriggerConfigReq ############################################################################### @@ -53,23 +54,29 @@ class Channels(click.ParamType): name = "channels" - def convert(self, value: Any, param: Any, ctx: Any) -> list[int]: + def convert(self, value: Any, param: Any, ctx: Any) -> list[ChannelRef]: """Convert channels argument.""" - lint = [] + lint: list[ChannelRef] = [] if value == "all": # special case to indicate all channels - lint.append(-1) + lint.append(ChannelRef.all_channels()) return lint lstr = get_list_from_str(value) for ch in lstr: + if ch.startswith("v"): + virt = ch[1:] + assert virt.isnumeric(), "virtual channel id must be numeric" + lint.append(ChannelRef.virtual(int(virt))) + continue + assert ch.isnumeric(), "channel id must be a valid integer" chan = int(ch) if chan < 0 or chan > 255: raise click.BadParameter( "channel id must be in range [0, 255]" ) - lint.append(chan) + lint.append(ChannelRef.physical(chan)) return lint diff --git a/src/nxscli/commands/cmd_npmem.py b/src/nxscli/commands/cmd_npmem.py new file mode 100644 index 0000000..0def9e7 --- /dev/null +++ b/src/nxscli/commands/cmd_npmem.py @@ -0,0 +1,53 @@ +"""Module containing Numpy memmap plugin command.""" + +from typing import TYPE_CHECKING + +import click + +from nxscli.cli.environment import Environment, pass_environment +from nxscli.cli.types import Samples, capture_options + +if TYPE_CHECKING: + from nxscli.trigger import DTriggerConfigReq + + +@click.command(name="pnpmem") +@click.argument("samples", type=Samples(), required=True) +@click.argument("path", type=click.Path(resolve_path=False), required=True) +@click.argument("shape", type=int, required=True) +@capture_options +@pass_environment +def cmd_pnpmem( + ctx: Environment, + samples: int, + path: str, + shape: int, + chan: list[int], + trig: dict[int, "DTriggerConfigReq"], +) -> bool: + """[plugin] Store samples in Numpy memmap files. + + Each configured channel will be written to a separate file. + + If SAMPLES argument is set to 'i' then we capture data until enter + is press. + + The 'shape' argument defines the second dimension of the memmap array. + """ # noqa: D301 + assert ctx.phandler + if samples == 0: # pragma: no cover + ctx.waitenter = True + + ctx.phandler.enable( + "npmem", + samples=samples, + path=path, + channels=chan, + shape=shape, + trig=trig, + nostop=ctx.waitenter, + ) + + ctx.needchannels = True + + return True diff --git a/src/nxscli/commands/cmd_npsave.py b/src/nxscli/commands/cmd_npsave.py new file mode 100644 index 0000000..5be6353 --- /dev/null +++ b/src/nxscli/commands/cmd_npsave.py @@ -0,0 +1,47 @@ +"""Module containing Numpy capture plugin command.""" + +from typing import TYPE_CHECKING + +import click + +from nxscli.cli.environment import Environment, pass_environment +from nxscli.cli.types import Samples, capture_options + +if TYPE_CHECKING: + from nxscli.trigger import DTriggerConfigReq + + +@click.command(name="pnpsave") +@click.argument("samples", type=Samples(), required=True) +@click.argument("path", type=click.Path(resolve_path=False), required=True) +@capture_options +@pass_environment +def cmd_pnpsave( + ctx: Environment, + samples: int, + path: str, + chan: list[int], + trig: dict[int, "DTriggerConfigReq"], +) -> bool: + """[plugin] Store samples in Numpy files. + + Each configured channel will be stored in a separate file. + If SAMPLES argument is set to 'i' then we capture data until enter + is press. + """ # noqa: D301 + assert ctx.phandler + if samples == 0: # pragma: no cover + ctx.waitenter = True + + ctx.phandler.enable( + "npsave", + samples=samples, + path=path, + channels=chan, + trig=trig, + nostop=ctx.waitenter, + ) + + ctx.needchannels = True + + return True diff --git a/src/nxscli/commands/config/cmd_chan.py b/src/nxscli/commands/config/cmd_chan.py index f21c36a..e152e84 100644 --- a/src/nxscli/commands/config/cmd_chan.py +++ b/src/nxscli/commands/config/cmd_chan.py @@ -1,12 +1,15 @@ """Module containint the channels configuration command for CLI.""" -from typing import Any +from typing import TYPE_CHECKING, Any import click from nxscli.cli.environment import Environment, pass_environment from nxscli.cli.types import Channels, Divider, divider_option_help +if TYPE_CHECKING: + from nxscli.channelref import ChannelRef + ############################################################################### # Command: cmd_chan ############################################################################### @@ -18,7 +21,9 @@ "--divider", default="0", type=Divider(), help=divider_option_help ) @pass_environment -def cmd_chan(ctx: Environment, channels: list[int], divider: Any) -> bool: +def cmd_chan( + ctx: Environment, channels: list["ChannelRef"], divider: Any +) -> bool: """[config] Channels declaration and configuration. This command configure and enable given channels. diff --git a/src/nxscli/commands/config/cmd_vadd.py b/src/nxscli/commands/config/cmd_vadd.py new file mode 100644 index 0000000..aeeb1ed --- /dev/null +++ b/src/nxscli/commands/config/cmd_vadd.py @@ -0,0 +1,119 @@ +"""Virtual channel declaration command.""" + +from typing import TYPE_CHECKING + +import click + +from nxscli.channelref import ChannelRef +from nxscli.cli.environment import Environment, pass_environment +from nxscli.cli.types import StringList +from nxscli.virtual.services import get_runtime + +if TYPE_CHECKING: + from nxscli.phandler import PluginHandler + + +def _get_phandler(ctx: Environment) -> "PluginHandler": + assert ctx.phandler is not None + return ctx.phandler + + +def _parse_param_value(raw: str) -> object: + low = raw.lower() + if low in ("true", "false"): + return low == "true" + try: + if "." in raw: + return float(raw) + return int(raw) + except ValueError: + return raw + + +def _parse_params(params: list[str]) -> dict[str, object]: + parsed: dict[str, object] = {} + for token in params: + token = token.strip() + if not token: + continue + if "=" not in token: + raise click.BadParameter(f"Invalid param token: {token}") + key, raw = token.split("=", 1) + key = key.strip() + raw = raw.strip() + if not key: + raise click.BadParameter("Parameter key must not be empty") + parsed[key] = _parse_param_value(raw) + return parsed + + +def _merge_required_sources(ctx: Environment, inputs: list[str]) -> None: + """Ensure physical virtual-input sources are configured for streaming.""" + required: list[ChannelRef] = [] + for token in inputs: + tok = token.strip() + if tok.isnumeric(): + required.append(ChannelRef.physical(int(tok))) + + if not required: + return + + if ctx.channels is None: + ctx.channels = (required, 0) + return + + channels, divider = ctx.channels + if any(ref.is_all for ref in channels): + return + + merged = list(channels) + for ref in required: + if ref not in merged: + merged.append(ref) + ctx.channels = (merged, divider) + + +@click.command(name="vadd") +@click.argument("channel_id", type=int) +@click.argument("inputs", type=StringList()) +@click.option("--name", type=str, default=None) +@click.option( + "--operator", + type=click.Choice( + [ + "scale_offset", + "math_binary", + "stats_running", + ] + ), + default="scale_offset", +) +@click.option( + "--params", + type=StringList(), + default="", + help="Operator params in key=value format, comma separated", +) +@pass_environment +def cmd_vadd( + ctx: Environment, + channel_id: int, + name: str | None, + operator: str, + inputs: list[str], + params: list[str], +) -> bool: + """[config] Add virtual channel to shared runtime.""" + _merge_required_sources(ctx, inputs) + runtime = get_runtime(_get_phandler(ctx)) + parsed_params = _parse_params(params) + aliases = runtime.add_virtual_channel( + channel_id=channel_id, + name=name or f"virt{channel_id}", + operator=operator, + inputs=tuple(inputs), + params=parsed_params, + ) + for alias, output_id in aliases: + click.echo(f"virtual output {output_id} -> channel {alias}") + return True diff --git a/src/nxscli/commands/interface/cmd_dummy.py b/src/nxscli/commands/interface/cmd_dummy.py index f4a28ab..a39198f 100644 --- a/src/nxscli/commands/interface/cmd_dummy.py +++ b/src/nxscli/commands/interface/cmd_dummy.py @@ -14,10 +14,16 @@ @click.group(name="dummy", chain=True) @click.option("--writepadding", default=0) @click.option( - "--streamsleep", type=float, default=0.001, help="dummy dev parameter. Default: 0.001" + "--streamsleep", + type=float, + default=0.001, + help="dummy dev parameter. Default: 0.001", ) @click.option( - "--samplesnum", type=int, default=100, help="dummy dev parameter. Default: 100" + "--samplesnum", + type=int, + default=100, + help="dummy dev parameter. Default: 100", ) @pass_environment def cmd_dummy( @@ -27,16 +33,27 @@ def cmd_dummy( \b Channels data: - chan0 - vdim = 1, random() - chan1 - vdim = 1, saw wave - chan2 - vdim = 1, triangle wave - chan3 - vdim = 2, random() - chan4 - vdim = 3, random() - chan5 - vdim = 3, static vector = [1.0, 0.0, -1.0] - chan6 - vdim = 1, 'hello' string - chan7 - vdim = 3, static vector = [1.0, 0.0, -1.0], meta = 1B int - chan8 - vdim = 0, meta = 'hello string', mlen = 16 - chan9 - vdim = 3, 3-phase sine wave + 0: noise_uniform_scalar - vdim = 1, random() + 1: ramp_saw_up - vdim = 1, saw wave + 2: ramp_triangle - vdim = 1, triangle wave + 3: noise_uniform_vec2 - vdim = 2, random() + 4: noise_uniform_vec3 - vdim = 3, random() + 5: static_vec3 - vdim = 3, static vector = [1.0, 0.0, -1.0] + 6: text_hello_sparse - vdim = 1, sparse 'hello' string + 7: static_vec3_meta_counter - vdim = 3, static vec + 1B meta counter + 8: meta_hello_only - vdim = 0, mlen = 16, meta = 'hello string' + 9: sine_three_phase - vdim = 3, 3-phase sine wave + 10: reserved (undefined) + 11: fft_multitone - vdim = 1, deterministic multi-tone + 12: fft_chirp - vdim = 1, deterministic chirp-like signal + 13: hist_gaussian - vdim = 1, deterministic Gaussian-like + 14: hist_bimodal - vdim = 1, deterministic bi-modal + 15: xy_lissajous - vdim = 2, correlated XY signal + 16: polar_theta_radius - vdim = 2, (theta, radius) signal + 17: step_up_once - vdim = 1, one rising step + 18: step_down_once - vdim = 1, one falling step + 19: pulse_square_20p - vdim = 1, periodic square pulse (20% duty) + 20: pulse_single_sparse - vdim = 1, one-sample pulse every 250 samples """ # noqa: D301 intf = DummyDev( rxpadding=writepadding, @@ -46,7 +63,12 @@ def cmd_dummy( # initialize nxslib communication handler assert ctx.parser - ctx.nxscope = NxscopeHandler(intf, ctx.parser) + ctx.nxscope = NxscopeHandler( + intf, + ctx.parser, + enable_bitrate_tracking=True, + stream_decode_mode="numpy", + ) ctx.interface = True diff --git a/src/nxscli/commands/interface/cmd_rtt.py b/src/nxscli/commands/interface/cmd_rtt.py index 0d0a3e5..bd60790 100644 --- a/src/nxscli/commands/interface/cmd_rtt.py +++ b/src/nxscli/commands/interface/cmd_rtt.py @@ -44,7 +44,12 @@ def cmd_rtt( # initialize nxslib communication handler assert ctx.parser - ctx.nxscope = NxscopeHandler(intf, ctx.parser) + ctx.nxscope = NxscopeHandler( + intf, + ctx.parser, + enable_bitrate_tracking=True, + stream_decode_mode="numpy", + ) ctx.interface = True diff --git a/src/nxscli/commands/interface/cmd_serial.py b/src/nxscli/commands/interface/cmd_serial.py index ea45bb7..2665073 100644 --- a/src/nxscli/commands/interface/cmd_serial.py +++ b/src/nxscli/commands/interface/cmd_serial.py @@ -25,7 +25,12 @@ def cmd_serial( # initialize nxslib communication handler assert ctx.parser - ctx.nxscope = NxscopeHandler(intf, ctx.parser) + ctx.nxscope = NxscopeHandler( + intf, + ctx.parser, + enable_bitrate_tracking=True, + stream_decode_mode="numpy", + ) ctx.interface = True diff --git a/src/nxscli/commands/interface/cmd_udp.py b/src/nxscli/commands/interface/cmd_udp.py new file mode 100644 index 0000000..e8a6a25 --- /dev/null +++ b/src/nxscli/commands/interface/cmd_udp.py @@ -0,0 +1,42 @@ +"""Module containint the UDP interface command for CLI.""" + +import click +from nxslib.intf.udp import UdpDevice +from nxslib.nxscope import NxscopeHandler + +from nxscli.cli.environment import Environment, pass_environment + +############################################################################### +# Command: cmd_udp +############################################################################### + + +@click.group(name="udp", chain=True) +@click.argument("host", type=str, required=True) +@click.argument("port", type=int, required=True) +@click.option("--local-port", type=int, default=0, help="Default: 0") +@click.option("--writepadding", default=0, help="Default: 0") +@pass_environment +def cmd_udp( + ctx: Environment, + host: str, + port: int, + local_port: int, + writepadding: int, +) -> bool: # pragma: no cover + """[interface] Connect with a UDP NxScope devie.""" + intf = UdpDevice(host, port, local_port=local_port) + intf.write_padding = writepadding + + # initialize nxslib communication handler + assert ctx.parser + ctx.nxscope = NxscopeHandler( + intf, + ctx.parser, + enable_bitrate_tracking=True, + stream_decode_mode="numpy", + ) + + ctx.interface = True + + return True diff --git a/src/nxscli/control_server.py b/src/nxscli/control_server.py new file mode 100644 index 0000000..4fb70b0 --- /dev/null +++ b/src/nxscli/control_server.py @@ -0,0 +1,400 @@ +"""Optional control-server plugin for nxscli.""" + +import base64 +import json +import os +import socket +import threading +from dataclasses import dataclass +from json import JSONDecodeError +from typing import TYPE_CHECKING, Any, cast + +from nxslib.comm import AckMode +from nxslib.plugin import INxscopePlugin +from nxslib.proto.iparse import ParseAck + +if TYPE_CHECKING: + from nxslib.plugin import INxscopeControl + + +@dataclass(frozen=True) +class ControlResult: + """Response from control server client calls.""" + + ok: bool + data: dict[str, Any] + error: str | None = None + + +@dataclass(frozen=True) +class _EndpointConfig: + """Parsed endpoint configuration.""" + + family: int + bind_addr: Any + connect_addr: Any + cleanup_path: str | None + + +def _require_af_unix(endpoint: str) -> None: + """Fail early for unix endpoints on platforms without AF_UNIX.""" + if not hasattr(socket, "AF_UNIX"): + raise ValueError( + f"unix endpoint '{endpoint}' is not supported on this platform; " + "use tcp://:" + ) + + +def _parse_endpoint(endpoint: str) -> _EndpointConfig: + """Parse IPC endpoint string into socket configuration.""" + if endpoint.startswith("tcp://"): + host_port = endpoint[len("tcp://") :] + host, port_s = host_port.rsplit(":", 1) + port = int(port_s, 10) + addr_tcp = (host, port) + return _EndpointConfig( + family=socket.AF_INET, + bind_addr=addr_tcp, + connect_addr=addr_tcp, + cleanup_path=None, + ) + + if endpoint.startswith("unix-abstract://"): + _require_af_unix(endpoint) + name = endpoint[len("unix-abstract://") :] + if not name: + raise ValueError("unix-abstract endpoint name cannot be empty") + addr_unix_abstract = "\x00" + name + return _EndpointConfig( + family=socket.AF_UNIX, + bind_addr=addr_unix_abstract, + connect_addr=addr_unix_abstract, + cleanup_path=None, + ) + + path = endpoint + if endpoint.startswith("unix://"): + _require_af_unix(endpoint) + path = endpoint[len("unix://") :] + else: + _require_af_unix(endpoint) + if not path: + raise ValueError("unix endpoint path cannot be empty") + return _EndpointConfig( + family=socket.AF_UNIX, + bind_addr=path, + connect_addr=path, + cleanup_path=path, + ) + + +class ControlServerPlugin(INxscopePlugin): + """Nxslib plugin exposing control surface via local IPC endpoint.""" + + name = "control_server" + + def __init__(self, endpoint: str): + """Initialize plugin and parse configured control endpoint.""" + self._endpoint = _parse_endpoint(endpoint) + self._control: "INxscopeControl | None" = None + self._sock: socket.socket | None = None + self._thread: threading.Thread | None = None + self._stop = threading.Event() + + def on_register(self, control: "INxscopeControl") -> None: + """Attach control surface and start IPC server thread.""" + self._control = control + self._start() + + def on_unregister(self) -> None: + """Stop IPC server and detach control surface.""" + self._stop_server() + self._control = None + + def _start(self) -> None: + if self._thread is not None and self._thread.is_alive(): + return + + if self._endpoint.cleanup_path is not None: + os.makedirs( + os.path.dirname(self._endpoint.cleanup_path) or ".", + exist_ok=True, + ) + try: + os.unlink(self._endpoint.cleanup_path) + except FileNotFoundError: + pass + + self._sock = socket.socket(self._endpoint.family, socket.SOCK_STREAM) + self._sock.bind(self._endpoint.bind_addr) + self._sock.listen(4) + self._sock.settimeout(0.2) + self._stop.clear() + + self._thread = threading.Thread( + target=self._serve_loop, + name="nxscli_control", + daemon=True, + ) + self._thread.start() + + def _stop_server(self) -> None: + self._stop.set() + if self._sock is not None: + try: + self._sock.close() + except OSError: + pass + self._sock = None + + if self._thread is not None: + self._thread.join(timeout=1.0) + self._thread = None + + if self._endpoint.cleanup_path is not None: + try: + os.unlink(self._endpoint.cleanup_path) + except FileNotFoundError: + pass + + def _serve_loop(self) -> None: + assert self._sock is not None + while not self._stop.is_set(): + try: + conn, _ = self._sock.accept() + except TimeoutError: + continue + except OSError: + break + + with conn: + conn.settimeout(1.0) + try: + req = self._recv_json(conn) + resp = self._handle(req) + except Exception as exc: + resp = {"ok": False, "error": str(exc)} + self._send_json(conn, resp) + + def _recv_json(self, conn: socket.socket) -> dict[str, Any]: + buf = bytearray() + while True: + chunk = conn.recv(4096) + if not chunk: + break + buf.extend(chunk) + if b"\n" in chunk: + break + if not buf: + raise RuntimeError("empty request") + line = bytes(buf).split(b"\n", 1)[0] + obj = json.loads(line.decode("utf-8")) + if not isinstance(obj, dict): + raise ValueError("request must be a JSON object") + return cast("dict[str, Any]", obj) + + def _send_json(self, conn: socket.socket, resp: dict[str, Any]) -> None: + payload = json.dumps(resp, separators=(",", ":")).encode("utf-8") + conn.sendall(payload + b"\n") + + def _handle(self, req: dict[str, Any]) -> dict[str, Any]: + if self._control is None: + raise RuntimeError("control server not attached") + + method = req.get("method") + params = req.get("params", {}) + + if method == "send_user_frame": + payload = base64.b64decode(params["payload_b64"]) + ack = self._control.send_user_frame( + int(params["fid"]), + payload, + ack_mode=AckMode( + str(params.get("ack_mode", "disabled")).lower() + ), + ack_timeout=float(params.get("ack_timeout", 1.0)), + ) + return { + "ok": True, + "data": { + "state": bool(ack.state), + "retcode": int(ack.retcode), + }, + } + + if method == "ext_notify": + payload = base64.b64decode(params["payload_b64"]) + ack = self._control.ext_notify( + ext_id=int(params["ext_id"]), + cmd_id=int(params["cmd_id"]), + payload=payload, + fid=int(params.get("fid", 8)), + ack_mode=AckMode( + str(params.get("ack_mode", "disabled")).lower() + ), + ack_timeout=float(params.get("ack_timeout", 1.0)), + ) + return { + "ok": True, + "data": { + "state": bool(ack.state), + "retcode": int(ack.retcode), + }, + } + + if method == "ext_request": + payload = base64.b64decode(params["payload_b64"]) + resp = self._control.ext_request( + ext_id=int(params["ext_id"]), + cmd_id=int(params["cmd_id"]), + payload=payload, + fid=int(params.get("fid", 8)), + timeout=float(params.get("timeout", 1.0)), + ack_mode=AckMode( + str(params.get("ack_mode", "disabled")).lower() + ), + ack_timeout=float(params.get("ack_timeout", 1.0)), + ) + return { + "ok": True, + "data": { + "ext_id": int(resp.ext_id), + "cmd_id": int(resp.cmd_id), + "req_id": int(resp.req_id), + "status": int(resp.status), + "fid": int(resp.fid), + "is_error": bool(resp.is_error), + "payload_b64": base64.b64encode(resp.payload).decode( + "ascii" + ), + }, + } + + raise ValueError(f"unknown method: {method}") + + +class ControlClient: + """Client for nxscli ControlServerPlugin endpoint.""" + + def __init__(self, endpoint: str, timeout: float = 1.0): + """Initialize control client bound to given endpoint.""" + self._endpoint = _parse_endpoint(endpoint) + self._timeout = timeout + self._last_error: str | None = None + + @property + def last_error(self) -> str | None: + """Return last control client error string.""" + return self._last_error + + def _call(self, method: str, params: dict[str, Any]) -> ControlResult: + self._last_error = None + try: + with socket.socket( + self._endpoint.family, socket.SOCK_STREAM + ) as sock: + sock.settimeout(self._timeout) + sock.connect(self._endpoint.connect_addr) + + req = {"method": method, "params": params} + wire = json.dumps(req, separators=(",", ":")).encode("utf-8") + sock.sendall(wire + b"\n") + + buf = bytearray() + while True: + chunk = sock.recv(4096) + if not chunk: + break + buf.extend(chunk) + if b"\n" in chunk: + break + if not buf: + self._last_error = "empty response" + return ControlResult(False, {}, self._last_error) + + line = bytes(buf).split(b"\n", 1)[0] + obj = json.loads(line.decode("utf-8")) + return ControlResult( + ok=bool(obj.get("ok", False)), + data=dict(obj.get("data", {})), + error=obj.get("error"), + ) + except (OSError, ValueError, JSONDecodeError) as exc: + self._last_error = str(exc) + return ControlResult(False, {}, self._last_error) + + def send_user_frame( + self, + fid: int, + payload: bytes, + ack_mode: str = "disabled", + ack_timeout: float = 1.0, + ) -> ParseAck: + """Proxy send_user_frame.""" + ret = self._call( + "send_user_frame", + { + "fid": int(fid), + "payload_b64": base64.b64encode(payload).decode("ascii"), + "ack_mode": ack_mode, + "ack_timeout": float(ack_timeout), + }, + ) + if not ret.ok: + return ParseAck(False, -1) + return ParseAck( + bool(ret.data.get("state", False)), + int(ret.data.get("retcode", -1)), + ) + + def ext_notify( + self, + ext_id: int, + cmd_id: int, + payload: bytes, + fid: int = 8, + ack_mode: str = "disabled", + ack_timeout: float = 1.0, + ) -> ParseAck: + """Proxy ext_notify.""" + ret = self._call( + "ext_notify", + { + "ext_id": int(ext_id), + "cmd_id": int(cmd_id), + "payload_b64": base64.b64encode(payload).decode("ascii"), + "fid": int(fid), + "ack_mode": ack_mode, + "ack_timeout": float(ack_timeout), + }, + ) + if not ret.ok: + return ParseAck(False, -1) + return ParseAck( + bool(ret.data.get("state", False)), + int(ret.data.get("retcode", -1)), + ) + + def ext_request( + self, + ext_id: int, + cmd_id: int, + payload: bytes, + fid: int = 8, + timeout: float = 1.0, + ack_mode: str = "disabled", + ack_timeout: float = 1.0, + ) -> ControlResult: + """Proxy ext_request.""" + return self._call( + "ext_request", + { + "ext_id": int(ext_id), + "cmd_id": int(cmd_id), + "payload_b64": base64.b64encode(payload).decode("ascii"), + "fid": int(fid), + "timeout": float(timeout), + "ack_mode": ack_mode, + "ack_timeout": float(ack_timeout), + }, + ) diff --git a/src/nxscli/ext_commands.py b/src/nxscli/ext_commands.py index d9bea5e..f6904a5 100644 --- a/src/nxscli/ext_commands.py +++ b/src/nxscli/ext_commands.py @@ -5,10 +5,13 @@ from nxscli.commands.cmd_csv import cmd_pcsv from nxscli.commands.cmd_devinfo import cmd_pdevinfo from nxscli.commands.cmd_none import cmd_pnone +from nxscli.commands.cmd_npmem import cmd_pnpmem +from nxscli.commands.cmd_npsave import cmd_pnpsave from nxscli.commands.cmd_printer import cmd_printer from nxscli.commands.cmd_udp import cmd_pudp from nxscli.commands.config.cmd_chan import cmd_chan from nxscli.commands.config.cmd_trig import cmd_trig +from nxscli.commands.config.cmd_vadd import cmd_vadd if TYPE_CHECKING: import click @@ -16,8 +19,11 @@ commands_list: list["click.Command"] = [ cmd_chan, cmd_trig, + cmd_vadd, cmd_pdevinfo, cmd_pcsv, + cmd_pnpsave, + cmd_pnpmem, cmd_pnone, cmd_printer, cmd_pudp, diff --git a/src/nxscli/ext_interfaces.py b/src/nxscli/ext_interfaces.py index 757c0c6..d45b92a 100644 --- a/src/nxscli/ext_interfaces.py +++ b/src/nxscli/ext_interfaces.py @@ -5,8 +5,14 @@ from nxscli.commands.interface.cmd_dummy import cmd_dummy from nxscli.commands.interface.cmd_rtt import cmd_rtt from nxscli.commands.interface.cmd_serial import cmd_serial +from nxscli.commands.interface.cmd_udp import cmd_udp if TYPE_CHECKING: import click -interfaces_list: list["click.Group"] = [cmd_serial, cmd_rtt, cmd_dummy] +interfaces_list: list["click.Group"] = [ + cmd_serial, + cmd_udp, + cmd_rtt, + cmd_dummy, +] diff --git a/src/nxscli/ext_plugins.py b/src/nxscli/ext_plugins.py index debffb9..bac6eb8 100644 --- a/src/nxscli/ext_plugins.py +++ b/src/nxscli/ext_plugins.py @@ -4,12 +4,16 @@ from nxscli.plugins.csv import PluginCsv from nxscli.plugins.devinfo import PluginDevinfo from nxscli.plugins.none import PluginNone +from nxscli.plugins.npmem import PluginNpmem +from nxscli.plugins.npsave import PluginNpsave from nxscli.plugins.printer import PluginPrinter from nxscli.plugins.udp import PluginUdp plugins_list = [ DPluginDescription("devinfo", PluginDevinfo), DPluginDescription("csv", PluginCsv), + DPluginDescription("npsave", PluginNpsave), + DPluginDescription("npmem", PluginNpmem), DPluginDescription("none", PluginNone), DPluginDescription("printer", PluginPrinter), DPluginDescription("udp", PluginUdp), diff --git a/src/nxscli/idata.py b/src/nxscli/idata.py index bd44957..25eac45 100644 --- a/src/nxscli/idata.py +++ b/src/nxscli/idata.py @@ -10,8 +10,8 @@ from collections.abc import Callable from nxslib.dev import DeviceChannel - from nxslib.nxscope import DNxscopeStream + from nxscli.channelref import ChannelRef from nxscli.trigger import TriggerHandler ############################################################################### @@ -23,7 +23,7 @@ class PluginDataCb: """Plugin data callbacks.""" - stream_sub: "Callable[[int], queue.Queue[Any]]" + stream_sub: "Callable[[ChannelRef], queue.Queue[Any]]" stream_unsub: "Callable[[queue.Queue[Any]], None]" @@ -91,9 +91,7 @@ def mlen(self) -> int: """Return stream metadata dimension.""" return self._channel.data.mlen - def queue_get( - self, block: bool, timeout: float = 1.0 - ) -> list["DNxscopeStream"]: + def queue_get(self, block: bool, timeout: float = 1.0) -> Any: """Get data from a stream queue. :param block: blocking operation @@ -105,6 +103,7 @@ def queue_get( ret = self._queue.get(block=block, timeout=timeout) except queue.Empty: pass + return self._trigger.data_triggered(ret) @@ -145,10 +144,23 @@ def __del__(self) -> None: pass def _qdlist_init(self) -> list[PluginQueueData]: + from nxscli.channelref import ChannelRef + ret = [] for i, channel in enumerate(self._chanlist): # get queue with data - que = self._cb.stream_sub(channel.data.chan) + if channel.data.chan >= 0: + cref = ChannelRef.physical(channel.data.chan) + else: + name = channel.data.name + if name.startswith("v") and name[1:].isnumeric(): + cref = ChannelRef.virtual(int(name[1:])) + else: + raise ValueError( + "invalid virtual channel name " + f"for stream subscription: {name}" + ) + que = self._cb.stream_sub(cref) # initialize queue handler pdata = PluginQueueData(que, channel, self._trig[i]) # add hanler to a list @@ -157,9 +169,9 @@ def _qdlist_init(self) -> list[PluginQueueData]: def _queue_deinit(self) -> None: """Deinitialize queue.""" - for i, pdata in enumerate(self._qdlist): + for pdata in self._qdlist: self._cb.stream_unsub(pdata.queue) - self._qdlist.pop(i) + self._qdlist.clear() # clean up triggers # TODO: revisit where this beleong, here or in plugins ? diff --git a/src/nxscli/iplugin.py b/src/nxscli/iplugin.py index e24ceae..49a25ec 100644 --- a/src/nxscli/iplugin.py +++ b/src/nxscli/iplugin.py @@ -71,10 +71,32 @@ def handled(self, val: bool) -> None: """ self._handled = val + def get_plot_handler(self) -> Any: + """Return the plot handler for this plugin. + + Plugins with a visual output (plot backends) override this method + to expose their plot handler instance. + + :return: plot handler instance, or None if the plugin has no plot + """ + return None + def wait_for_plugin(self) -> bool: """Return True if plugin is dont't need to wait.""" return True + @classmethod + def get_inputhook(cls) -> Any: + """Get inputhook function for GUI event processing. + + Plugin classes that use GUI frameworks (matplotlib, Qt, etc.) should + override this method to return a function that processes GUI events + while waiting for user input in interactive mode. + + :return: inputhook function or None if no GUI event processing needed + """ + return None + @property @abstractmethod def stream(self) -> bool: diff --git a/src/nxscli/istream.py b/src/nxscli/istream.py new file mode 100644 index 0000000..836602c --- /dev/null +++ b/src/nxscli/istream.py @@ -0,0 +1,66 @@ +"""Interfaces for additional stream providers.""" + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + import queue + + from nxslib.dev import DeviceChannel + from nxslib.nxscope import DNxscopeStreamBlock, NxscopeHandler + + from nxscli.channelref import ChannelRef + + +############################################################################### +# Class: IStreamProvider +############################################################################### + + +class IStreamProvider(Protocol): + """Provider that can expose additional channels and stream queues.""" + + def on_connect(self, nxscope: "NxscopeHandler") -> None: + """Attach to an active ``NxscopeHandler``.""" + + def on_disconnect(self) -> None: + """Detach from ``NxscopeHandler``.""" + + def on_stream_start(self) -> None: + """React to stream start.""" + + def on_stream_stop(self) -> None: + """React to stream stop.""" + + def channel_get(self, channel: "ChannelRef") -> "DeviceChannel | None": + """Return channel metadata managed by this provider.""" + + def channel_list(self) -> tuple["DeviceChannel", ...]: + """Return all channels managed by this provider.""" + + def stream_sub( + self, channel: "ChannelRef" + ) -> "queue.Queue[list[DNxscopeStreamBlock]] | None": + """Subscribe to a provider-managed channel queue.""" + + def stream_unsub( + self, subq: "queue.Queue[list[DNxscopeStreamBlock]]" + ) -> bool: + """Unsubscribe queue. Return ``True`` if queue belonged here.""" + + +############################################################################### +# Class: IServiceRegistry +############################################################################### + + +class IServiceRegistry(Protocol): + """Minimal service registry exposed by ``PluginHandler``.""" + + def service_get(self, name: str) -> Any: + """Get registered service.""" + + def service_set(self, name: str, service: Any) -> None: + """Set registered service.""" + + def stream_provider_add(self, provider: IStreamProvider) -> None: + """Add stream provider.""" diff --git a/src/nxscli/phandler.py b/src/nxscli/phandler.py index cfc2d3c..2659d01 100644 --- a/src/nxscli/phandler.py +++ b/src/nxscli/phandler.py @@ -1,16 +1,21 @@ """Module containt the nxscli handler implementation.""" -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Sequence +from nxscli.channelref import ChannelRef from nxscli.idata import PluginDataCb from nxscli.logger import logger +from nxscli.stream_hub import SharedStreamProvider from nxscli.trigger import DTriggerConfigReq, TriggerHandler, trigger_from_req if TYPE_CHECKING: + import queue + from nxslib.dev import Device, DeviceChannel from nxslib.nxscope import NxscopeHandler from nxscli.iplugin import DPluginDescription, IPlugin + from nxscli.istream import IStreamProvider ############################################################################### @@ -26,7 +31,7 @@ def __init__(self, plugins: list["DPluginDescription"] | None = None): :param plugins: a list with plugins """ - self._nxs: "NxscopeHandler" | None = None + self._nxs: "NxscopeHandler" | None = None # noqa: TC010 self._plugins = {} if plugins: @@ -43,27 +48,42 @@ def __init__(self, plugins: list["DPluginDescription"] | None = None): self._stream = False self._cleanup_done = False + self._providers: list["IStreamProvider"] = [SharedStreamProvider()] + self._services: dict[str, Any] = {} def __del__(self) -> None: """Raise assertion if not cleaned.""" if not self._cleanup_done: raise AssertionError("PluginHandler not cleaned") - def _chanlist_gen(self, channels: list[int]) -> list["DeviceChannel"]: - assert self._nxs + def __enter__(self) -> "PluginHandler": + """Return self on context manager entry.""" + return self + + def __exit__(self, *_: object) -> None: + """Clean up on context manager exit.""" + self.cleanup() + + def _chanlist_gen( + self, channels: Sequence[ChannelRef] | None + ) -> list["DeviceChannel"]: assert self.dev + refs = self._channel_refs(channels, default_all=False) # convert special keys for all channels - if channels and channels[0] == -1: # pragma: no cover - chanlist = list(range(self.dev.data.chmax)) - else: - assert all(isinstance(x, int) for x in channels) - chanlist = channels - # get channels data - ret = [] - for chan in chanlist: - channel = self._nxs.dev_channel_get(chan) - assert channel + ret: list["DeviceChannel"] = [] + for ref in refs: + if ref.is_all: # pragma: no cover + for chid in range(self.dev.data.chmax): + channel = self.channel_get(ChannelRef.physical(chid)) + if channel is None: + raise AssertionError + ret.append(channel) + continue + + channel = self.channel_get(ref) + if channel is None: + raise AssertionError ret.append(channel) return ret @@ -78,6 +98,12 @@ def _chanlist_enable(self) -> None: ) continue + if channel.data.chan < 0: + continue + + if self._nxs.dev_channel_get(channel.data.chan) is None: + continue + # enable channel self._nxs.ch_enable(channel.data.chan) @@ -85,12 +111,52 @@ def _chanlist_div(self, div: int | list[int]) -> None: assert self._nxs if isinstance(div, int): for channel in self._chanlist: - self._nxs.ch_divider(channel.data.chan, div) + if channel.data.chan < 0: + continue + if self._nxs.dev_channel_get(channel.data.chan) is not None: + self._nxs.ch_divider(channel.data.chan, div) else: # divider list configuration must cover all configured channels assert len(div) == len(self._chanlist) for i, channel in enumerate(self._chanlist): - self._nxs.ch_divider(channel.data.chan, div[i]) + if channel.data.chan < 0: + continue + if self._nxs.dev_channel_get(channel.data.chan) is not None: + self._nxs.ch_divider(channel.data.chan, div[i]) + + def _provider_channels(self) -> list["DeviceChannel"]: + channels: list["DeviceChannel"] = [] + for provider in self._providers: + channels.extend(provider.channel_list()) + return channels + + def _channel_ref(self, value: Any) -> ChannelRef: + if isinstance(value, ChannelRef): + return value + if isinstance(value, int): + if value == -1: + return ChannelRef.all_channels() + return ChannelRef.physical(value) + token = value.strip() + if token.startswith("v"): + vnum = token[1:] + if not vnum.isnumeric(): + raise ValueError(f"Invalid virtual channel: {value}") + return ChannelRef.virtual(int(vnum)) + if token.isnumeric(): + return ChannelRef.physical(int(token)) + raise ValueError(f"Invalid channel token: {value}") + + def _channel_refs( + self, + channels: Sequence[ChannelRef] | None, + default_all: bool, + ) -> list[ChannelRef]: + if channels is None: + if default_all: + return [ChannelRef.all_channels()] + return [] + return [self._channel_ref(x) for x in channels] @property def chanlist(self) -> list["DeviceChannel"]: @@ -123,6 +189,138 @@ def enabled(self) -> list[tuple[int, type["IPlugin"], Any]]: """Get enabled plugins.""" return self._enabled + @property + def nxscope(self) -> "NxscopeHandler": + """Get NxScope handler. + + :return: NxscopeHandler instance + :raises AssertionError: If nxscope is not connected + """ + assert self._nxs + return self._nxs + + def get_enabled_channels(self, applied: bool = True) -> tuple[int, ...]: + """Get enabled channels from NxScope.""" + return self.nxscope.get_enabled_channels(applied=applied) + + def get_channel_divider(self, chid: int, applied: bool = True) -> int: + """Get channel divider from NxScope.""" + return self.nxscope.get_channel_divider(chid, applied=applied) + + def get_channel_dividers(self, applied: bool = True) -> tuple[int, ...]: + """Get channel dividers from NxScope.""" + return self.nxscope.get_channel_dividers(applied=applied) + + def get_channels_state(self, applied: bool = True) -> Any: + """Get channels state snapshot from NxScope.""" + return self.nxscope.get_channels_state(applied=applied) + + def get_device_capabilities(self) -> Any: + """Get device capabilities snapshot from NxScope.""" + return self.nxscope.get_device_capabilities() + + def get_stream_stats(self) -> Any: + """Get stream stats snapshot from NxScope.""" + return self.nxscope.get_stream_stats() + + def collect_inputhooks(self) -> list[Any]: + """Collect inputhooks from all loaded plugins. + + :return: list of inputhook functions from plugins that provide them + """ + hooks = [] + for plugin_cls in self._plugins.values(): + hook = plugin_cls.get_inputhook() + if hook is not None: + hooks.append(hook) + return hooks + + def plugin_start_dynamic(self, name: str, **kwargs: Any) -> int: + """Start a plugin dynamically at runtime. + + :param name: Plugin name + :param kwargs: Plugin-specific configuration + + :return: Plugin ID for later reference + """ + from nxscli.iplugin import EPluginType + + # Get plugin class + cls = self._plugins[name] + + # Create plugin instance + plugin = cls() # type: ignore + + # Connect to plugin handler + plugin.connect_phandler(self) + + # Start the plugin + if not plugin.start(kwargs): # pragma: no cover + logger.error("failed to start plugin %s", str(plugin)) + return -1 + + # For plot plugins (STATIC/ANIMATION), call result() to show the plot + # This is equivalent to what handle_plugin() does in the normal flow + if plugin.ptype in (EPluginType.STATIC, EPluginType.ANIMATION): + plugin.result() + + # Add to started list + self._started.append((plugin, kwargs)) + pid = len(self._started) - 1 + + logger.info("dynamically started %s with pid=%d", str(plugin), pid) + return pid + + def plugin_stop_dynamic(self, pid: int) -> None: + """Stop a running plugin by ID. + + :param pid: Plugin ID from plugin_start_dynamic() + + :raises IndexError: If plugin ID is invalid + """ + if pid < 0 or pid >= len(self._started): + raise IndexError(f"Invalid plugin ID: {pid}") + + plugin, _ = self._started[pid] + plugin.stop() + logger.info("stopped plugin with pid=%d", pid) + + # Remove from started list + self._started.pop(pid) + + def plugin_get_instance(self, pid: int) -> "IPlugin | None": + """Return running plugin instance for a given plugin ID. + + :param pid: plugin ID from plugin_start_dynamic() + + :return: IPlugin instance, or None if the ID is out of range + """ + if pid < 0 or pid >= len(self._started): + return None + plugin, _ = self._started[pid] + return plugin + + def get_started_plugins(self) -> tuple[tuple[int, str], ...]: + """Get list of started plugins. + + :return: Tuple of (pid, plugin_name) pairs where plugin_name + is the registered name (not class name) + """ + result = [] + for i, (plugin, _) in enumerate(self._started): + # Find registered name for this plugin class + plugin_class = type(plugin) + name = None + for reg_name, reg_class in self._plugins.items(): + if reg_class == plugin_class: + name = reg_name + break + if name is None: + # Fallback to class name if not found + name = plugin_class.__name__ + result.append((i, name)) + return tuple(result) + def cleanup(self) -> None: """Clean up - must be called after instance use.""" # disconnect nxscope if connected @@ -133,16 +331,64 @@ def cleanup(self) -> None: def cb_get(self) -> PluginDataCb: """Get callbacks for plugins.""" + return PluginDataCb(self.stream_sub, self.stream_unsub) + + def service_set(self, name: str, service: Any) -> None: + """Register named service for extensions.""" + self._services[name] = service + + def service_get(self, name: str) -> Any: + """Get named service for extensions.""" + return self._services.get(name) + + def stream_provider_add(self, provider: "IStreamProvider") -> None: + """Register stream provider.""" + self._providers.append(provider) + if self._nxs is not None: + provider.on_connect(self._nxs) + + def channel_get(self, channel: ChannelRef) -> "DeviceChannel | None": + """Get channel from device or registered providers.""" + if self._nxs is not None and channel.is_physical: + ch = self._nxs.dev_channel_get(channel.physical_id()) + if ch is not None: + return ch + for provider in self._providers: + ch = provider.channel_get(channel) + if ch is not None: + return ch + return None + + def stream_sub(self, channel: ChannelRef) -> "queue.Queue[Any]": + """Subscribe queue for device/provider channel.""" assert self._nxs - return PluginDataCb(self._nxs.stream_sub, self._nxs.stream_unsub) + for provider in self._providers: + subq = provider.stream_sub(channel) + if subq is not None: + return subq + if not channel.is_physical: + raise ValueError(f"Unknown channel: {channel}") + return self._nxs.stream_sub(channel.physical_id()) + + def stream_unsub(self, subq: "queue.Queue[Any]") -> None: + """Unsubscribe queue from device/provider channel.""" + for provider in self._providers: + if provider.stream_unsub(subq): + return + if self._nxs is not None: + self._nxs.stream_unsub(subq) def stream_start(self) -> None: """Start stream.""" assert self._nxs self._nxs.stream_start() + for provider in self._providers: + provider.on_stream_start() def stream_stop(self) -> None: """Stop stream.""" + for provider in self._providers: + provider.on_stream_stop() assert self._nxs self._nxs.stream_stop() @@ -150,6 +396,8 @@ def nxscope_disconnect(self) -> None: """Disconnect from NxScope device.""" if self._nxs: logger.info("disconnecting from nxs device...") + for provider in self._providers: + provider.on_disconnect() # connect nxscope device self._nxs.disconnect() logger.info("disconnected!") @@ -164,6 +412,8 @@ def nxscope_connect(self, nxs: "NxscopeHandler") -> None: logger.info("connecting to nxs device...") # connect nxscope device self._nxs.connect() + for provider in self._providers: + provider.on_connect(self._nxs) logger.info("connected!") def plugin_add(self, cls: tuple[str, type["IPlugin"]]) -> None: @@ -313,27 +563,61 @@ def trigger_get( trg = DTriggerConfigReq("on", None) return trg - def chanlist_plugin(self, channels: list[int]) -> list["DeviceChannel"]: + def chanlist_plugin( # noqa: C901 + self, channels: Sequence[ChannelRef] | None + ) -> list["DeviceChannel"]: """Prepare channels list for a plugin. :param chanlist: a list with plugin channels """ + assert self.dev + + refs = self._channel_refs(channels, default_all=True) chanlist = [] - if channels and channels[0] != -1: - # plugin specific channels configuration - for chan in self.chanlist: # pragma: no cover - if chan.data.chan in channels: - chanlist.append(chan) - else: # pragma: no cover - pass + + # If no channels configured in phandler (dynamic mode), + # get them directly from device + if not self._chanlist: + if refs and refs[0].is_all: + # All channels + for chid in range(self.dev.data.chmax): + ch = self.channel_get(ChannelRef.physical(chid)) + if ch and ch.data.is_valid: + chanlist.append(ch) + for ch in self._provider_channels(): + if ch.data.is_valid: + chanlist.append(ch) + else: + # Specific channels + for ref in refs: + ch = self.channel_get(ref) + if ch and ch.data.is_valid: + chanlist.append(ch) else: - chanlist = self.chanlist + # Normal mode: filter from configured chanlist + if refs and not refs[0].is_all: + # plugin specific channels configuration + for chan in self.chanlist: # pragma: no cover + if any( + ref.is_physical and ref.value == chan.data.chan + for ref in refs + ): + chanlist.append(chan) + else: # pragma: no cover + pass + for ref in refs: + if ref.is_virtual: + ch = self.channel_get(ref) + if ch and ch.data.is_valid and ch not in chanlist: + chanlist.append(ch) + else: + chanlist = self.chanlist return chanlist def channels_configure( self, - channels: list[int], + channels: Sequence[ChannelRef] | None, div: int | list[int] = 0, writenow: bool = False, ) -> None: @@ -349,9 +633,13 @@ def channels_configure( """ assert self._nxs - logger.info("configure channels = %s divider = %s", str(channels), str(div)) + logger.info( + "configure channels = %s divider = %s", str(channels), str(div) + ) - self._chanlist = self._chanlist_gen(channels) + refs = self._channel_refs(channels, default_all=False) + physical_channels = [x for x in refs if x.is_physical] + self._chanlist = self._chanlist_gen(physical_channels) if not self._chanlist: return diff --git a/src/nxscli/plugins/csv.py b/src/nxscli/plugins/csv.py index ce93bf9..d6b990d 100644 --- a/src/nxscli/plugins/csv.py +++ b/src/nxscli/plugins/csv.py @@ -1,16 +1,14 @@ -"""Module containing CSV plugin.""" +"""Module containing CSV plugin.""" # noqa: A005 import csv -from typing import TYPE_CHECKING, Any +from typing import Any + +import numpy as np from nxscli.idata import PluginData, PluginQueueData from nxscli.iplugin import IPluginFile from nxscli.logger import logger -from nxscli.pluginthr import PluginThread - -if TYPE_CHECKING: - from nxslib.nxscope import DNxscopeStream - +from nxscli.pluginthr import PluginThread, StreamBlocks ############################################################################### # Class: PluginCsv @@ -47,13 +45,6 @@ def _csvwriters_open(self) -> list[Any]: return csvwriters - def _sample_row_get(self, sample: "DNxscopeStream") -> tuple[Any, Any]: - # covert to string - if self._meta_string: - return (sample.data, bytes(list(sample.meta)).decode()) - else: - return sample.data, sample.meta - def _init(self) -> None: assert self._phandler # open writers @@ -66,21 +57,39 @@ def _final(self) -> None: logger.info("csv capture DONE") - def _handle_samples( - self, data: list["DNxscopeStream"], pdata: "PluginQueueData", j: int + def _handle_blocks( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int ) -> None: - # store data - for sample in data: + writer = self._csvwriters[j][0] + for block in data: + block_data = block.data + assert isinstance(block_data, np.ndarray) + rows = int(block_data.shape[0]) + if rows == 0: + continue + if not self._nostop: # pragma: no cover - # ignore data if capture done for channel - if self._datalen[j] >= self._samples: + remaining = self._samples - self._datalen[j] + if remaining <= 0: + break + rows = min(rows, remaining) + if rows <= 0: # pragma: no cover break - # write row - self._csvwriters[j][0].writerow(self._sample_row_get(sample)) - - # one sample - self._datalen[j] += 1 + data_rows = (tuple(row) for row in block_data[:rows]) + meta_rows: Any + if block.meta is None: + meta_rows = (() for _ in range(rows)) + elif self._meta_string: + meta_rows = ( + bytes(np.asarray(mrow, dtype=np.uint8)).decode() + for mrow in block.meta[:rows] + ) + else: + meta_rows = (tuple(mrow) for mrow in block.meta[:rows]) + + writer.writerows(zip(data_rows, meta_rows)) + self._datalen[j] += rows def start(self, kwargs: Any) -> bool: """Start CSV plugin. diff --git a/src/nxscli/plugins/devinfo.py b/src/nxscli/plugins/devinfo.py index fc58b05..69fcdce 100644 --- a/src/nxscli/plugins/devinfo.py +++ b/src/nxscli/plugins/devinfo.py @@ -38,23 +38,27 @@ def start(self, _: Any) -> bool: assert self._phandler assert self._phandler.dev - dev = self._phandler.dev - ret: Any = {} - ret["cmn"] = {} - ret["cmn"]["chmax"] = dev.data.chmax - ret["cmn"]["flags"] = dev.data.flags - ret["cmn"]["rxpadding"] = dev.data.rxpadding + ret["cmn"] = vars(self._phandler.get_device_capabilities()) + ret["stream"] = vars(self._phandler.get_stream_stats()) + ret["channels_state_applied"] = vars( + self._phandler.get_channels_state(applied=True) + ) + ret["channels_state_buffered"] = vars( + self._phandler.get_channels_state(applied=False) + ) tmp = [] - for chid in range(dev.data.chmax): - chinfo = dev.channel_get(chid) + for chid in range(ret["cmn"]["chmax"]): + chinfo = self._phandler.nxscope.dev_channel_get(chid) assert chinfo chan: Any = {} chan["chan"] = chinfo.data.chan chan["type"] = chinfo.data._type chan["vdim"] = chinfo.data.vdim chan["name"] = chinfo.data.name + chan["enabled"] = chinfo.data.en + chan["divider"] = self._phandler.get_channel_divider(chid) tmp.append(chan) @@ -69,6 +73,12 @@ def result(self) -> str: assert self._return s = "\nDevice common:\n" s += pprint.pformat(self._return["cmn"]) + s += "\nStream stats:\n" + s += pprint.pformat(self._return["stream"]) + s += "\nChannels state (applied):\n" + s += pprint.pformat(self._return["channels_state_applied"]) + s += "\nChannels state (buffered):\n" + s += pprint.pformat(self._return["channels_state_buffered"]) s += "\nDevice channels:\n" s += pprint.pformat(self._return["channels"]) s += "\n" diff --git a/src/nxscli/plugins/none.py b/src/nxscli/plugins/none.py index a1fcd20..c2aa459 100644 --- a/src/nxscli/plugins/none.py +++ b/src/nxscli/plugins/none.py @@ -1,15 +1,13 @@ """Module containing dummy capture plugin.""" -from typing import TYPE_CHECKING, Any +from typing import Any + +import numpy as np from nxscli.idata import PluginData, PluginQueueData from nxscli.iplugin import IPluginNone from nxscli.logger import logger -from nxscli.pluginthr import PluginThread - -if TYPE_CHECKING: - from nxslib.nxscope import DNxscopeStream - +from nxscli.pluginthr import PluginThread, StreamBlocks ############################################################################### # Class: PluginNone @@ -32,12 +30,21 @@ def _init(self) -> None: def _final(self) -> None: logger.info("None DONE") - def _handle_samples( - self, data: list["DNxscopeStream"], pdata: "PluginQueueData", j: int + def _handle_blocks( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int ) -> None: - for _ in data: - # get data len - self._datalen[j] += 1 + for block in data: + block_data = block.data + assert isinstance(block_data, np.ndarray) + rows = int(block_data.shape[0]) + if rows == 0: + continue + if not self._nostop: + remaining = self._samples - self._datalen[j] + if remaining <= 0: + break + rows = min(rows, remaining) + self._datalen[j] += rows def start(self, kwargs: Any) -> bool: """Start none plugin. diff --git a/src/nxscli/plugins/npmem.py b/src/nxscli/plugins/npmem.py new file mode 100644 index 0000000..1b63574 --- /dev/null +++ b/src/nxscli/plugins/npmem.py @@ -0,0 +1,117 @@ +"""Module containing Numpy memmap plugin.""" + +from typing import TYPE_CHECKING, Any + +import numpy as np + +from nxscli.idata import PluginData, PluginQueueData +from nxscli.iplugin import IPluginFile +from nxscli.logger import logger +from nxscli.pluginthr import PluginThread, StreamBlocks + +if TYPE_CHECKING: + from nxslib.nxscope import DNxscopeStream + + +class PluginNpmem(PluginThread, IPluginFile): + """Plugin that capture data to Numpy memmap files.""" + + def __init__(self) -> None: + """Intiialize a Numpy capture plugin.""" + IPluginFile.__init__(self) + PluginThread.__init__(self) + + self._data: "PluginData" + self._path: str + self._npfiles: list[Any] = [] + + self._npshape: int + self._npdata: list[np.ndarray[Any, Any]] = [] + + def _init(self) -> None: + assert self._phandler + + self._npdata = [] + + for pdata in self._data.qdlist: + chanpath = self._path + "_chan" + str(pdata.chan) + ".dat" + npf = np.memmap( + chanpath, + dtype="float32", + mode="w+", + shape=(pdata.vdim, self._npshape), + ) + self._npfiles.append(npf) + self._npdata.append(np.empty((0, pdata.vdim), dtype=np.float64)) + + def _final(self) -> None: + logger.info("numpy memmap captures DONE") + + # no API to close memmap + + def _flush_ready(self, pdata: "PluginQueueData", j: int) -> None: + pending = self._npdata[j] + while pending.shape[0] >= self._npshape: + chunk = pending[: self._npshape, :] + self._npfiles[j][:] = chunk.T.astype(np.float32, copy=False) + self._npfiles[j].flush() + self._datalen[j] += self._npshape + pending = pending[self._npshape :, :] + self._npdata[j] = pending + + def _handle_blocks( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int + ) -> None: + chunks: list[np.ndarray[Any, Any]] = [self._npdata[j]] + for block in data: + block_data = np.asarray(block.data, dtype=np.float64) + if int(block_data.shape[0]) == 0: # pragma: no cover + continue + chunks.append(block_data) + if len(chunks) > 1: # pragma: no branch + self._npdata[j] = np.concatenate(chunks, axis=0) + self._flush_ready(pdata, j) + + def _handle_samples( + self, data: list["DNxscopeStream"], pdata: "PluginQueueData", j: int + ) -> None: + if not data: # pragma: no cover + return + block = np.empty((len(data), pdata.vdim), dtype=np.float64) + for row, sample in enumerate(data): + for col in range(pdata.vdim): + # TODO: metadata not supported for now + block[row, col] = sample.data[col] + self._npdata[j] = np.concatenate((self._npdata[j], block), axis=0) + self._flush_ready(pdata, j) + + def start(self, kwargs: Any) -> bool: # pragma: no cover + """Start capture plugin. + + :param kwargs: implementation specific arguments + """ + assert self._phandler + + logger.info("start capture %s", str(kwargs)) + + self._samples = kwargs["samples"] + self._path = kwargs["path"] + self._nostop = kwargs["nostop"] + self._npshape = kwargs["shape"] + + chanlist = self._phandler.chanlist_plugin(kwargs["channels"]) + trig = self._phandler.triggers_plugin(chanlist, kwargs["trig"]) + + cb = self._phandler.cb_get() + self._data = PluginData(chanlist, trig, cb) + + if not self._data.qdlist: # pragma: no cover + return False + + self.thread_start(self._data) + + return True + + def result(self) -> None: + """Get npsave plugin result.""" + return # pragma: no cover diff --git a/src/nxscli/plugins/npsave.py b/src/nxscli/plugins/npsave.py new file mode 100644 index 0000000..a921786 --- /dev/null +++ b/src/nxscli/plugins/npsave.py @@ -0,0 +1,98 @@ +"""Module containing Numpy capture plugin.""" + +from typing import TYPE_CHECKING, Any + +import numpy as np + +from nxscli.idata import PluginData, PluginQueueData +from nxscli.iplugin import IPluginFile +from nxscli.logger import logger +from nxscli.pluginthr import PluginThread, StreamBlocks + +if TYPE_CHECKING: + from nxslib.nxscope import DNxscopeStream + + +class PluginNpsave(PluginThread, IPluginFile): + """Plugin that capture data to Numpy file.""" + + def __init__(self) -> None: + """Intiialize a Numpy capture plugin.""" + IPluginFile.__init__(self) + PluginThread.__init__(self) + + self._data: "PluginData" + self._path: str + self._npdata: list[list[np.ndarray[Any, Any]]] = [] + + def _init(self) -> None: + assert self._phandler + + self._npdata = [[] for _ in range(len(self._data.qdlist))] + + def _final(self) -> None: + logger.info("numpy save captures DONE") + + for i, pdata in enumerate(self._data.qdlist): + chanpath = self._path + "_chan" + str(pdata.chan) + ".npy" + chunks = self._npdata[i] + if chunks: + npdata = np.concatenate(chunks, axis=0).T + else: + npdata = np.empty((pdata.vdim, 0), dtype=np.float64) + np.save(chanpath, npdata) + + def _handle_blocks( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int + ) -> None: + rows = 0 + for block in data: + block_data = np.asarray(block.data, dtype=np.float64) + if int(block_data.shape[0]) == 0: # pragma: no cover + continue + self._npdata[j].append(block_data) + rows += int(block_data.shape[0]) + self._datalen[j] += rows + + def _handle_samples( + self, data: list["DNxscopeStream"], pdata: "PluginQueueData", j: int + ) -> None: + if not data: # pragma: no cover + return + block = np.empty((len(data), pdata.vdim), dtype=np.float64) + for row, sample in enumerate(data): + for col in range(pdata.vdim): + # TODO: metadata not supported for now + block[row, col] = sample.data[col] + self._npdata[j].append(block) + self._datalen[j] += int(block.shape[0]) + + def start(self, kwargs: Any) -> bool: # pragma: no cover + """Start capture plugin. + + :param kwargs: implementation specific arguments + """ + assert self._phandler + + logger.info("start capture %s", str(kwargs)) + + self._samples = kwargs["samples"] + self._path = kwargs["path"] + self._nostop = kwargs["nostop"] + + chanlist = self._phandler.chanlist_plugin(kwargs["channels"]) + trig = self._phandler.triggers_plugin(chanlist, kwargs["trig"]) + + cb = self._phandler.cb_get() + self._data = PluginData(chanlist, trig, cb) + + if not self._data.qdlist: # pragma: no cover + return False + + self.thread_start(self._data) + + return True + + def result(self) -> None: + """Get npsave plugin result.""" + return # pragma: no cover diff --git a/src/nxscli/plugins/printer.py b/src/nxscli/plugins/printer.py index 3cc9d32..439bf0e 100644 --- a/src/nxscli/plugins/printer.py +++ b/src/nxscli/plugins/printer.py @@ -2,16 +2,12 @@ import queue from threading import Event -from typing import TYPE_CHECKING, Any +from typing import Any from nxscli.idata import PluginData, PluginQueueData from nxscli.iplugin import IPluginText from nxscli.logger import logger -from nxscli.pluginthr import PluginThread - -if TYPE_CHECKING: - from nxslib.nxscope import DNxscopeStream - +from nxscli.pluginthr import PluginThread, StreamBlocks ############################################################################### # Class: PluginPrinter @@ -39,18 +35,18 @@ def _init(self) -> None: def _final(self) -> None: logger.info("printer DONE") - def _handle_samples( - self, data: list["DNxscopeStream"], pdata: "PluginQueueData", j: int + def _handle_blocks( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int ) -> None: - for sample in data: + for data_t, meta_t in self._block_rows(data, pdata, j): if self._datalen[j] < self._samples: d: dict[str, Any] = dict() d["chan"] = self._data.qdlist[j].chan - d["data"] = sample.data + d["data"] = data_t if self._meta_string: - d["meta"] = bytes(list(sample.meta)).decode() + d["meta"] = bytes(list(meta_t)).decode() else: - d["meta"] = sample.meta + d["meta"] = meta_t self._q.put(d) self._datalen[j] += 1 diff --git a/src/nxscli/plugins/udp.py b/src/nxscli/plugins/udp.py index 79bd3cb..f17d122 100644 --- a/src/nxscli/plugins/udp.py +++ b/src/nxscli/plugins/udp.py @@ -2,16 +2,14 @@ import json import socket -from typing import TYPE_CHECKING, Any +from typing import Any + +import numpy as np from nxscli.idata import PluginData, PluginQueueData from nxscli.iplugin import IPluginFile from nxscli.logger import logger -from nxscli.pluginthr import PluginThread - -if TYPE_CHECKING: - from nxslib.nxscope import DNxscopeStream - +from nxscli.pluginthr import PluginThread, StreamBlocks ############################################################################### # Class: PluginUdp @@ -41,39 +39,48 @@ def _final(self) -> None: self._sock.close() logger.info("UDP capture DONE") - def _handle_samples( - self, data: list["DNxscopeStream"], pdata: "PluginQueueData", j: int + def _handle_blocks( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int ) -> None: - # store data - for sample in data: + if pdata.vdim > 1: + keys = [pdata.channame + "_" + str(i) for i in range(pdata.vdim)] + else: + keys = [pdata.channame] + + dumps = json.dumps + sendto = self._sock.sendto + endpoint = (self._address, self._port) + + for block in data: + block_data = block.data + assert isinstance(block_data, np.ndarray) + rows = int(block_data.shape[0]) + if rows == 0: + continue + if not self._nostop: # pragma: no cover - # ignore data if capture done for channel - if self._datalen[j] >= self._samples: + remaining = self._samples - self._datalen[j] + if remaining <= 0: + break + rows = min(rows, remaining) + if rows <= 0: # pragma: no cover break - # get data - temp: Any = {} - temp["timestamp"] = self._datalen[j] - - # TODO: optimise - for i, val in enumerate(sample.data): - if pdata.vdim > 1: - s = pdata.channame + "_" + str(i) - else: - s = pdata.channame - - temp[s] = val + start = self._datalen[j] + for offs, row in enumerate(block_data[:rows]): + temp: dict[str, Any] = {"timestamp": start + offs} + for key, val in zip(keys, row): + temp[key] = float(val) - # encode data - if self._data_format == "json": - encoded = json.dumps(temp).encode() - else: # pragma: no cover - raise ValueError("not supported data format") + # encode data + if self._data_format == "json": + encoded = dumps(temp).encode() + else: # pragma: no cover + raise ValueError("not supported data format") - self._sock.sendto(encoded, (self._address, self._port)) + sendto(encoded, endpoint) - # one sample - self._datalen[j] += 1 + self._datalen[j] += rows def start(self, kwargs: Any) -> bool: """Start UDP plugin. diff --git a/src/nxscli/pluginthr.py b/src/nxscli/pluginthr.py index fccede6..d9817c4 100644 --- a/src/nxscli/pluginthr.py +++ b/src/nxscli/pluginthr.py @@ -2,15 +2,17 @@ import threading from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterator +import numpy as np +from nxslib.nxscope import DNxscopeStreamBlock from nxslib.thread import ThreadCommon if TYPE_CHECKING: - from nxslib.nxscope import DNxscopeStream - from nxscli.idata import PluginData, PluginQueueData +StreamBlocks = list[DNxscopeStreamBlock] + ############################################################################### # Class: PluginThread @@ -36,21 +38,34 @@ def _thread_common(self) -> None: assert self._plugindata # get samples for j, pdata in enumerate(self._plugindata.qdlist): - # get data from queue - data = pdata.queue_get(block=True, timeout=1) - if not self._nostop: # pragma: no cover - # ignore data if capture done for channel if self._datalen[j] >= self._samples: continue - # handle samples - self._handle_samples(data, pdata, j) + data = self._queue_data_get(pdata) + + if self._is_block_payload(data): + self._handle_blocks(data, pdata, j) + elif data: + raise RuntimeError( + "nxscli requires numpy block stream payloads; " + "non-block payloads are not supported" + ) # break loop if done if self._is_done(self._datalen): self._thread.stop_set() + def _queue_data_get(self, pdata: "PluginQueueData") -> Any: + return pdata.queue_get(block=True, timeout=1) + + def _is_block_payload(self, data: Any) -> bool: + return ( + bool(data) + and isinstance(data, list) + and isinstance(data[0], DNxscopeStreamBlock) + ) + def _init_common(self) -> None: self._init() self._datalen = [0 for _ in range(len(self._plugindata.qdlist))] @@ -93,11 +108,27 @@ def _is_done(self, datalen: list[int]) -> bool: done = False return done + def _block_rows( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int + ) -> Iterator[tuple[tuple[Any, ...], tuple[Any, ...]]]: + for block in data: + arr = block.data + meta_arr = block.meta + for idx in range(int(arr.shape[0])): + row = np.asarray(arr[idx]).reshape(-1) + data_t = tuple(row.tolist()) + if meta_arr is None: + meta_t: tuple[Any, ...] = () + else: + mrow = np.asarray(meta_arr[idx]).reshape(-1) + meta_t = tuple(mrow.tolist()) + yield data_t, meta_t + @abstractmethod - def _handle_samples( - self, data: list["DNxscopeStream"], pdata: "PluginQueueData", j: int + def _handle_blocks( + self, data: StreamBlocks, pdata: "PluginQueueData", j: int ) -> None: - """Handle samples from queue and update datalen.""" + """Handle block payload from queue and update datalen.""" @abstractmethod def _init(self) -> None: diff --git a/src/nxscli/stream_hub.py b/src/nxscli/stream_hub.py new file mode 100644 index 0000000..6771d20 --- /dev/null +++ b/src/nxscli/stream_hub.py @@ -0,0 +1,150 @@ +"""Shared physical stream fan-out provider for all plugins.""" + +import queue +from threading import Lock +from time import sleep +from typing import TYPE_CHECKING + +from nxslib.thread import ThreadCommon + +if TYPE_CHECKING: + from nxslib.dev import DeviceChannel + from nxslib.nxscope import DNxscopeStreamBlock, NxscopeHandler + + from nxscli.channelref import ChannelRef + + +class SharedStreamProvider: + """Provide shared physical stream queues to many consumers.""" + + def __init__(self) -> None: + """Initialize provider state.""" + self._lock = Lock() + self._nxscope: "NxscopeHandler | None" = None + self._started = False + self._source_subs: dict[ + int, queue.Queue[list["DNxscopeStreamBlock"]] + ] = {} + self._subscribers: dict[ + int, list[queue.Queue[list["DNxscopeStreamBlock"]]] + ] = {} + self._queue_to_channel: dict[int, int] = {} + self._thread = ThreadCommon(self._thread_common, name="streamhub") + self._poll_idx = 0 + + def on_connect(self, nxscope: "NxscopeHandler") -> None: + """Attach provider to active Nxscope handler.""" + with self._lock: + self._nxscope = nxscope + + def on_disconnect(self) -> None: + """Detach provider from Nxscope handler.""" + self.on_stream_stop() + with self._lock: + self._nxscope = None + self._subscribers = {} + self._queue_to_channel = {} + + def on_stream_start(self) -> None: + """Start fan-out thread and source subscriptions.""" + with self._lock: + if self._started: + return + if self._nxscope is None: + return + self._started = True + for chid in self._subscribers: + self._ensure_source_sub_locked(chid) + self._thread.thread_start() + + def on_stream_stop(self) -> None: + """Stop fan-out thread and upstream subscriptions.""" + with self._lock: + if not self._started: + return + self._started = False + self._thread.thread_stop() + with self._lock: + if self._nxscope is not None: + for subq in self._source_subs.values(): + self._nxscope.stream_unsub(subq) + self._source_subs = {} + + def channel_get(self, channel: "ChannelRef") -> "DeviceChannel | None": + """Return physical channel metadata from Nxscope.""" + with self._lock: + if self._nxscope is None: + return None + if not channel.is_physical: + return None + return self._nxscope.dev_channel_get(channel.physical_id()) + + def channel_list(self) -> tuple["DeviceChannel", ...]: + """No extra virtual channels are provided by this hub.""" + return () + + def stream_sub( + self, channel: "ChannelRef" + ) -> "queue.Queue[list[DNxscopeStreamBlock]] | None": + """Subscribe to physical channel fan-out queue.""" + if not channel.is_physical: + return None + chid = channel.physical_id() + with self._lock: + subq: queue.Queue[list["DNxscopeStreamBlock"]] = queue.Queue() + self._subscribers.setdefault(chid, []).append(subq) + self._queue_to_channel[id(subq)] = chid + if self._started: + self._ensure_source_sub_locked(chid) + return subq + + def stream_unsub( + self, subq: "queue.Queue[list[DNxscopeStreamBlock]]" + ) -> bool: + """Unsubscribe consumer queue from fan-out.""" + with self._lock: + qid = id(subq) + chid = self._queue_to_channel.get(qid) + if chid is None: + return False + self._queue_to_channel.pop(qid, None) + subs = self._subscribers.get(chid, []) + if subq in subs: + subs.remove(subq) + if not subs: + self._subscribers.pop(chid, None) + if self._nxscope is not None and chid in self._source_subs: + self._nxscope.stream_unsub(self._source_subs[chid]) + self._source_subs.pop(chid, None) + return True + + def _ensure_source_sub_locked(self, chid: int) -> None: + if self._nxscope is None: + return + if chid in self._source_subs: + return + self._source_subs[chid] = self._nxscope.stream_sub(chid) + + def _thread_common(self) -> None: + with self._lock: + if not self._started: + return + if not self._source_subs: + sleep(0.005) + return + items = list(self._source_subs.items()) + idx = self._poll_idx % len(items) + self._poll_idx += 1 + chid, srcq = items[idx] + dstq = list(self._subscribers.get(chid, [])) + + try: + blocks = srcq.get(block=True, timeout=0.02) + except queue.Empty: + return + + if not blocks: + return + + for subq in dstq: + subq.put(blocks) diff --git a/src/nxscli/transforms/models.py b/src/nxscli/transforms/models.py new file mode 100644 index 0000000..b2488fc --- /dev/null +++ b/src/nxscli/transforms/models.py @@ -0,0 +1,51 @@ +"""Data models used by shared transform operators.""" + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class WindowConfig: + """Windowed transform configuration.""" + + window: int + hop: int + + +@dataclass +class WindowCursor: + """Incremental window-processing cursor.""" + + last_count: int = 0 + + +@dataclass(frozen=True) +class FftResult: + """FFT result model.""" + + freq: Any + amplitude: Any + + +@dataclass(frozen=True) +class HistogramResult: + """Histogram result model.""" + + counts: Any + edges: Any + + +@dataclass(frozen=True) +class XyResult: + """XY relation result model.""" + + x: Any + y: Any + + +@dataclass(frozen=True) +class PolarResult: + """Polar relation result model.""" + + theta: Any + radius: Any diff --git a/src/nxscli/transforms/operators_sample.py b/src/nxscli/transforms/operators_sample.py new file mode 100644 index 0000000..e5105a7 --- /dev/null +++ b/src/nxscli/transforms/operators_sample.py @@ -0,0 +1,27 @@ +"""Sample-domain transform operators.""" + +from typing import Callable, Sequence + +import numpy as np + + +def apply_scale_offset( + samples: Sequence[float], scale: float = 1.0, offset: float = 0.0 +) -> np.ndarray: + """Apply affine transform to scalar samples.""" + arr = np.asarray(samples, dtype=np.float64) + return arr * float(scale) + float(offset) + + +def binary_op( + left: Sequence[float], + right: Sequence[float], + op: Callable[[np.ndarray, np.ndarray], np.ndarray], +) -> np.ndarray: + """Apply binary operation to aligned sample sequences.""" + la = np.asarray(left, dtype=np.float64) + ra = np.asarray(right, dtype=np.float64) + size = min(int(la.size), int(ra.size)) + if size <= 0: + return np.asarray([], dtype=np.float64) + return op(la[:size], ra[:size]) diff --git a/src/nxscli/transforms/operators_window.py b/src/nxscli/transforms/operators_window.py new file mode 100644 index 0000000..c7bf738 --- /dev/null +++ b/src/nxscli/transforms/operators_window.py @@ -0,0 +1,228 @@ +"""Window-domain transform operators.""" + +from typing import Sequence + +import numpy as np + +from nxscli.transforms.models import ( + FftResult, + HistogramResult, + PolarResult, + WindowCursor, + XyResult, +) +from nxscli.transforms.window_engine import ( + latest_window, + normalize_window_config, + should_recompute, +) + + +def _weights(window_fn: str, size: int) -> np.ndarray: + name = window_fn.lower() + if name == "hann": + return np.hanning(size) + if name == "hamming": + return np.hamming(size) + if name == "blackman": + return np.blackman(size) + return np.ones(size, dtype=np.float64) + + +def fft_spectrum( + samples: Sequence[float] | np.ndarray, + *, + sample_period: float = 1.0, + window_fn: str = "hann", +) -> FftResult: + """Compute one-sided FFT magnitude spectrum.""" + arr = np.asarray(samples, dtype=np.float64) + if arr.size < 2: + return FftResult( + freq=np.asarray([], dtype=np.float64), + amplitude=np.asarray([], dtype=np.float64), + ) + weighted = arr * _weights(window_fn, int(arr.size)) + freq = np.fft.rfftfreq(int(weighted.size), d=float(sample_period)) + amp = np.abs(np.fft.rfft(weighted)) + return FftResult( + freq=freq.astype(np.float64), + amplitude=amp.astype(np.float64), + ) + + +def histogram_counts( + samples: Sequence[float] | np.ndarray, + *, + bins: int, + range_mode: str = "auto", + value_range: tuple[float, float] | None = None, +) -> HistogramResult: + """Compute histogram bin counts.""" + arr = np.asarray(samples, dtype=np.float64) + if arr.size == 0: + return HistogramResult( + counts=np.asarray([], dtype=np.float64), + edges=np.asarray([], dtype=np.float64), + ) + + hist_range = None + if range_mode == "fixed": + if value_range is None: + raise ValueError( + "value_range must be provided for fixed range_mode" + ) + hist_range = (float(value_range[0]), float(value_range[1])) + + counts, edges = np.histogram( + arr, + bins=max(1, int(bins)), + range=hist_range, + ) + return HistogramResult( + counts=counts.astype(np.float64), + edges=edges.astype(np.float64), + ) + + +def xy_relation( + x_samples: Sequence[float], + y_samples: Sequence[float], + *, + window: int, + align_policy: str = "truncate", +) -> XyResult: + """Build XY relation from two sample series.""" + xa = np.asarray(x_samples, dtype=np.float64) + ya = np.asarray(y_samples, dtype=np.float64) + if align_policy != "truncate": + raise ValueError("unsupported align_policy") + size = min(int(xa.size), int(ya.size), max(2, int(window))) + if size <= 0: + return XyResult( + x=np.asarray([], dtype=np.float64), + y=np.asarray([], dtype=np.float64), + ) + return XyResult(x=xa[-size:], y=ya[-size:]) + + +def polar_relation( + x_samples: Sequence[float], + y_samples: Sequence[float], + *, + window: int, + align_policy: str = "truncate", +) -> PolarResult: + """Build polar relation from two sample series.""" + rel = xy_relation( + x_samples, + y_samples, + window=window, + align_policy=align_policy, + ) + if int(rel.x.size) == 0 or int(rel.y.size) == 0: + return PolarResult( + theta=np.asarray([], dtype=np.float64), + radius=np.asarray([], dtype=np.float64), + ) + theta = np.arctan2(rel.y, rel.x) + radius = np.hypot(rel.x, rel.y) + return PolarResult( + theta=theta.astype(np.float64), + radius=radius.astype(np.float64), + ) + + +def windowed_fft( + series: Sequence[float], + *, + window: int, + hop: int | None, + cursor: WindowCursor, + window_fn: str = "hann", + total_count: int | None = None, +) -> FftResult | None: + """Compute FFT only when hop criteria is satisfied.""" + cfg = normalize_window_config(window, hop) + current = len(series) if total_count is None else int(total_count) + if not should_recompute(current, cfg, cursor): + return None + arr = latest_window(series, cfg) + return fft_spectrum(arr, window_fn=window_fn) + + +def windowed_histogram( + series: Sequence[float], + *, + window: int, + hop: int | None, + bins: int, + range_mode: str, + cursor: WindowCursor, + total_count: int | None = None, + value_range: tuple[float, float] | None = None, +) -> HistogramResult | None: + """Compute histogram only when hop criteria is satisfied.""" + cfg = normalize_window_config(window, hop) + current = len(series) if total_count is None else int(total_count) + if not should_recompute(current, cfg, cursor): + return None + arr = latest_window(series, cfg) + return histogram_counts( + arr, + bins=bins, + range_mode=range_mode, + value_range=value_range, + ) + + +def windowed_xy( + x_series: Sequence[float], + y_series: Sequence[float], + *, + window: int, + hop: int | None, + align_policy: str, + cursor: WindowCursor, + total_count: int | None = None, +) -> XyResult | None: + """Compute XY relation only when hop criteria is satisfied.""" + cfg = normalize_window_config(window, hop) + if total_count is None: + total = min(len(x_series), len(y_series)) + else: + total = int(total_count) + if not should_recompute(total, cfg, cursor): + return None + return xy_relation( + x_series, + y_series, + window=cfg.window, + align_policy=align_policy, + ) + + +def windowed_polar( + x_series: Sequence[float], + y_series: Sequence[float], + *, + window: int, + hop: int | None, + align_policy: str, + cursor: WindowCursor, + total_count: int | None = None, +) -> PolarResult | None: + """Compute polar relation only when hop criteria is satisfied.""" + cfg = normalize_window_config(window, hop) + if total_count is None: + total = min(len(x_series), len(y_series)) + else: + total = int(total_count) + if not should_recompute(total, cfg, cursor): + return None + return polar_relation( + x_series, + y_series, + window=cfg.window, + align_policy=align_policy, + ) diff --git a/src/nxscli/transforms/pipeline.py b/src/nxscli/transforms/pipeline.py new file mode 100644 index 0000000..ec282b5 --- /dev/null +++ b/src/nxscli/transforms/pipeline.py @@ -0,0 +1,185 @@ +"""Shared transform pipeline for fan-out processing over one sample stream.""" + +from collections import deque +from dataclasses import dataclass +from typing import Callable, Mapping, Protocol, Sequence + +import numpy as np + +from nxscli.transforms.window_engine import ( + latest_window, + normalize_window_config, +) + + +class TransformProcessor(Protocol): + """Protocol for processors used by ``TransformPipeline``.""" + + @property + def name(self) -> str: + """Processor output name.""" + + def process(self, store: "SampleStore") -> object | None: + """Return transformed output or ``None`` when not ready.""" + + +class SampleStore: + """In-memory sample storage shared by all processors.""" + + def __init__(self, max_points: int | None = None) -> None: + """Initialize store. + + :param max_points: max points kept per channel, unbounded if ``None``. + """ + self._max_points = max_points + self._series: dict[str, deque[float]] = {} + self._count: dict[str, int] = {} + + def ingest(self, batch: Mapping[str, Sequence[float]]) -> None: + """Append a sample batch into the store.""" + for channel, values in batch.items(): + data = self._series.get(channel) + if data is None: + data = deque(maxlen=self._max_points) + self._series[channel] = data + self._count[channel] = 0 + for value in values: + data.append(float(value)) + self._count[channel] += 1 + + def count(self, channel: str) -> int: + """Return number of ingested samples for channel.""" + return self._count.get(channel, 0) + + def series(self, channel: str) -> np.ndarray: + """Return current channel series as float64 array.""" + values = self._series.get(channel) + if values is None: + return np.asarray([], dtype=np.float64) + return np.asarray(list(values), dtype=np.float64) + + +class TransformPipeline: + """Shared sample pipeline dispatching to many processors.""" + + def __init__(self, *, max_points: int | None = None) -> None: + """Initialize empty pipeline.""" + self._store = SampleStore(max_points=max_points) + self._processors: list[TransformProcessor] = [] + + @property + def store(self) -> SampleStore: + """Get shared sample store.""" + return self._store + + def register(self, processor: TransformProcessor) -> None: + """Register one processor.""" + self._processors.append(processor) + + def ingest( + self, batch: Mapping[str, Sequence[float]] + ) -> dict[str, object]: + """Ingest data and return outputs from ready processors.""" + self._store.ingest(batch) + ret: dict[str, object] = {} + for processor in self._processors: + value = processor.process(self._store) + if value is not None: + ret[processor.name] = value + return ret + + +@dataclass +class HopGate: + """Per-processor hop gate used by windowed processors.""" + + hop: int + last_count: int = 0 + + def ready(self, total_count: int) -> bool: + """Return ``True`` when processor should run for current count.""" + if total_count <= 0: + return False + if self.last_count == 0: + self.last_count = total_count + return True + if total_count - self.last_count >= self.hop: + self.last_count = total_count + return True + return False + + +class WindowUnaryProcessor: + """Windowed processor based on one source channel.""" + + def __init__( + self, + *, + name: str, + channel: str, + window: int, + hop: int | None, + fn: Callable[[np.ndarray], object], + ) -> None: + """Initialize unary processor.""" + self._name = name + self._channel = channel + self._cfg = normalize_window_config(window, hop) + self._gate = HopGate(hop=self._cfg.hop) + self._fn = fn + + @property + def name(self) -> str: + """Processor output name.""" + return self._name + + def process(self, store: SampleStore) -> object | None: + """Process latest channel window when hop gate allows it.""" + total = store.count(self._channel) + if not self._gate.ready(total): + return None + series = store.series(self._channel) + window = latest_window(series.tolist(), self._cfg) + return self._fn(window) + + +class WindowBinaryProcessor: + """Windowed processor based on two source channels.""" + + def __init__( + self, + *, + name: str, + left_channel: str, + right_channel: str, + window: int, + hop: int | None, + fn: Callable[[np.ndarray, np.ndarray], object], + ) -> None: + """Initialize binary processor.""" + self._name = name + self._left_channel = left_channel + self._right_channel = right_channel + self._cfg = normalize_window_config(window, hop) + self._gate = HopGate(hop=self._cfg.hop) + self._fn = fn + + @property + def name(self) -> str: + """Processor output name.""" + return self._name + + def process(self, store: SampleStore) -> object | None: + """Process latest pair of windows when hop gate allows it.""" + total = min( + store.count(self._left_channel), + store.count(self._right_channel), + ) + if not self._gate.ready(total): + return None + left_series = store.series(self._left_channel) + right_series = store.series(self._right_channel) + left = latest_window(left_series.tolist(), self._cfg) + right = latest_window(right_series.tolist(), self._cfg) + size = min(int(left.size), int(right.size)) + return self._fn(left[-size:], right[-size:]) diff --git a/src/nxscli/transforms/window_engine.py b/src/nxscli/transforms/window_engine.py new file mode 100644 index 0000000..1faebfc --- /dev/null +++ b/src/nxscli/transforms/window_engine.py @@ -0,0 +1,42 @@ +"""Shared window extraction and incremental scheduling logic.""" + +from typing import Sequence + +import numpy as np + +from nxscli.transforms.models import WindowConfig, WindowCursor + + +def normalize_window_config( + window: int, hop: int | None = None +) -> WindowConfig: + """Normalize window/hop values into a valid configuration.""" + win = max(2, int(window)) + if hop is None or int(hop) <= 0: + hop_n = max(1, win // 4) + else: + hop_n = max(1, int(hop)) + return WindowConfig(window=win, hop=hop_n) + + +def should_recompute( + total_count: int, cfg: WindowConfig, cursor: WindowCursor +) -> bool: + """Return True when enough new samples arrived for next window step.""" + if total_count <= 0: + return False + if cursor.last_count == 0: + cursor.last_count = total_count + return True + if total_count - cursor.last_count >= cfg.hop: + cursor.last_count = total_count + return True + return False + + +def latest_window(series: Sequence[float], cfg: WindowConfig) -> np.ndarray: + """Get latest signal window from an in-memory series.""" + arr = np.asarray(series, dtype=np.float64) + if arr.size <= cfg.window: + return arr + return arr[-cfg.window :] diff --git a/src/nxscli/trigger.py b/src/nxscli/trigger.py index 9b0488a..5e52b64 100644 --- a/src/nxscli/trigger.py +++ b/src/nxscli/trigger.py @@ -1,17 +1,16 @@ """Module containing Nxscli stream data trigger logic.""" -import itertools import weakref from copy import deepcopy from dataclasses import dataclass from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, Any, Iterator +from typing import Any -from nxscli.logger import logger +import numpy as np +from nxslib.nxscope import DNxscopeStreamBlock -if TYPE_CHECKING: - from nxslib.nxscope import DNxscopeStream +from nxscli.logger import logger ############################################################################### # Enum: ETriggerType @@ -132,13 +131,13 @@ def __new__(cls, chan: int, config: DTriggerConfig) -> "TriggerHandler": def __init__(self, chan: int, config: DTriggerConfig) -> None: """Initialize a stream data trigger handler.""" self._config = config - self._cache: list["DNxscopeStream"] = [] + self._cache: list[Any] = [] self._chan: int = chan self._trigger: DTriggerState = DTriggerState(False, 0) self._triger_done = False # trigger source channel reference - self._src: "TriggerHandler" | None = None + self._src: "TriggerHandler" | None = None # noqa: TC010 # connected cross channels self._cross: list["TriggerHandler"] = [] self._src_configured = False @@ -194,57 +193,71 @@ def _pending_crosschan(self) -> None: for inst in tmp: TriggerHandler._wait_for_src.remove(inst) - def _pairwise( - self, iterable: list["DNxscopeStream"] - ) -> Iterator[tuple["DNxscopeStream", "DNxscopeStream"]]: - (a, b) = itertools.tee(iterable) - next(b, None) - return zip(a, b) + @staticmethod + def _is_block_payload(data: list[Any]) -> bool: + return bool(data) and isinstance(data[0], DNxscopeStreamBlock) - def _alwaysoff(self, _: list["DNxscopeStream"]) -> DTriggerState: + def _alwaysoff(self, _: list[Any]) -> DTriggerState: # reset cache self._cache = [] return DTriggerState(False, 0) - def _alwayson(self, _: list["DNxscopeStream"]) -> DTriggerState: + def _alwayson(self, _: list[Any]) -> DTriggerState: return DTriggerState(True, 0) - def _edgerising( - self, combined: list["DNxscopeStream"], vect: int, level: float - ) -> DTriggerState: - ret = False - idx = 0 - for a, b in self._pairwise(combined): - if a.data[vect] <= level < b.data[vect]: - ret = True - idx = idx - break - idx += 1 + def _combined_vector_np( + self, combined: list[Any], vect: int + ) -> np.ndarray[Any, Any]: + if not combined: + return np.empty((0,), dtype=np.float64) + + if self._is_block_payload(combined): + parts: list[np.ndarray[Any, Any]] = [] + for block in combined: + parts.append( + np.asarray(block.data[:, vect], dtype=np.float64).reshape( + -1 + ) + ) + if not parts: # pragma: no cover + return np.empty((0,), dtype=np.float64) + if len(parts) == 1: + return parts[0] + return np.concatenate(parts) + + return np.fromiter( + (float(sample.data[vect]) for sample in combined), + dtype=np.float64, + count=len(combined), + ) - if not ret: - idx = 0 + def _combined_vector(self, combined: list[Any], vect: int) -> list[float]: + return [float(v) for v in self._combined_vector_np(combined, vect)] - return DTriggerState(ret, idx) + def _edgerising( + self, combined: list[Any], vect: int, level: float + ) -> DTriggerState: + vec = self._combined_vector_np(combined, vect) + if vec.size < 2: + return DTriggerState(False, 0) + hits = np.flatnonzero((vec[:-1] <= level) & (vec[1:] > level)) + if hits.size > 0: + return DTriggerState(True, int(hits[0])) + return DTriggerState(False, 0) def _edgefalling( - self, combined: list["DNxscopeStream"], vect: int, level: float + self, combined: list[Any], vect: int, level: float ) -> DTriggerState: - ret = False - idx = 0 - for a, b in self._pairwise(combined): - if a.data[vect] >= level > b.data[vect]: - ret = True - idx = idx - break - idx += 1 - - if not ret: - idx = 0 - - return DTriggerState(ret, idx) + vec = self._combined_vector_np(combined, vect) + if vec.size < 2: + return DTriggerState(False, 0) + hits = np.flatnonzero((vec[:-1] >= level) & (vec[1:] < level)) + if hits.size > 0: + return DTriggerState(True, int(hits[0])) + return DTriggerState(False, 0) def _is_self_trigger( - self, combined: list["DNxscopeStream"], config: DTriggerConfig + self, combined: list[Any], config: DTriggerConfig ) -> DTriggerState: if config.ttype is ETriggerType.ALWAYS_OFF: return self._alwaysoff(combined) @@ -259,7 +272,7 @@ def _is_self_trigger( else: raise AssertionError - def _is_triggered(self, combined: list["DNxscopeStream"]) -> DTriggerState: + def _is_triggered(self, combined: list[Any]) -> DTriggerState: if self._trigger.state: # make sure that idx is 0 return DTriggerState(True, 0) @@ -274,7 +287,7 @@ def _is_triggered(self, combined: list["DNxscopeStream"]) -> DTriggerState: # self-triggered return self._is_self_trigger(combined, self._config) - def _cross_channel_handle(self, combined: list["DNxscopeStream"]) -> None: + def _cross_channel_handle(self, combined: list[Any]) -> None: for cross in self._cross: if cross.cross_trigger.state is True: continue @@ -354,9 +367,42 @@ def unsubscribe_cross(self, inst: "TriggerHandler") -> None: if cross is inst: # pragma: no cover self._cross.pop(i) - def data_triggered( - self, data: list["DNxscopeStream"] - ) -> list["DNxscopeStream"]: + def _slice_from(self, combined: list[Any], start: int) -> list[Any]: + if start <= 0: + return combined + if not combined: + return [] + if not self._is_block_payload(combined): + return combined[start:] + + ret: list[Any] = [] + offset = start + for block in combined: + rows = int(block.data.shape[0]) + if offset >= rows: + offset -= rows + continue + if offset == 0: + ret.append(block) + else: + data = block.data[offset:, :] + meta = None if block.meta is None else block.meta[offset:, :] + ret.append(DNxscopeStreamBlock(data=data, meta=meta)) + offset = 0 + return ret + + def _cache_tail(self, combined: list[Any], hoffset: int) -> list[Any]: + if hoffset <= 0: + return combined + if not combined: + return [] + if not self._is_block_payload(combined): + return combined[-hoffset:] + total = sum(int(block.data.shape[0]) for block in combined) + start = max(total - hoffset, 0) + return self._slice_from(combined, start) + + def data_triggered(self, data: list[Any]) -> list[Any]: """Get triggered data. :param data: stream data @@ -372,8 +418,17 @@ def data_triggered( # not triggered yet ret = [] # update cache - clen = len(self._cache) - self._cache = combined[clen - self._config.hoffset :] + if self._is_block_payload(combined): + if self._config.hoffset <= 0: + # keep only current block batch when no history is needed + self._cache = data + else: + self._cache = self._cache_tail( + combined, self._config.hoffset + ) + else: + clen = len(self._cache) + self._cache = combined[clen - self._config.hoffset :] else: # one time hoffset for trigger if not self._triger_done: @@ -383,7 +438,7 @@ def data_triggered( hoffset = 0 # return data with a configured horisontal offset - ret = combined[self._trigger.idx - hoffset :] + ret = self._slice_from(combined, self._trigger.idx - hoffset) # reset cache self._cache = [] diff --git a/src/nxscli/virtual/errors.py b/src/nxscli/virtual/errors.py new file mode 100644 index 0000000..a1b0d31 --- /dev/null +++ b/src/nxscli/virtual/errors.py @@ -0,0 +1,5 @@ +"""Error types for virtual channels.""" + + +class VirtualChannelError(ValueError): + """Error raised for invalid virtual channel configuration/runtime.""" diff --git a/src/nxscli/virtual/manager.py b/src/nxscli/virtual/manager.py new file mode 100644 index 0000000..5274935 --- /dev/null +++ b/src/nxscli/virtual/manager.py @@ -0,0 +1,227 @@ +"""Virtual channel manager and execution graph.""" + +from dataclasses import dataclass +from typing import Callable + +from nxslib.dev import DeviceChannel + +from nxscli.virtual.errors import VirtualChannelError +from nxscli.virtual.models import ( + ChannelSpec, + SampleValue, + VirtualChannelSpec, +) +from nxscli.virtual.operators import ( + VirtualOperator, + default_operator_registry, +) + + +@dataclass +class _CompiledVirtualChannel: + """Internal node state for one virtual channel declaration.""" + + spec: VirtualChannelSpec + outputs: tuple[ChannelSpec, ...] + output_ids: tuple[str, ...] + operator: VirtualOperator + + +class VirtualChannelManager: + """Manage virtual channel declarations and execution graph.""" + + def __init__( + self, + operators: dict[str, Callable[[], VirtualOperator]] | None = None, + ) -> None: + """Initialize manager with operator registry.""" + self._operators = operators or default_operator_registry() + self._channels: dict[str, ChannelSpec] = {} + self._compiled: dict[str, _CompiledVirtualChannel] = {} + self._output_owner: dict[str, str] = {} + self._last_values: dict[str, SampleValue] = {} + self._order: list[str] = [] + + def add_physical_channel(self, spec: ChannelSpec | DeviceChannel) -> None: + """Add physical channel metadata.""" + if isinstance(spec, DeviceChannel): + spec = ChannelSpec.from_device_channel(spec) + if spec.channel_id in self._channels: + raise VirtualChannelError( + f"Channel already exists: {spec.channel_id}" + ) + self._channels[spec.channel_id] = spec + + def add_virtual_channel( + self, spec: VirtualChannelSpec + ) -> tuple[ChannelSpec, ...]: + """Register one virtual channel and rebuild execution order.""" + if spec.channel_id in self._compiled: + raise VirtualChannelError( + f"Virtual channel already exists: {spec.channel_id}" + ) + if spec.operator not in self._operators: + raise VirtualChannelError(f"Unknown operator: {spec.operator}") + if len(spec.inputs) == 0: + raise VirtualChannelError( + "Virtual channel requires at least one input" + ) + + input_specs: list[ChannelSpec] = [] + for input_id in spec.inputs: + input_spec = self.channel_spec(input_id) + if input_spec is None: + raise VirtualChannelError(f"Unknown input channel: {input_id}") + input_specs.append(input_spec) + + operator = self._operators[spec.operator]() + operator.configure(spec, tuple(input_specs)) + outputs = operator.describe_outputs(spec) + if len(outputs) == 0: + raise VirtualChannelError( + "Operator must provide at least one output" + ) + + output_ids = tuple(out.channel_id for out in outputs) + for output in outputs: + if output.channel_id in self._channels: + raise VirtualChannelError( + f"Output channel already exists: {output.channel_id}" + ) + + self._compiled[spec.channel_id] = _CompiledVirtualChannel( + spec=spec, + outputs=outputs, + output_ids=output_ids, + operator=operator, + ) + for output in outputs: + self._channels[output.channel_id] = output + self._output_owner[output.channel_id] = spec.channel_id + self._rebuild_order() + return outputs + + def channel_spec(self, channel_id: str) -> ChannelSpec | None: + """Return channel metadata by ID.""" + return self._channels.get(channel_id) + + def channel_specs(self) -> tuple[ChannelSpec, ...]: + """Return all physical and virtual channel specs.""" + return tuple(self._channels.values()) + + def physical_channel_ids(self) -> tuple[str, ...]: + """Return IDs of physical channels only.""" + return tuple( + channel_id + for channel_id in self._channels + if channel_id not in self._output_owner + ) + + def required_physical_channel_ids(self) -> tuple[str, ...]: + """Return physical channel IDs required by declared virtual graph.""" + required: dict[str, None] = {} + for node_id in self._order: + compiled = self._compiled[node_id] + for input_id in compiled.spec.inputs: + if input_id not in self._output_owner: + required[input_id] = None + return tuple(required.keys()) + + def process_sample( + self, physical_values: dict[str, SampleValue] + ) -> dict[str, SampleValue]: + """Process one sample tick and return full channel map.""" + result: dict[str, SampleValue] = dict(physical_values) + + for channel_id in self._channels: + if channel_id not in self._output_owner: + if channel_id not in result: + raise VirtualChannelError( + f"Missing physical channel value: {channel_id}" + ) + + for node_id in self._order: + compiled = self._compiled[node_id] + if not compiled.spec.enabled: + continue + inputs: list[SampleValue] = [] + for input_id in compiled.spec.inputs: + if input_id not in result: + raise VirtualChannelError( + f"Input value not available: {input_id}" + ) + inputs.append(result[input_id]) + outputs = compiled.operator.process(tuple(inputs)) + if len(outputs) != len(compiled.output_ids): + raise VirtualChannelError("Operator returned invalid outputs") + for output_id, value in zip(compiled.output_ids, outputs): + result[output_id] = value + return result + + def process_update( + self, channel_id: str, value: SampleValue + ) -> dict[str, SampleValue]: + """Process one physical channel update and return changed virtuals.""" + spec = self._channels.get(channel_id) + if spec is None or channel_id in self._output_owner: + raise VirtualChannelError( + f"Unknown physical channel: {channel_id}" + ) + + if len(value) != spec.vdim: + raise VirtualChannelError( + f"Invalid sample vdim for {channel_id}: " + f"expected {spec.vdim}, got {len(value)}" + ) + + self._last_values[channel_id] = value + changed: dict[str, SampleValue] = {} + for node_id in self._order: + compiled = self._compiled[node_id] + if not compiled.spec.enabled: + continue + if not all( + inp in self._last_values for inp in compiled.spec.inputs + ): + continue + inputs = tuple( + self._last_values[inp] for inp in compiled.spec.inputs + ) + outputs = compiled.operator.process(inputs) + if len(outputs) != len(compiled.output_ids): + raise VirtualChannelError("Operator returned invalid outputs") + for output_id, out_value in zip(compiled.output_ids, outputs): + self._last_values[output_id] = out_value + changed[output_id] = out_value + return changed + + def reset(self) -> None: + """Reset all virtual operators.""" + self._last_values.clear() + for compiled in self._compiled.values(): + compiled.operator.reset() + + def _rebuild_order(self) -> None: + """Rebuild topological execution order with cycle detection.""" + visiting: set[str] = set() + visited: set[str] = set() + order: list[str] = [] + + def visit(node: str) -> None: + if node in visited: + return + if node in visiting: + raise VirtualChannelError("Virtual channel graph has a cycle") + visiting.add(node) + compiled = self._compiled[node] + for input_id in compiled.spec.inputs: + owner = self._output_owner.get(input_id) + if owner is not None: + visit(owner) + visiting.remove(node) + visited.add(node) + order.append(node) + + for node in self._compiled: + visit(node) + self._order = order diff --git a/src/nxscli/virtual/models.py b/src/nxscli/virtual/models.py new file mode 100644 index 0000000..e302233 --- /dev/null +++ b/src/nxscli/virtual/models.py @@ -0,0 +1,96 @@ +"""Data models for virtual channels.""" + +from dataclasses import dataclass, field + +from nxslib.dev import DeviceChannel, EDeviceChannelType + + +@dataclass(frozen=True) +class ChannelSpec: + """Channel metadata used by the virtual-channel graph.""" + + channel_id: str + name: str + dtype: int + vdim: int + data_kind: str = "timeseries" + device_channel: DeviceChannel = field(init=False, repr=False) + + def __post_init__(self) -> None: + """Back this spec with nxslib ``DeviceChannel``.""" + dtype = _to_dtype_code(self.dtype) + object.__setattr__(self, "dtype", dtype) + object.__setattr__( + self, + "device_channel", + DeviceChannel( + chan=_parse_channel_number(self.channel_id), + _type=dtype, + vdim=self.vdim, + name=self.name, + ), + ) + + @classmethod + def from_device_channel( + cls, + channel: DeviceChannel, + channel_id: str | None = None, + data_kind: str = "timeseries", + ) -> "ChannelSpec": + """Create spec from an existing nxslib ``DeviceChannel``.""" + return cls( + channel_id=channel_id or str(channel.data.chan), + name=channel.data.name, + dtype=channel.data.dtype, + vdim=channel.data.vdim, + data_kind=data_kind, + ) + + +@dataclass(frozen=True) +class VirtualChannelSpec: + """Virtual channel declaration.""" + + channel_id: str + name: str + operator: str + inputs: tuple[str, ...] + params: dict[str, object] = field(default_factory=dict) + enabled: bool = True + + +SampleValue = tuple[float, ...] + + +def to_float(value: object, fallback: float) -> float: + """Convert parameter value to float with fallback.""" + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return fallback + return fallback + + +def _parse_channel_number(channel_id: str) -> int: + """Parse stream ID to numeric channel number when possible.""" + try: + return int(channel_id) + except ValueError: + return -1 + + +def _to_dtype_code(value: object) -> int: + """Normalize virtual dtype into nxslib channel type code.""" + if isinstance(value, int): + return value + if isinstance(value, str): + key = value.strip().upper() + if key == "INT": + key = "INT32" + if hasattr(EDeviceChannelType, key): + return int(getattr(EDeviceChannelType, key).value) + return int(EDeviceChannelType.FLOAT.value) diff --git a/src/nxscli/virtual/operators.py b/src/nxscli/virtual/operators.py new file mode 100644 index 0000000..909b04c --- /dev/null +++ b/src/nxscli/virtual/operators.py @@ -0,0 +1,253 @@ +"""Built-in virtual operators.""" + +import math +from typing import Callable, Protocol + +from nxslib.dev import EDeviceChannelType + +from nxscli.virtual.errors import VirtualChannelError +from nxscli.virtual.models import ( + ChannelSpec, + SampleValue, + VirtualChannelSpec, + to_float, +) + + +class VirtualOperator(Protocol): + """Protocol implemented by virtual-channel operators.""" + + def configure( + self, + spec: VirtualChannelSpec, + inputs: tuple[ChannelSpec, ...], + ) -> None: + """Validate and store operator configuration.""" + + def describe_outputs( + self, spec: VirtualChannelSpec + ) -> tuple[ChannelSpec, ...]: + """Return output channel metadata.""" + + def process( + self, inputs: tuple[SampleValue, ...] + ) -> tuple[SampleValue, ...]: + """Process one sample tick.""" + + def reset(self) -> None: + """Reset internal operator state.""" + + +class ScaleOffsetOperator: + """Apply ``out = in * scale + offset`` element-wise.""" + + def __init__(self) -> None: + """Initialize defaults.""" + self._scale = 1.0 + self._offset = 0.0 + self._vdim = 1 + + def configure( + self, + spec: VirtualChannelSpec, + inputs: tuple[ChannelSpec, ...], + ) -> None: + """Validate inputs and parse parameters.""" + if len(inputs) != 1: + raise VirtualChannelError("scale_offset expects exactly one input") + self._vdim = inputs[0].vdim + self._scale = to_float(spec.params.get("scale", 1.0), 1.0) + self._offset = to_float(spec.params.get("offset", 0.0), 0.0) + + def describe_outputs( + self, spec: VirtualChannelSpec + ) -> tuple[ChannelSpec, ...]: + """Describe single transformed output.""" + return ( + ChannelSpec( + channel_id=spec.channel_id, + name=spec.name, + dtype=EDeviceChannelType.FLOAT.value, + vdim=self._vdim, + ), + ) + + def process( + self, inputs: tuple[SampleValue, ...] + ) -> tuple[SampleValue, ...]: + """Apply scale/offset for one sample tick.""" + src = inputs[0] + return (tuple((x * self._scale) + self._offset for x in src),) + + def reset(self) -> None: + """Stateless operator reset.""" + return + + +class MathBinaryOperator: + """Element-wise binary math operation.""" + + _OPS: dict[str, Callable[[float, float], float]] = { + "add": lambda a, b: a + b, + "sub": lambda a, b: a - b, + "mul": lambda a, b: a * b, + "div": lambda a, b: a / b, + "min": lambda a, b: a if a < b else b, + "max": lambda a, b: a if a > b else b, + } + + def __init__(self) -> None: + """Initialize defaults.""" + self._vdim = 1 + self._op: Callable[[float, float], float] = self._OPS["add"] + + def configure( + self, + spec: VirtualChannelSpec, + inputs: tuple[ChannelSpec, ...], + ) -> None: + """Validate inputs and operation kind.""" + if len(inputs) != 2: + raise VirtualChannelError("math_binary expects exactly two inputs") + if inputs[0].vdim != inputs[1].vdim: + raise VirtualChannelError("math_binary inputs must have same vdim") + op_name = str(spec.params.get("op", "add")) + if op_name not in self._OPS: + raise VirtualChannelError(f"Unsupported math operation: {op_name}") + self._op = self._OPS[op_name] + self._vdim = inputs[0].vdim + + def describe_outputs( + self, spec: VirtualChannelSpec + ) -> tuple[ChannelSpec, ...]: + """Describe single math output.""" + return ( + ChannelSpec( + channel_id=spec.channel_id, + name=spec.name, + dtype=EDeviceChannelType.FLOAT.value, + vdim=self._vdim, + ), + ) + + def process( + self, inputs: tuple[SampleValue, ...] + ) -> tuple[SampleValue, ...]: + """Apply selected binary operation for one sample tick.""" + left, right = inputs + out = tuple(self._op(a, b) for a, b in zip(left, right)) + return (out,) + + def reset(self) -> None: + """Stateless operator reset.""" + return + + +class RunningStatsOperator: + """Track running ``min,max,avg,rms`` and emit separate output streams.""" + + def __init__(self) -> None: + """Initialize running stats state.""" + self.reset() + + def configure( + self, + spec: VirtualChannelSpec, + inputs: tuple[ChannelSpec, ...], + ) -> None: + """Validate one input and initialize per-dimension state.""" + if len(inputs) != 1: + raise VirtualChannelError( + "stats_running expects exactly one input" + ) + if inputs[0].vdim <= 0: + raise VirtualChannelError("stats_running requires non-empty input") + if bool(spec.params): + raise VirtualChannelError("stats_running does not accept params") + vdim = inputs[0].vdim + self.reset() + self._vdim = vdim + self._sum = [0.0 for _ in range(vdim)] + self._sum_sq = [0.0 for _ in range(vdim)] + self._min = [0.0 for _ in range(vdim)] + self._max = [0.0 for _ in range(vdim)] + + def describe_outputs( + self, spec: VirtualChannelSpec + ) -> tuple[ChannelSpec, ...]: + """Describe four stat output streams.""" + return ( + ChannelSpec( + channel_id=f"{spec.channel_id}.min", + name=f"{spec.name}_min", + dtype=EDeviceChannelType.FLOAT.value, + vdim=self._vdim, + data_kind="stats", + ), + ChannelSpec( + channel_id=f"{spec.channel_id}.max", + name=f"{spec.name}_max", + dtype=EDeviceChannelType.FLOAT.value, + vdim=self._vdim, + data_kind="stats", + ), + ChannelSpec( + channel_id=f"{spec.channel_id}.avg", + name=f"{spec.name}_avg", + dtype=EDeviceChannelType.FLOAT.value, + vdim=self._vdim, + data_kind="stats", + ), + ChannelSpec( + channel_id=f"{spec.channel_id}.rms", + name=f"{spec.name}_rms", + dtype=EDeviceChannelType.FLOAT.value, + vdim=self._vdim, + data_kind="stats", + ), + ) + + def process( + self, inputs: tuple[SampleValue, ...] + ) -> tuple[SampleValue, ...]: + """Update and return min/max/avg/rms outputs.""" + values = inputs[0] + if len(values) != self._vdim: + raise VirtualChannelError("stats_running input vdim mismatch") + + if self._count == 0: + self._min = list(values) + self._max = list(values) + else: + for i, value in enumerate(values): + if value < self._min[i]: + self._min[i] = value + if value > self._max[i]: + self._max[i] = value + + self._count += 1 + for i, value in enumerate(values): + self._sum[i] += value + self._sum_sq[i] += value * value + + avg = tuple(total / self._count for total in self._sum) + rms = tuple(math.sqrt(total / self._count) for total in self._sum_sq) + return (tuple(self._min), tuple(self._max), avg, rms) + + def reset(self) -> None: + """Reset running counters and accumulators.""" + self._vdim = 1 + self._count = 0 + self._sum = [0.0] + self._sum_sq = [0.0] + self._min = [0.0] + self._max = [0.0] + + +def default_operator_registry() -> dict[str, Callable[[], VirtualOperator]]: + """Return built-in virtual-channel operator factories.""" + return { + "scale_offset": ScaleOffsetOperator, + "math_binary": MathBinaryOperator, + "stats_running": RunningStatsOperator, + } diff --git a/src/nxscli/virtual/runtime.py b/src/nxscli/virtual/runtime.py new file mode 100644 index 0000000..1a6c0a9 --- /dev/null +++ b/src/nxscli/virtual/runtime.py @@ -0,0 +1,368 @@ +"""Shared virtual-channel runtime for nxscli stream pipeline.""" + +import queue +from dataclasses import dataclass +from threading import Lock +from time import sleep +from typing import TYPE_CHECKING + +import numpy as np +from nxslib.dev import DeviceChannel +from nxslib.nxscope import DNxscopeStreamBlock, NxscopeHandler +from nxslib.thread import ThreadCommon + +from nxscli.virtual.errors import VirtualChannelError +from nxscli.virtual.manager import VirtualChannelManager +from nxscli.virtual.models import SampleValue, VirtualChannelSpec + +if TYPE_CHECKING: + from nxscli.channelref import ChannelRef + + +@dataclass(frozen=True) +class DeclaredVirtualChannel: + """One declared virtual channel with fixed output aliases.""" + + spec: VirtualChannelSpec + output_ids: tuple[str, ...] + aliases: tuple[str, ...] + aliased_names: tuple[str, ...] + + +class VirtualStreamRuntime: + """Expose virtual channels as normal stream channels.""" + + def __init__(self) -> None: + """Initialize runtime state.""" + self._lock = Lock() + self._nxscope: NxscopeHandler | None = None + self._manager = VirtualChannelManager() + self._declared: list[DeclaredVirtualChannel] = [] + self._alias_to_output_id: dict[str, str] = {} + self._output_id_to_alias: dict[str, str] = {} + self._channels: dict[str, DeviceChannel] = {} + self._subscribers: dict[ + str, list[queue.Queue[list[DNxscopeStreamBlock]]] + ] = {} + self._physical_subs: dict[ + int, queue.Queue[list[DNxscopeStreamBlock]] + ] = {} + self._thread = ThreadCommon(self._thread_common, name="virtstream") + self._poll_idx = 0 + self._started = False + + def add_virtual_channel( + self, + channel_id: int, + name: str, + operator: str, + inputs: tuple[str, ...], + params: dict[str, object], + ) -> tuple[tuple[str, str], ...]: + """Add one virtual channel declaration. + + :return: tuple of ``(alias_channel_id, internal_output_id)`` + """ + if channel_id < 0: + raise VirtualChannelError("virtual channel id must be >= 0") + + with self._lock: + base = f"v{channel_id}" + output_ids: tuple[str, ...] + aliases: tuple[str, ...] + alias_names: tuple[str, ...] + if operator == "stats_running": + output_ids = ( + f"{base}.min", + f"{base}.max", + f"{base}.avg", + f"{base}.rms", + ) + aliases = ( + f"v{channel_id}", + f"v{channel_id + 1}", + f"v{channel_id + 2}", + f"v{channel_id + 3}", + ) + alias_names = ( + f"v{channel_id}", + f"v{channel_id + 1}", + f"v{channel_id + 2}", + f"v{channel_id + 3}", + ) + else: + output_ids = (base,) + aliases = (f"v{channel_id}",) + alias_names = (f"v{channel_id}",) + + for out_id in output_ids: + if out_id in self._output_id_to_alias: + raise VirtualChannelError( + f"virtual output already exists: {out_id}" + ) + for alias_name in alias_names: + if alias_name in self._alias_to_output_id: + raise VirtualChannelError( + f"virtual alias already exists: {alias_name}" + ) + + resolved_inputs = tuple( + self._alias_to_output_id.get( + self._normalize_input_token(token), token + ) + for token in inputs + ) + spec = VirtualChannelSpec( + channel_id=base, + name=name, + operator=operator, + inputs=resolved_inputs, + params=dict(params), + ) + self._declared.append( + DeclaredVirtualChannel( + spec=spec, + output_ids=output_ids, + aliases=aliases, + aliased_names=alias_names, + ) + ) + for out_id, alias in zip(output_ids, aliases): + self._output_id_to_alias[out_id] = alias + for out_id, alias_name in zip(output_ids, alias_names): + self._alias_to_output_id[alias_name] = out_id + + if self._nxscope is not None: + self._rebuild_locked() + + return tuple( + (alias, out_id) for alias, out_id in zip(aliases, output_ids) + ) + + def clear(self) -> None: + """Remove all declared virtual channels.""" + with self._lock: + self._declared = [] + self._alias_to_output_id = {} + self._output_id_to_alias = {} + self._channels = {} + self._subscribers = {} + self._manager = VirtualChannelManager() + + def declared(self) -> tuple[DeclaredVirtualChannel, ...]: + """Return current declarations.""" + with self._lock: + return tuple(self._declared) + + def channel_get(self, channel: "ChannelRef") -> DeviceChannel | None: + """Return runtime channel metadata.""" + if not channel.is_virtual: + return None + chid = channel.virtual_name() + with self._lock: + return self._channels.get(chid) + + def channel_list(self) -> tuple[DeviceChannel, ...]: + """Return all runtime channels.""" + with self._lock: + return tuple(self._channels.values()) + + def stream_sub( + self, channel: "ChannelRef" + ) -> queue.Queue[list[DNxscopeStreamBlock]] | None: + """Subscribe queue for virtual channel.""" + if not channel.is_virtual: + return None + chan = channel.virtual_name() + with self._lock: + if chan not in self._channels: + return None + subq: queue.Queue[list[DNxscopeStreamBlock]] = queue.Queue() + self._subscribers.setdefault(chan, []).append(subq) + return subq + + def stream_unsub( + self, subq: queue.Queue[list[DNxscopeStreamBlock]] + ) -> bool: + """Unsubscribe queue. Returns ``True`` if removed.""" + with self._lock: + for chan in list(self._subscribers.keys()): + subs = self._subscribers[chan] + if subq in subs: + subs.remove(subq) + if not subs: + del self._subscribers[chan] + return True + return False + + def on_connect(self, nxscope: NxscopeHandler) -> None: + """Attach runtime to connected Nxscope handler.""" + with self._lock: + self._nxscope = nxscope + self._rebuild_locked() + + def on_disconnect(self) -> None: + """Detach runtime from Nxscope handler.""" + self.on_stream_stop() + with self._lock: + self._nxscope = None + self._manager = VirtualChannelManager() + self._channels = {} + + def on_stream_start(self) -> None: + """Start runtime streaming thread.""" + with self._lock: + if self._started: + return + if self._nxscope is None: + return + if not self._declared: + return + + self._physical_subs = {} + for channel_id in self._manager.required_physical_channel_ids(): + chid = int(channel_id) + self._physical_subs[chid] = self._nxscope.stream_sub(chid) + self._started = True + + self._thread.thread_start() + + def on_stream_stop(self) -> None: + """Stop runtime streaming thread.""" + with self._lock: + if not self._started: + return + self._started = False + self._thread.thread_stop() + with self._lock: + if self._nxscope is not None: + for subq in self._physical_subs.values(): + self._nxscope.stream_unsub(subq) + self._physical_subs = {} + self._manager.reset() + + def _rebuild_locked(self) -> None: + assert self._nxscope is not None + assert self._nxscope.dev is not None + + manager = VirtualChannelManager() + for chid in range(self._nxscope.dev.data.chmax): + channel = self._nxscope.dev_channel_get(chid) + if channel is not None and channel.data.is_valid: + manager.add_physical_channel(channel) + + channels: dict[str, DeviceChannel] = {} + chan_idx = -1 + for declared in self._declared: + outputs = manager.add_virtual_channel(declared.spec) + got_ids = tuple(out.channel_id for out in outputs) + if got_ids != declared.output_ids: + raise VirtualChannelError( + "virtual output ids mismatch for declared channel " + f"{declared.spec.channel_id}" + ) + for output, alias in zip(outputs, declared.aliases): + channels[alias] = DeviceChannel( + chan=chan_idx, + _type=output.dtype, + vdim=output.vdim, + name=alias, + ) + chan_idx -= 1 + + self._manager = manager + self._channels = channels + self._subscribers = { + chan: self._subscribers.get(chan, []) for chan in channels + } + + def _to_sample(self, sample: object) -> SampleValue | None: + try: + arr = np.asarray(sample, dtype=np.float64).reshape(-1) + return tuple(float(x) for x in arr.tolist()) + except (TypeError, ValueError): + return None + + def _normalize_input_token(self, token: str) -> str: + tok = token.strip() + if tok.startswith("v"): + vid = tok[1:] + if not vid.isnumeric(): + raise VirtualChannelError( + f"invalid virtual channel input: {token}" + ) + return f"v{int(vid)}" + return tok + + def _thread_common(self) -> None: + with self._lock: + subs = list(self._physical_subs.items()) + + if not subs: + sleep(0.05) + return + + chid, subq = subs[self._poll_idx % len(subs)] + self._poll_idx += 1 + try: + batch = subq.get(block=True, timeout=0.05) + except queue.Empty: + return + + out_batches = self._process_batch(chid, batch) + + if not out_batches: + return + + with self._lock: + for alias, samples in out_batches.items(): + for qsub in self._subscribers.get(alias, []): + qsub.put(samples) + + def _process_batch( + self, + chid: int, + batch: list[DNxscopeStreamBlock], + ) -> dict[str, list[DNxscopeStreamBlock]]: + out_rows = self._collect_output_rows(chid, batch) + return self._build_output_blocks(out_rows) + + def _collect_output_rows( + self, + chid: int, + batch: list[DNxscopeStreamBlock], + ) -> dict[str, list[SampleValue]]: + out_rows: dict[str, list[SampleValue]] = {} + for item in batch: + arr = item.data + if int(arr.shape[0]) == 0: + continue + for row in arr: + sample = self._to_sample(row) + if sample is None: + continue + try: + changed = self._manager.process_update(str(chid), sample) + except VirtualChannelError: + continue + for out_id, out_value in changed.items(): + alias = self._output_id_to_alias.get(out_id) + if alias is None: + continue + out_rows.setdefault(alias, []).append(out_value) + return out_rows + + def _build_output_blocks( + self, + out_rows: dict[str, list[SampleValue]], + ) -> dict[str, list[DNxscopeStreamBlock]]: + out_batches: dict[str, list[DNxscopeStreamBlock]] = {} + for alias, rows in out_rows.items(): + chan = self._channels.get(alias) + if chan is None or not rows: + continue + vdim = int(chan.data.vdim) + arr = np.asarray(rows, dtype=np.float64).reshape(-1, vdim) + out_batches[alias] = [DNxscopeStreamBlock(data=arr, meta=None)] + + return out_batches diff --git a/src/nxscli/virtual/services.py b/src/nxscli/virtual/services.py new file mode 100644 index 0000000..5f0b8c7 --- /dev/null +++ b/src/nxscli/virtual/services.py @@ -0,0 +1,23 @@ +"""Helpers for integrating virtual runtime into nxscli.""" + +from typing import TYPE_CHECKING + +from nxscli.virtual.runtime import VirtualStreamRuntime + +if TYPE_CHECKING: + from nxscli.istream import IServiceRegistry + +SERVICE_KEY = "nxscli.virtual" + + +def get_runtime(registry: "IServiceRegistry") -> VirtualStreamRuntime: + """Get or create shared virtual runtime service.""" + runtime = registry.service_get(SERVICE_KEY) + if runtime is not None: + assert isinstance(runtime, VirtualStreamRuntime) + return runtime + + runtime = VirtualStreamRuntime() + registry.service_set(SERVICE_KEY, runtime) + registry.stream_provider_add(runtime) + return runtime diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py index 118baa9..501d3e8 100644 --- a/tests/cli/test_main.py +++ b/tests/cli/test_main.py @@ -1,13 +1,21 @@ +import socket + import pytest # type: ignore from click.testing import CliRunner import nxscli from nxscli.cli.main import main +from tests.fake_nxscope import FakeNxscope @pytest.fixture def runner(mocker): mocker.patch.object(nxscli.cli.main, "wait_for_plugins", autospec=True) + mocker.patch.object( + nxscli.commands.interface.cmd_dummy, + "NxscopeHandler", + FakeNxscope, + ) return CliRunner() @@ -21,6 +29,61 @@ def test_main_dummy(runner): assert result.exit_code == 2 +@pytest.mark.parametrize("_has_af_unix", [True, False]) +def test_main_control_server_enabled(runner, monkeypatch, _has_af_unix): + if _has_af_unix: + monkeypatch.setattr(socket, "AF_UNIX", 1, raising=False) + else: + monkeypatch.delattr(socket, "AF_UNIX", raising=False) + if hasattr(socket, "AF_UNIX"): + endpoint = "unix-abstract://nxscli-test-control" + else: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + host, port = sock.getsockname() + endpoint = f"tcp://{host}:{port}" + + args = [ + "--control-server", + "--control-endpoint", + endpoint, + "dummy", + "pdevinfo", + ] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + +def test_main_control_server_enabled_init_failure(runner, monkeypatch): + cleanup_called = {"flag": False} + + def fail_init(*_, **__): + raise RuntimeError("boom") + + orig_cleanup = nxscli.cli.main.PluginHandler.cleanup + + def cleanup(self): + cleanup_called["flag"] = True + orig_cleanup(self) + + monkeypatch.setattr("nxscli.cli.main.ControlServerPlugin", fail_init) + monkeypatch.setattr( + "nxscli.cli.main.PluginHandler.cleanup", cleanup, raising=False + ) + result = runner.invoke( + main, + [ + "--control-server", + "--control-endpoint", + "tcp://127.0.0.1:12345", + "dummy", + "pdevinfo", + ], + ) + assert result.exit_code != 0 + assert cleanup_called["flag"] is True + + def test_main_pdevinfo(runner): args = ["dummy", "pdevinfo"] result = runner.invoke(main, args) @@ -96,6 +159,48 @@ def test_main_pcsv(runner): assert result.exit_code == 0 +def test_main_pnpsave(runner): + args = ["chan", "1", "pnpsave", "1", "./test"] + result = runner.invoke(main, args) + assert result.exit_code == 2 + + with runner.isolated_filesystem(): + args = ["dummy", "chan", "1", "pnpsave", "1", "./test"] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + with runner.isolated_filesystem(): + args = ["dummy", "chan", "1", "pnpsave", "1000", "./test"] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + with runner.isolated_filesystem(): + args = ["dummy", "chan", "5", "pnpsave", "1", "./test"] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + +def test_main_pnpmem(runner): + args = ["chan", "1", "pnpmem", "1", "./test", "100"] + result = runner.invoke(main, args) + assert result.exit_code == 2 + + with runner.isolated_filesystem(): + args = ["dummy", "chan", "1", "pnpmem", "10", "./test", "100"] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + with runner.isolated_filesystem(): + args = ["dummy", "chan", "1", "pnpmem", "10", "./test", "100"] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + with runner.isolated_filesystem(): + args = ["dummy", "chan", "5", "pnpmem", "400", "./test", "200"] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + def test_main_pnone(runner): args = ["chan", "1", "pnone", "1"] result = runner.invoke(main, args) @@ -196,3 +301,37 @@ def test_main_help(runner): args = ["dummy", "trig", "--help"] result = runner.invoke(main, args) assert result.exit_code == 0 + + +def test_main_vadd(runner): + args = ["dummy", "vadd", "--operator", "scale_offset", "0", "0"] + result = runner.invoke(main, args) + assert result.exit_code == 0 + + +def test_main_vadd_pprinter_virtual_output(runner): + args = [ + "dummy", + "vadd", + "--operator", + "scale_offset", + "--params", + "scale=2,offset=1", + "100", + "0", + "pprinter", + "--chan", + "v100", + "3", + ] + result = runner.invoke(main, args) + assert result.exit_code == 0 + assert "virtual output v100 -> channel v100" in result.output + assert "'chan': -1" in result.output + assert "1:" in result.output + + +def test_main_udp_help(runner): + args = ["udp", "--help"] + result = runner.invoke(main, args) + assert result.exit_code == 0 diff --git a/tests/cli/test_types.py b/tests/cli/test_types.py index e0f3d9d..6dd4240 100644 --- a/tests/cli/test_types.py +++ b/tests/cli/test_types.py @@ -2,6 +2,7 @@ import pytest from nxscli.cli.types import ( + Channels, Samples, StringList, StringList2, @@ -32,6 +33,25 @@ def test_samples(): s.convert("a", None, None) +def test_channels() -> None: + ch = Channels() + + all_ch = ch.convert("all", None, None) + assert len(all_ch) == 1 and all_ch[0].is_all + + phy = ch.convert("0,1", None, None) + assert [x.physical_id() for x in phy] == [0, 1] + + virt = ch.convert("v0,v12", None, None) + assert [x.virtual_name() for x in virt] == ["v0", "v12"] + + with pytest.raises(click.BadParameter): + ch.convert("256", None, None) + + with pytest.raises(AssertionError): + ch.convert("vA", None, None) + + def test_trigger(): t = Trigger() diff --git a/tests/fake_nxscope.py b/tests/fake_nxscope.py new file mode 100644 index 0000000..5a6185f --- /dev/null +++ b/tests/fake_nxscope.py @@ -0,0 +1,120 @@ +import queue +from types import SimpleNamespace +from typing import Any + +import numpy as np +from nxslib.dev import DeviceChannel +from nxslib.nxscope import DNxscopeStreamBlock + + +class FakeNxscope: + """Fast Nxscope test double used by CLI/PluginHandler tests.""" + + def __init__(self, *_: Any, **__: Any) -> None: + self.connected = False + self._stream_started = False + self._registered_plugins: dict[str, Any] = {} + self._channels = [] + for i in range(10): + if i == 8: + vdim = 0 + elif i == 9: + vdim = 3 + else: + vdim = 1 + self._channels.append(DeviceChannel(i, 10, vdim, f"chan{i}")) + self.dev = SimpleNamespace( + data=SimpleNamespace(chmax=len(self._channels)) + ) + self._enabled = [False for _ in self._channels] + self._dividers = [0 for _ in self._channels] + + def __enter__(self) -> "FakeNxscope": + self.connect() + return self + + def __exit__(self, *_: object) -> None: + self.disconnect() + + def connect(self) -> None: + self.connected = True + + def disconnect(self) -> None: + self.connected = False + self._stream_started = False + + def register_plugin(self, plugin: Any, frame_ids: Any = None) -> str: + del frame_ids + name = getattr(plugin, "name", plugin.__class__.__name__) + self._registered_plugins[name] = plugin + return str(name) + + def unregister_plugin(self, name: str) -> bool: + if name in self._registered_plugins: + self._registered_plugins.pop(name) + return True + return False + + def dev_channel_get(self, chid: int) -> DeviceChannel | None: + if 0 <= chid < len(self._channels): + return self._channels[chid] + return None + + def channels_default_cfg(self) -> None: + self._enabled = [False for _ in self._channels] + self._dividers = [0 for _ in self._channels] + + def ch_enable(self, chid: int) -> None: + self._enabled[chid] = True + + def ch_divider(self, chid: int, div: int) -> None: + self._dividers[chid] = div + + def channels_write(self) -> None: + return + + def stream_start(self) -> None: + self._stream_started = True + + def stream_stop(self) -> None: + self._stream_started = False + + def stream_sub(self, chid: int) -> queue.Queue[Any]: + q: queue.Queue[Any] = queue.Queue() + data = np.full((1200, 1), float(chid)) + meta = np.zeros((1200, 1), dtype=np.uint32) + block = DNxscopeStreamBlock(data=data, meta=meta) + # One large block keeps plugin loops deterministic and fast. + q.put([block]) + return q + + def stream_unsub(self, _: queue.Queue[Any]) -> None: + return + + def get_enabled_channels(self, applied: bool = True) -> tuple[int, ...]: + del applied + return tuple(i for i, enabled in enumerate(self._enabled) if enabled) + + def get_channel_divider(self, chid: int, applied: bool = True) -> int: + del applied + return self._dividers[chid] + + def get_channel_dividers(self, applied: bool = True) -> tuple[int, ...]: + del applied + return tuple(self._dividers) + + def get_channels_state(self, applied: bool = True) -> Any: + del applied + return SimpleNamespace( + enabled_channels=self.get_enabled_channels(), + dividers=self.get_channel_dividers(), + ) + + def get_device_capabilities(self) -> Any: + return SimpleNamespace(chmax=len(self._channels)) + + def get_stream_stats(self) -> Any: + return SimpleNamespace( + connected=self.connected, + stream_started=self._stream_started, + ) diff --git a/tests/plugins/test_csv.py b/tests/plugins/test_csv.py index 00a609a..b45b9e6 100644 --- a/tests/plugins/test_csv.py +++ b/tests/plugins/test_csv.py @@ -1,3 +1,9 @@ +import csv +import io + +import numpy as np +from nxslib.nxscope import DNxscopeStreamBlock + from nxscli.plugins.csv import PluginCsv @@ -7,3 +13,54 @@ def test_plugincsv_init(): assert plugin.stream is True # TODO: + + +def test_plugincsv_handle_blocks_none_meta_and_empty_block() -> None: + plugin = PluginCsv() + out = io.StringIO() + writer = csv.writer( + out, + delimiter=" ", + quotechar="|", + escapechar="\\", + quoting=csv.QUOTE_MINIMAL, + ) + plugin._csvwriters = [[writer, out]] + plugin._samples = 10 + plugin._nostop = False + plugin._datalen = [0] + plugin._meta_string = False + pdata = type("Q", (), {"vdim": 1})() + + block0 = DNxscopeStreamBlock(data=np.empty((0, 1)), meta=None) + block1 = DNxscopeStreamBlock(data=np.array([[1.0], [2.0]]), meta=None) + plugin._handle_blocks([block0, block1], pdata, 0) + + assert plugin._datalen == [2] + + +def test_plugincsv_handle_blocks_meta_string() -> None: + plugin = PluginCsv() + out = io.StringIO() + writer = csv.writer( + out, + delimiter=" ", + quotechar="|", + escapechar="\\", + quoting=csv.QUOTE_MINIMAL, + ) + plugin._csvwriters = [[writer, out]] + plugin._samples = 10 + plugin._nostop = False + plugin._datalen = [0] + plugin._meta_string = True + pdata = type("Q", (), {"vdim": 1})() + + block = DNxscopeStreamBlock( + data=np.array([[1.0], [2.0]]), + meta=np.array([[65], [66]], dtype=np.uint8), + ) + plugin._handle_blocks([block], pdata, 0) + + assert plugin._datalen == [2] + assert "A" in out.getvalue() diff --git a/tests/plugins/test_devinfo.py b/tests/plugins/test_devinfo.py index f6afa64..9bf7988 100644 --- a/tests/plugins/test_devinfo.py +++ b/tests/plugins/test_devinfo.py @@ -1,3 +1,9 @@ +from nxslib.intf.dummy import DummyDev +from nxslib.nxscope import NxscopeHandler +from nxslib.proto.parse import Parser + +from nxscli.iplugin import DPluginDescription +from nxscli.phandler import PluginHandler from nxscli.plugins.devinfo import PluginDevinfo @@ -7,3 +13,25 @@ def test_plugindevinfo_init(): assert plugin.stream is False # TODO: + + +def test_plugindevinfo_content(): + intf = DummyDev() + parse = Parser() + with NxscopeHandler(intf, parse, enable_bitrate_tracking=True) as nxscope: + with PluginHandler( + [DPluginDescription("pdevinfo", PluginDevinfo)] + ) as phandler: + phandler.nxscope_connect(nxscope) + + plugin = PluginDevinfo() + plugin.connect_phandler(phandler) + + assert plugin.start({}) is True + out = plugin.result() + + assert "Device common" in out + assert "Channels state (applied)" in out + assert "Channels state (buffered)" in out + assert "stream_started" in out + assert "bitrate" in out diff --git a/tests/plugins/test_npmem.py b/tests/plugins/test_npmem.py new file mode 100644 index 0000000..feae998 --- /dev/null +++ b/tests/plugins/test_npmem.py @@ -0,0 +1,48 @@ +import numpy as np + +from nxscli.plugins.npmem import PluginNpmem + + +def test_pluginnpmem_init() -> None: + plugin = PluginNpmem() + assert plugin.stream is True + + +def test_pluginnpmem_handle_blocks_and_samples(tmp_path) -> None: + class QData: + def __init__(self) -> None: + self.chan = 2 + self.vdim = 2 + + class Data: + def __init__(self) -> None: + self.qdlist = [QData()] + + class Block: + def __init__(self, data): # noqa: ANN001 + self.data = data + + class Sample: + def __init__(self, data): # noqa: ANN001 + self.data = data + + plugin = PluginNpmem() + plugin._phandler = object() + plugin._data = Data() + plugin._path = str(tmp_path / "capture") + plugin._npshape = 2 + plugin._init() + plugin._datalen = [0] + pdata = plugin._data.qdlist[0] + + plugin._handle_samples([Sample((1.0, 2.0))], pdata, 0) + assert plugin._datalen[0] == 0 + + plugin._handle_blocks( + [Block(np.array([[3.0, 4.0], [5.0, 6.0]]))], pdata, 0 + ) + assert plugin._datalen[0] == 2 + assert np.array_equal( + np.asarray(plugin._npfiles[0]), + np.array([[1.0, 3.0], [2.0, 4.0]], dtype=np.float32), + ) diff --git a/tests/plugins/test_npsave.py b/tests/plugins/test_npsave.py new file mode 100644 index 0000000..3af70e0 --- /dev/null +++ b/tests/plugins/test_npsave.py @@ -0,0 +1,85 @@ +import numpy as np + +from nxscli.plugins.npsave import PluginNpsave + + +def test_pluginnpsave_init() -> None: + plugin = PluginNpsave() + assert plugin.stream is True + + +def test_pluginnpsave_handle_blocks_and_final(tmp_path) -> None: + class QData: + def __init__(self) -> None: + self.chan = 1 + self.vdim = 2 + + class Data: + def __init__(self) -> None: + self.qdlist = [QData()] + + class Block: + def __init__(self, data): # noqa: ANN001 + self.data = data + + plugin = PluginNpsave() + plugin._phandler = object() + plugin._data = Data() + plugin._path = str(tmp_path / "capture") + plugin._init() + plugin._datalen = [0] + + pdata = plugin._data.qdlist[0] + plugin._handle_blocks( + [Block(np.array([[1.0, 2.0], [3.0, 4.0]]))], pdata, 0 + ) + plugin._final() + + arr = np.load(str(tmp_path / "capture_chan1.npy")) + assert arr.shape == (2, 2) + assert np.array_equal(arr, np.array([[1.0, 3.0], [2.0, 4.0]])) + assert plugin._datalen[0] == 2 + + +def test_pluginnpsave_handle_samples() -> None: + class QData: + def __init__(self) -> None: + self.chan = 1 + self.vdim = 2 + + class Data: + def __init__(self) -> None: + self.qdlist = [QData()] + + class Sample: + def __init__(self, data): # noqa: ANN001 + self.data = data + + plugin = PluginNpsave() + plugin._phandler = object() + plugin._data = Data() + plugin._init() + plugin._datalen = [0] + pdata = plugin._data.qdlist[0] + plugin._handle_samples([Sample((1.0, 2.0)), Sample((3.0, 4.0))], pdata, 0) + assert plugin._datalen[0] == 2 + + +def test_pluginnpsave_final_empty_chunks(tmp_path) -> None: + class QData: + def __init__(self) -> None: + self.chan = 3 + self.vdim = 2 + + class Data: + def __init__(self) -> None: + self.qdlist = [QData()] + + plugin = PluginNpsave() + plugin._phandler = object() + plugin._data = Data() + plugin._path = str(tmp_path / "capture_empty") + plugin._init() + plugin._final() + arr = np.load(str(tmp_path / "capture_empty_chan3.npy")) + assert arr.shape == (2, 0) diff --git a/tests/plugins/test_udp.py b/tests/plugins/test_udp.py index 267ae71..039a67e 100644 --- a/tests/plugins/test_udp.py +++ b/tests/plugins/test_udp.py @@ -1,3 +1,8 @@ +import json + +import numpy as np +from nxslib.nxscope import DNxscopeStreamBlock + from nxscli.plugins.udp import PluginUdp @@ -7,3 +12,34 @@ def test_pluginudp_init(): assert plugin.stream is True # TODO: + + +def test_pluginudp_handle_blocks_skips_empty_and_sends_json() -> None: + class Sock: + def __init__(self) -> None: + self.sent: list[tuple[bytes, tuple[str, int]]] = [] + + def sendto(self, payload: bytes, endpoint: tuple[str, int]) -> None: + self.sent.append((payload, endpoint)) + + plugin = PluginUdp() + plugin._sock = Sock() + plugin._address = "127.0.0.1" + plugin._port = 1234 + plugin._data_format = "json" + plugin._samples = 10 + plugin._nostop = False + plugin._datalen = [0] + pdata = type("Q", (), {"vdim": 1, "channame": "chan0"})() + + block0 = DNxscopeStreamBlock(data=np.empty((0, 1)), meta=None) + block1 = DNxscopeStreamBlock(data=np.array([[3.0]]), meta=None) + plugin._handle_blocks([block0, block1], pdata, 0) + + assert plugin._datalen == [1] + assert len(plugin._sock.sent) == 1 + payload, endpoint = plugin._sock.sent[0] + assert endpoint == ("127.0.0.1", 1234) + decoded = json.loads(payload.decode()) + assert decoded["timestamp"] == 0 + assert decoded["chan0"] == 3.0 diff --git a/tests/test_channelref.py b/tests/test_channelref.py new file mode 100644 index 0000000..09a77fa --- /dev/null +++ b/tests/test_channelref.py @@ -0,0 +1,17 @@ +import pytest + +from nxscli.channelref import ChannelRef + + +def test_channelref_accessors() -> None: + p = ChannelRef.physical(3) + v = ChannelRef.virtual(5) + + assert p.physical_id() == 3 + assert v.virtual_name() == "v5" + + with pytest.raises(ValueError): + ChannelRef.all_channels().physical_id() + + with pytest.raises(ValueError): + ChannelRef.all_channels().virtual_name() diff --git a/tests/test_control_server.py b/tests/test_control_server.py new file mode 100644 index 0000000..e275e5d --- /dev/null +++ b/tests/test_control_server.py @@ -0,0 +1,444 @@ +import base64 +import socket +from dataclasses import dataclass +from pathlib import Path + +import pytest # type: ignore +from nxslib.proto.iparse import ParseAck + +from nxscli.control_server import ( + ControlClient, + ControlServerPlugin, + _parse_endpoint, +) +from tests.fake_nxscope import FakeNxscope + + +def _get_test_endpoint(name: str) -> str: + if hasattr(socket, "AF_UNIX"): + return f"unix-abstract://{name}" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + host, port = sock.getsockname() + return f"tcp://{host}:{port}" + + +def _ensure_af_unix(monkeypatch: pytest.MonkeyPatch) -> None: + desired = getattr(socket, "AF_UNIX", 1) + monkeypatch.setattr(socket, "AF_UNIX", desired, raising=False) + + +class _DummySocket: + def __init__(self, family, sock_type): + self.family = family + self.sock_type = sock_type + self.closed = False + + def bind(self, addr): + self.addr = addr + + def listen(self, backlog): + self.backlog = backlog + + def settimeout(self, timeout): + self.timeout = timeout + + def close(self): + self.closed = True + + +class _DummyThread: + def __init__(self, target, name, daemon): + self.target = target + self.name = name + self.daemon = daemon + self.started = False + + def start(self): + self.started = True + + def is_alive(self): + return False + + def join(self, timeout=None): + del timeout + + +class _SockCloseErr(_DummySocket): + def close(self): + raise OSError("close failed") + + +def _patch_control_server_start(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("nxscli.control_server.socket.socket", _DummySocket) + monkeypatch.setattr("nxscli.control_server.threading.Thread", _DummyThread) + monkeypatch.setattr( + "nxscli.control_server.ControlServerPlugin._serve_loop", + lambda self: None, + raising=False, + ) + + +@dataclass(frozen=True) +class _Resp: + ext_id: int + cmd_id: int + req_id: int + status: int + payload: bytes + fid: int + is_error: bool + + +class _ControlStub: + def __init__(self): + self.calls = [] + + def send_user_frame(self, fid, payload, ack_mode="auto", ack_timeout=1.0): + self.calls.append(("send", int(fid), payload, ack_mode, ack_timeout)) + return ParseAck(True, 11) + + def ext_notify( + self, + ext_id, + cmd_id, + payload, + fid=8, + ack_mode="auto", + ack_timeout=1.0, + ): + self.calls.append( + ( + "notify", + int(ext_id), + int(cmd_id), + payload, + int(fid), + ack_mode, + ack_timeout, + ) + ) + return ParseAck(True, 22) + + def ext_request( + self, + ext_id, + cmd_id, + payload, + fid=8, + timeout=1.0, + ack_mode="auto", + ack_timeout=1.0, + ): + self.calls.append( + ( + "request", + int(ext_id), + int(cmd_id), + payload, + int(fid), + timeout, + ack_mode, + ack_timeout, + ) + ) + return _Resp( + ext_id=int(ext_id), + cmd_id=int(cmd_id), + req_id=31, + status=0, + payload=b"ok", + fid=int(fid), + is_error=False, + ) + + +def test_control_server_roundtrip_unix_abstract(): + endpoint = _get_test_endpoint("nxscli-test-control") + plugin = ControlServerPlugin(endpoint) + control = _ControlStub() + + plugin.on_register(control) + try: + client = ControlClient(endpoint, timeout=0.5) + + ack = client.send_user_frame(8, b"\x01", ack_mode="disabled") + assert ack.state is True + assert ack.retcode == 11 + + ack = client.ext_notify(0x21, 2, b"abc") + assert ack.state is True + assert ack.retcode == 22 + + ret = client.ext_request(0x21, 1, b"xyz", timeout=0.2) + assert ret.ok is True + assert ret.data["ext_id"] == 0x21 + assert ret.data["cmd_id"] == 1 + assert ret.data["req_id"] == 31 + assert base64.b64decode(ret.data["payload_b64"]) == b"ok" + finally: + plugin.on_unregister() + + +def test_control_server_parse_endpoint_tcp(): + ep = _parse_endpoint("tcp://127.0.0.1:55000") + assert ep.connect_addr == ("127.0.0.1", 55000) + + +def test_control_client_connection_error_returns_failed_ack( + monkeypatch, +): # type: ignore + class _DummySock: + def __enter__(self): + return self + + def __exit__(self, *_): + return None + + def settimeout(self, timeout): + del timeout + + def connect(self, addr): + del addr + raise ConnectionRefusedError(111, "Connection refused") + + monkeypatch.setattr("socket.socket", lambda *a, **k: _DummySock()) + + client = ControlClient("tcp://127.0.0.1:55001", timeout=0.1) + ack = client.send_user_frame(8, b"\x01") + assert ack.state is False + assert ack.retcode == -1 + + +def test_control_server_parse_unix_and_invalid_endpoints(monkeypatch): + _ensure_af_unix(monkeypatch) + + ep = _parse_endpoint("unix-abstract://nxscli.sock") + assert ep.cleanup_path is None + assert ep.bind_addr == "\x00nxscli.sock" + + ep = _parse_endpoint("unix:///tmp/nxscli.sock") + assert ep.cleanup_path == "/tmp/nxscli.sock" + assert ep.bind_addr == "/tmp/nxscli.sock" + + ep = _parse_endpoint("/tmp/nxscli2.sock") + assert ep.cleanup_path == "/tmp/nxscli2.sock" + assert ep.connect_addr == "/tmp/nxscli2.sock" + + with pytest.raises(ValueError): + _parse_endpoint("unix://") + with pytest.raises(ValueError): + _parse_endpoint("unix-abstract://") + + +def test_control_server_parse_unix_endpoints_without_af_unix(monkeypatch): + monkeypatch.delattr(socket, "AF_UNIX", raising=False) + with pytest.raises(ValueError, match="not supported"): + _parse_endpoint("unix:///tmp/nxscli.sock") + with pytest.raises(ValueError, match="not supported"): + _parse_endpoint("unix-abstract://nxscli") + with pytest.raises(ValueError, match="not supported"): + _parse_endpoint("/tmp/nxscli2.sock") + + +def test_control_server_enabled_parse_requires_af_unix(monkeypatch): + monkeypatch.delattr(socket, "AF_UNIX", raising=False) + with pytest.raises(ValueError, match="not supported"): + _parse_endpoint("unix:///tmp/nxscli.sock") + with pytest.raises(ValueError, match="not supported"): + _parse_endpoint("unix-abstract://nxscli") + with pytest.raises(ValueError, match="not supported"): + _parse_endpoint("/tmp/nxscli2.sock") + + +def test_control_server_enabled_endpoint_tcp(monkeypatch): + monkeypatch.delattr(socket, "AF_UNIX", raising=False) + endpoint = _get_test_endpoint("nxscli-test-endpoint") + assert endpoint.startswith("tcp://") + + +def test_control_server_enabled_endpoint_unix(monkeypatch): + _ensure_af_unix(monkeypatch) + endpoint = _get_test_endpoint("nxscli-test-endpoint") + assert endpoint == "unix-abstract://nxscli-test-endpoint" + + +def test_control_server_handle_validation_paths(): + plugin = ControlServerPlugin(_get_test_endpoint("nxscli-test-handle")) + + with pytest.raises(RuntimeError): + plugin._handle({"method": "send_user_frame", "params": {}}) + + plugin._control = _ControlStub() + with pytest.raises(ValueError): + plugin._handle({"method": "unknown", "params": {}}) + + +def test_control_server_start_noop_when_thread_alive(): + plugin = ControlServerPlugin(_get_test_endpoint("nxscli-start-alive")) + + class _AliveThread: + def is_alive(self): + return True + + plugin._thread = _AliveThread() + plugin._start() + assert isinstance(plugin._thread, _AliveThread) + + +def test_control_server_recv_json_paths(): + plugin = ControlServerPlugin(_get_test_endpoint("nxscli-test-recv")) + + class _Conn: + def __init__(self, chunks): + self._chunks = list(chunks) + + def recv(self, size): + del size + if not self._chunks: + return b"" + return self._chunks.pop(0) + + conn_ok = _Conn([b'{"a":1', b',"b":2}\n']) + req = plugin._recv_json(conn_ok) + assert req["a"] == 1 + assert req["b"] == 2 + assert conn_ok.recv(1) == b"" + + with pytest.raises(RuntimeError): + plugin._recv_json(_Conn([b""])) + + with pytest.raises(ValueError): + plugin._recv_json(_Conn([b"[]\n"])) + + +def test_control_server_serve_loop_timeout_and_exception_paths(): + plugin = ControlServerPlugin(_get_test_endpoint("nxscli-test-loop")) + plugin._control = _ControlStub() + + class _ConnOK: + def __enter__(self): + return self + + def __exit__(self, *_): + return None + + def settimeout(self, timeout): + del timeout + + class _Sock: + def __init__(self): + self.calls = 0 + + def accept(self): + self.calls += 1 + if self.calls == 1: + raise TimeoutError + if self.calls == 2: + return (_ConnOK(), None) + raise OSError("stop") + + sent = [] + + def _recv_json(_): + raise RuntimeError("bad request") + + def _send_json(_, resp): + sent.append(resp) + + plugin._recv_json = _recv_json + plugin._send_json = _send_json + plugin._sock = _Sock() + plugin._serve_loop() + assert sent + assert sent[0]["ok"] is False + + +def test_control_server_serve_loop_timeout_to_exit_branch(): + plugin = ControlServerPlugin(_get_test_endpoint("nxscli-test-loop-exit")) + plugin._control = _ControlStub() + + class _Sock: + def accept(self): + plugin._stop.set() + raise TimeoutError + + plugin._sock = _Sock() + plugin._serve_loop() + + +def test_control_server_start_stop_unix_cleanup_and_sock_close_error( + tmp_path, monkeypatch +): + _ensure_af_unix(monkeypatch) + _patch_control_server_start(monkeypatch) + + sock_path = Path(tmp_path) / "ctrl.sock" + plugin = ControlServerPlugin(f"unix://{sock_path}") + plugin._control = _ControlStub() + + plugin._start() + try: + assert plugin._thread is not None + plugin._start() + finally: + plugin._stop_server() + + plugin._stop_server() + + plugin._sock = _SockCloseErr(socket.AF_UNIX, socket.SOCK_STREAM) + plugin._stop_server() + + +def test_control_client_response_edge_paths(monkeypatch): # type: ignore + class _DummySock: + def __init__(self, chunks): + self._chunks = list(chunks) + + def __enter__(self): + return self + + def __exit__(self, *_): + return None + + def settimeout(self, timeout): + del timeout + + def connect(self, addr): + del addr + + def sendall(self, data): + del data + + def recv(self, size): + del size + if not self._chunks: + return b"" + return self._chunks.pop(0) + + assert _DummySock([]).recv(1) == b"" + + sock1 = _DummySock([b""]) + monkeypatch.setattr("socket.socket", lambda *a, **k: sock1) + client = ControlClient("tcp://127.0.0.1:55002", timeout=0.1) + ret = client.ext_request(1, 2, b"x") + assert ret.ok is False + assert client.last_error == "empty response" + + sock2 = _DummySock([b'{"ok":false', b',"error":"x"}', b""]) + monkeypatch.setattr("socket.socket", lambda *a, **k: sock2) + ack = client.ext_notify(1, 2, b"x") + assert ack.state is False + assert ack.retcode == -1 + + sock3 = _DummySock([b"not-json\n"]) + monkeypatch.setattr("socket.socket", lambda *a, **k: sock3) + ack = client.send_user_frame(8, b"a") + assert ack.state is False + assert ack.retcode == -1 + assert client.last_error is not None + + +def test_fake_nxscope_unregister_missing_returns_false(): + nxscope = FakeNxscope() + assert nxscope.unregister_plugin("missing") is False diff --git a/tests/test_idata.py b/tests/test_idata.py index dc9c1c5..65df0f3 100644 --- a/tests/test_idata.py +++ b/tests/test_idata.py @@ -1,11 +1,16 @@ import queue +from typing import TYPE_CHECKING +import numpy as np import pytest # type: ignore from nxslib.dev import DeviceChannel from nxscli.idata import PluginData, PluginDataCb, PluginQueueData from nxscli.trigger import DTriggerConfig, ETriggerType, TriggerHandler +if TYPE_CHECKING: + from nxscli.channelref import ChannelRef + g_queue: queue.Queue[list] = queue.Queue() @@ -40,6 +45,45 @@ def test_pluginqueuedata(): TriggerHandler.cls_cleanup() +def test_pluginqueuedata_queue_get_returns_triggered_block_payload() -> None: + class Block: + def __init__(self) -> None: + self.data = np.array([[1.0, 2.0], [3.0, 4.0]]) + self.meta = np.array([[10], [20]]) + + q = queue.Queue() + q.put([Block()]) + chan = DeviceChannel(0, 2, 2, "chan0") + dtc = DTriggerConfig(ETriggerType.ALWAYS_ON) + trig = TriggerHandler(0, dtc) + qdata = PluginQueueData(q, chan, trig) + + ret = qdata.queue_get(block=False) + assert len(ret) == 1 + assert ret[0].data.shape == (2, 2) + assert ret[0].meta.shape == (2, 1) + TriggerHandler.cls_cleanup() + + +def test_pluginqueuedata_queue_get_handles_block_without_meta() -> None: + class Block: + def __init__(self) -> None: + self.data = np.array([[1.0], [2.0]]) + self.meta = None + + q = queue.Queue() + q.put([Block()]) + chan = DeviceChannel(0, 2, 1, "chan0") + dtc = DTriggerConfig(ETriggerType.ALWAYS_ON) + trig = TriggerHandler(0, dtc) + qdata = PluginQueueData(q, chan, trig) + + ret = qdata.queue_get(block=False) + assert len(ret) == 1 + assert ret[0].meta is None + TriggerHandler.cls_cleanup() + + def test_nxsclipdata_init(): channels = [DeviceChannel(0, 1, 2, "chan0")] dtc = DTriggerConfig(ETriggerType.ALWAYS_OFF) @@ -58,3 +102,66 @@ def test_nxsclipdata_init(): assert gdata.qdlist[0].mlen == 0 TriggerHandler.cls_cleanup() + + +def test_nxsclipdata_queue_deinit_unsubscribes_all(): + unsubscribed: list[queue.Queue[list]] = [] + queues = [queue.Queue(), queue.Queue(), queue.Queue()] + + def stream_sub(channel: "ChannelRef"): # noqa: ANN001 + return queues[channel.physical_id()] + + def stream_unsub(q): # noqa: ANN001 + unsubscribed.append(q) + + channels = [ + DeviceChannel(0, 0, 1, "chan0"), + DeviceChannel(1, 0, 1, "chan1"), + DeviceChannel(2, 0, 1, "chan2"), + ] + dtc = DTriggerConfig(ETriggerType.ALWAYS_OFF) + trig = [ + TriggerHandler(0, dtc), + TriggerHandler(1, dtc), + TriggerHandler(2, dtc), + ] + cb = PluginDataCb(stream_sub, stream_unsub) + pdata = PluginData(channels, trig, cb) + + assert len(pdata.qdlist) == 3 + pdata._queue_deinit() + + assert len(pdata.qdlist) == 0 + assert unsubscribed == queues + TriggerHandler.cls_cleanup() + + +def test_nxsclipdata_virtual_channel_invalid_name_raises() -> None: + channels = [DeviceChannel(-2, 0, 1, "virt_bad_name")] + dtc = DTriggerConfig(ETriggerType.ALWAYS_OFF) + trig = [TriggerHandler(-2, dtc)] + cb = PluginDataCb(dummy_stream_sub, dummy_stream_unsub) + + with pytest.raises(ValueError): + _ = PluginData(channels, trig, cb) + + TriggerHandler.cls_cleanup() + + +def test_nxsclipdata_virtual_channel_name_is_supported() -> None: + channels = [DeviceChannel(-2, 0, 1, "v2")] + dtc = DTriggerConfig(ETriggerType.ALWAYS_OFF) + trig = [TriggerHandler(-2, dtc)] + + got = {"name": ""} + + def stream_sub(channel: "ChannelRef"): # noqa: ANN001 + got["name"] = channel.virtual_name() + return g_queue + + cb = PluginDataCb(stream_sub, dummy_stream_unsub) + pdata = PluginData(channels, trig, cb) + assert got["name"] == "v2" + pdata._queue_deinit() + + TriggerHandler.cls_cleanup() diff --git a/tests/test_iplugin.py b/tests/test_iplugin.py index 069dafd..e9aeaa9 100644 --- a/tests/test_iplugin.py +++ b/tests/test_iplugin.py @@ -48,3 +48,6 @@ def test_nxscliplugin_init(): # at default plugins dont need to wait assert p1.wait_for_plugin() is True + + # at default plugins return None for inputhook + assert p1.get_inputhook() is None diff --git a/tests/test_istream.py b/tests/test_istream.py new file mode 100644 index 0000000..55cb142 --- /dev/null +++ b/tests/test_istream.py @@ -0,0 +1,6 @@ +from nxscli.istream import IServiceRegistry, IStreamProvider + + +def test_istream_protocols_import() -> None: + assert IStreamProvider is not None + assert IServiceRegistry is not None diff --git a/tests/test_phandler.py b/tests/test_phandler.py index 05b4c36..3997305 100644 --- a/tests/test_phandler.py +++ b/tests/test_phandler.py @@ -1,8 +1,12 @@ +import queue +from types import SimpleNamespace + +import numpy as np import pytest # type: ignore -from nxslib.intf.dummy import DummyDev -from nxslib.nxscope import NxscopeHandler -from nxslib.proto.parse import Parser +from nxslib.dev import DeviceChannel +from nxslib.nxscope import DNxscopeStreamBlock +from nxscli.channelref import ChannelRef from nxscli.iplugin import ( DPluginDescription, EPluginType, @@ -12,7 +16,9 @@ IPluginText, ) from nxscli.phandler import PluginHandler +from nxscli.plugins.none import PluginNone from nxscli.trigger import DTriggerConfigReq +from tests.fake_nxscope import FakeNxscope class MockPlugin1(IPlugin): @@ -107,42 +113,37 @@ def result(self): # pragma: no cover def test_phandler_init(): # no plugins at the beginning - p = PluginHandler([]) - assert isinstance(p, PluginHandler) - assert p.names == [] - p.cleanup() + with PluginHandler([]) as p: + assert isinstance(p, PluginHandler) + assert p.names == [] # valid data plugins = [ DPluginDescription("plugin1", MockPlugin1), DPluginDescription("plugin2", MockPlugin2), ] - p = PluginHandler(plugins=plugins) - assert isinstance(p, PluginHandler) - assert p.names == ["plugin1", "plugin2"] - assert p.plugin_get("plugin1") == MockPlugin1 - assert p.plugin_get("plugin2") == MockPlugin2 - assert p.names == ["plugin1", "plugin2"] - - # add plugin = valid data - p.plugin_add(("plugin3", MockPlugin3)) - assert p.names == ["plugin1", "plugin2", "plugin3"] + with PluginHandler(plugins=plugins) as p: + assert isinstance(p, PluginHandler) + assert p.names == ["plugin1", "plugin2"] + assert p.plugin_get("plugin1") == MockPlugin1 + assert p.plugin_get("plugin2") == MockPlugin2 + assert p.names == ["plugin1", "plugin2"] - # plugins instances - assert isinstance(p.plugins["plugin1"](), MockPlugin1) - assert isinstance(p.plugins["plugin2"](), MockPlugin2) - assert isinstance(p.plugins["plugin3"](), MockPlugin3) + # add plugin = valid data + p.plugin_add(("plugin3", MockPlugin3)) + assert p.names == ["plugin1", "plugin2", "plugin3"] - # clean up - p.cleanup() + # plugins instances + assert isinstance(p.plugins["plugin1"](), MockPlugin1) + assert isinstance(p.plugins["plugin2"](), MockPlugin2) + assert isinstance(p.plugins["plugin3"](), MockPlugin3) @pytest.fixture def nxscope(): - intf = DummyDev() - parse = Parser() - nxscope = NxscopeHandler(intf, parse) - return nxscope + nxscope = FakeNxscope() + yield nxscope + nxscope.disconnect() def test_phandler_connect(nxscope): @@ -150,31 +151,67 @@ def test_phandler_connect(nxscope): DPluginDescription("plugin1", MockPlugin1), DPluginDescription("plugin2", MockPlugin2), ] - p = PluginHandler(plugins=plugins) + with PluginHandler(plugins=plugins) as p: + # nxs not connected + with pytest.raises(AssertionError): + _ = p.dev + with pytest.raises(AssertionError): + _ = p.stream_start() + with pytest.raises(AssertionError): + _ = p.stream_stop() + with pytest.raises(AssertionError): + _ = p.channels_configure([]) + + # connect nxslib instance + p.nxscope_connect(nxscope) + + # nxscope should be connected + assert p.dev is not None + + # chanlist + p.channels_configure([]) + p.channels_configure([-1], 1) + p.channels_configure([1, 2], [1, 2], writenow=True) + + +def test_phandler_nxscope_property(nxscope): + """Test nxscope property access.""" + plugins = [ + DPluginDescription("plugin1", MockPlugin1), + ] + with PluginHandler(plugins=plugins) as p: + # Test assertion when not connected + with pytest.raises(AssertionError): + _ = p.nxscope + + # Connect and test successful access + p.nxscope_connect(nxscope) + assert p.nxscope is not None + assert p.nxscope == nxscope - # nxs not connected - with pytest.raises(AssertionError): - _ = p.dev - with pytest.raises(AssertionError): - _ = p.stream_start() - with pytest.raises(AssertionError): - _ = p.stream_stop() - with pytest.raises(AssertionError): - _ = p.channels_configure([]) - # connect nxslib instance - p.nxscope_connect(nxscope) +def test_phandler_nxscope_status_interfaces(nxscope): + """Test status and capabilities interfaces from PluginHandler.""" + plugins = [DPluginDescription("plugin1", MockPlugin1)] + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) - # nxscope should be connected - assert p.dev is not None + caps = p.get_device_capabilities() + assert caps.chmax > 0 - # chanlist - p.channels_configure([]) - p.channels_configure([-1], 1) - p.channels_configure([1, 2], [1, 2], writenow=True) + enabled = p.get_enabled_channels() + assert enabled == () - # clean up - p.cleanup() + dividers = p.get_channel_dividers() + assert len(dividers) == caps.chmax + + state = p.get_channels_state() + assert state.enabled_channels == () + assert len(state.dividers) == caps.chmax + + stats = p.get_stream_stats() + assert stats.connected is True + assert stats.stream_started is False def test_phandler_enable(): @@ -182,75 +219,71 @@ def test_phandler_enable(): DPluginDescription("plugin1", MockPlugin1), DPluginDescription("plugin2", MockPlugin2), ] - p = PluginHandler(plugins=plugins) - - # no plugins enabled at default - assert len(p.enabled) == 0 - - pid1 = p.enable("plugin2", **{}) - assert len(p.enabled) == 1 - assert pid1 == 0 - assert p.enabled[0][0] == 0 - assert p.enabled[0][1] == MockPlugin2 - assert p.enabled[0][2] == {} - - pid2 = p.enable("plugin1", **{"arg1": 1}) - assert len(p.enabled) == 2 - assert pid2 == 1 - assert p.enabled[0][0] == 0 - assert p.enabled[0][1] == MockPlugin2 - assert p.enabled[0][2] == {} - - assert p.enabled[0][2] == {} - assert p.enabled[1][0] == 1 - assert p.enabled[1][1] == MockPlugin1 - assert p.enabled[1][2]["arg1"] == 1 - - # we can enable plugins multiple times - pid3 = p.enable("plugin1", **{"arg1": "test"}) - assert len(p.enabled) == 3 - assert pid3 == 2 - assert p.enabled[0][0] == 0 - assert p.enabled[0][1] == MockPlugin2 - assert p.enabled[0][2] == {} - - assert p.enabled[1][0] == 1 - assert p.enabled[1][1] == MockPlugin1 - assert p.enabled[1][2]["arg1"] == 1 - - assert p.enabled[2][0] == 2 - assert p.enabled[2][1] == MockPlugin1 - assert p.enabled[2][2]["arg1"] == "test" - - # disable plugin1 - p.disable(pid1) - assert len(p.enabled) == 2 - assert p.enabled[0][0] == 1 - assert p.enabled[1][0] == 2 - - # once again disable plugin1 - with pytest.raises(AttributeError): + with PluginHandler(plugins=plugins) as p: + # no plugins enabled at default + assert len(p.enabled) == 0 + + pid1 = p.enable("plugin2", **{}) + assert len(p.enabled) == 1 + assert pid1 == 0 + assert p.enabled[0][0] == 0 + assert p.enabled[0][1] == MockPlugin2 + assert p.enabled[0][2] == {} + + pid2 = p.enable("plugin1", **{"arg1": 1}) + assert len(p.enabled) == 2 + assert pid2 == 1 + assert p.enabled[0][0] == 0 + assert p.enabled[0][1] == MockPlugin2 + assert p.enabled[0][2] == {} + + assert p.enabled[0][2] == {} + assert p.enabled[1][0] == 1 + assert p.enabled[1][1] == MockPlugin1 + assert p.enabled[1][2]["arg1"] == 1 + + # we can enable plugins multiple times + pid3 = p.enable("plugin1", **{"arg1": "test"}) + assert len(p.enabled) == 3 + assert pid3 == 2 + assert p.enabled[0][0] == 0 + assert p.enabled[0][1] == MockPlugin2 + assert p.enabled[0][2] == {} + + assert p.enabled[1][0] == 1 + assert p.enabled[1][1] == MockPlugin1 + assert p.enabled[1][2]["arg1"] == 1 + + assert p.enabled[2][0] == 2 + assert p.enabled[2][1] == MockPlugin1 + assert p.enabled[2][2]["arg1"] == "test" + + # disable plugin1 p.disable(pid1) + assert len(p.enabled) == 2 + assert p.enabled[0][0] == 1 + assert p.enabled[1][0] == 2 - # disable plugin2 - p.disable(pid2) - assert len(p.enabled) == 1 - assert p.enabled[0][0] == 2 + # once again disable plugin1 + with pytest.raises(AttributeError): + p.disable(pid1) - # once again disable plugin2 - with pytest.raises(AttributeError): + # disable plugin2 p.disable(pid2) + assert len(p.enabled) == 1 + assert p.enabled[0][0] == 2 - # disable plugin3 - p.disable(pid3) - assert len(p.enabled) == 0 + # once again disable plugin2 + with pytest.raises(AttributeError): + p.disable(pid2) - # once again disable plugin3 - with pytest.raises(AttributeError): + # disable plugin3 p.disable(pid3) + assert len(p.enabled) == 0 - # clean up - p.cleanup() + # once again disable plugin3 + with pytest.raises(AttributeError): + p.disable(pid3) def test_phandler_start_ready(nxscope): @@ -259,31 +292,28 @@ def test_phandler_start_ready(nxscope): DPluginDescription("plugin2", MockPlugin2), DPluginDescription("plugin3", MockPlugin3), ] - p = PluginHandler(plugins=plugins) - p.nxscope_connect(nxscope) - - # enable all plugins - p.enable("plugin1", **{}) - p.enable("plugin2", **{}) - p.enable("plugin3", **{}) - assert len(p.enabled) == 3 + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) - # start plugins - p.start() + # enable all plugins + p.enable("plugin1", **{}) + p.enable("plugin2", **{}) + p.enable("plugin3", **{}) + assert len(p.enabled) == 3 - ret = p.ready() - assert ret[0].result() == "1" - assert ret[1].result() == "2" - assert ret[2].result() is None + # start plugins + p.start() - # plugins not need to wait - assert p.wait_for_plugins() is None + ret = p.ready() + assert ret[0].result() == "1" + assert ret[1].result() == "2" + assert ret[2].result() is None - # stop plugins - p.stop() + # plugins not need to wait + assert p.wait_for_plugins() is None - # clean up - p.cleanup() + # stop plugins + p.stop() def test_phandler_start_poll(nxscope): @@ -292,138 +322,607 @@ def test_phandler_start_poll(nxscope): DPluginDescription("plugin2", MockPlugin2), DPluginDescription("plugin3", MockPlugin3), ] - p = PluginHandler(plugins=plugins) - p.nxscope_connect(nxscope) + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) - # enable all plugins - p.enable("plugin1", **{}) - p.enable("plugin2", **{}) - p.enable("plugin3", **{}) - assert len(p.enabled) == 3 + # enable all plugins + p.enable("plugin1", **{}) + p.enable("plugin2", **{}) + p.enable("plugin3", **{}) + assert len(p.enabled) == 3 - # start plugins - p.start() + # start plugins + p.start() - # poll - ret = p.poll() - assert ret[0].result() == "1" - assert ret[1].result() == "2" - assert ret[2].result() is None + # poll + ret = p.poll() + assert ret[0].result() == "1" + assert ret[1].result() == "2" + assert ret[2].result() is None - # poll once again but all handled - ret = p.poll() - assert ret is None + # poll once again but all handled + ret = p.poll() + assert ret is None - # plugins not need to wait - assert p.wait_for_plugins() is None + # plugins not need to wait + assert p.wait_for_plugins() is None - # stop plugins - p.stop() - - # clean up - p.cleanup() + # stop plugins + p.stop() def test_phandler_start_nostream(nxscope): plugins = [DPluginDescription("plugin1", MockPlugin1)] - p = PluginHandler(plugins=plugins) - p.nxscope_connect(nxscope) - - # enable all plugins - p.enable("plugin1", **{}) - assert len(p.enabled) == 1 + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) - # start plugins - p.start() + # enable all plugins + p.enable("plugin1", **{}) + assert len(p.enabled) == 1 - ret = p.ready() - assert ret[0].result() == "1" + # start plugins + p.start() - # plugins not need to wait - assert p.wait_for_plugins() is None + ret = p.ready() + assert ret[0].result() == "1" - # stop plugins - p.stop() + # plugins not need to wait + assert p.wait_for_plugins() is None - # clean up - p.cleanup() + # stop plugins + p.stop() def test_phandler_start_noready(nxscope): plugins = [DPluginDescription("plugin4", MockPlugin4)] - p = PluginHandler(plugins=plugins) - p.nxscope_connect(nxscope) - - # enable all plugins - p.enable("plugin4", **{}) - assert len(p.enabled) == 1 + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) - # start plugins - p.start() + # enable all plugins + p.enable("plugin4", **{}) + assert len(p.enabled) == 1 - # always not ready - ret = p.ready() - assert ret == [] + # start plugins + p.start() - # always not ready - ret = p.poll() - assert ret == [] + # always not ready + ret = p.ready() + assert ret == [] - # plugins not need to wait - assert p.wait_for_plugins() is None + # always not ready + ret = p.poll() + assert ret == [] - # stop plugins - p.stop() + # plugins not need to wait + assert p.wait_for_plugins() is None - # clean up - p.cleanup() + # stop plugins + p.stop() def test_phandler_trigger(): - p = PluginHandler() - - # default all on - dt = p.trigger_get(0) - assert dt.ttype == "on" - assert dt.srcchan is None - assert dt.params is None - dt = p.trigger_get(1) - assert dt.ttype == "on" - assert dt.srcchan is None - assert dt.params is None - - trg = {-1: DTriggerConfigReq("off", None)} - p.triggers_configure(trg) - dt = p.trigger_get(0) - assert dt.ttype == "off" - assert dt.srcchan is None - assert dt.params is None - dt = p.trigger_get(1) - assert dt.ttype == "off" - assert dt.srcchan is None - assert dt.params is None - - trg = {-1: DTriggerConfigReq("on", None)} - p.triggers_configure(trg) - dt = p.trigger_get(0) - assert dt.ttype == "on" - assert dt.srcchan is None - assert dt.params is None - dt = p.trigger_get(1) - assert dt.ttype == "on" - assert dt.srcchan is None - assert dt.params is None - - trg = {0: DTriggerConfigReq("on", None), 1: DTriggerConfigReq("off", None)} - p.triggers_configure(trg) - dt = p.trigger_get(0) - assert dt.ttype == "on" - assert dt.srcchan is None - assert dt.params is None - dt = p.trigger_get(1) - assert dt.ttype == "off" - assert dt.srcchan is None - assert dt.params is None - - # clean up - p.cleanup() + with PluginHandler() as p: + # default all on + dt = p.trigger_get(0) + assert dt.ttype == "on" + assert dt.srcchan is None + assert dt.params is None + dt = p.trigger_get(1) + assert dt.ttype == "on" + assert dt.srcchan is None + assert dt.params is None + + trg = {-1: DTriggerConfigReq("off", None)} + p.triggers_configure(trg) + dt = p.trigger_get(0) + assert dt.ttype == "off" + assert dt.srcchan is None + assert dt.params is None + dt = p.trigger_get(1) + assert dt.ttype == "off" + assert dt.srcchan is None + assert dt.params is None + + trg = {-1: DTriggerConfigReq("on", None)} + p.triggers_configure(trg) + dt = p.trigger_get(0) + assert dt.ttype == "on" + assert dt.srcchan is None + assert dt.params is None + dt = p.trigger_get(1) + assert dt.ttype == "on" + assert dt.srcchan is None + assert dt.params is None + + trg = { + 0: DTriggerConfigReq("on", None), + 1: DTriggerConfigReq("off", None), + } + p.triggers_configure(trg) + dt = p.trigger_get(0) + assert dt.ttype == "on" + assert dt.srcchan is None + assert dt.params is None + dt = p.trigger_get(1) + assert dt.ttype == "off" + assert dt.srcchan is None + assert dt.params is None + + +def test_phandler_collect_inputhooks(): # noqa: C901 + """Test collect_inputhooks method.""" + + class PluginWithHook(IPlugin): # pragma: no cover + def __init__(self): + super().__init__(EPluginType.ANIMATION) + + @property + def stream(self) -> bool: + return True + + def stop(self) -> None: + pass + + def data_wait(self, timeout=None) -> bool: + return True + + def start(self, kwargs) -> bool: + return True + + def result(self): + return None + + @classmethod + def get_inputhook(cls): + def hook(context): + pass + + return hook + + class PluginWithoutHook(IPlugin): # pragma: no cover + def __init__(self): + super().__init__(EPluginType.TEXT) + + @property + def stream(self) -> bool: + return False + + def stop(self) -> None: + pass + + def data_wait(self, timeout=None) -> bool: + return True + + def start(self, kwargs) -> bool: + return True + + def result(self): + return None + + # Create plugin handler with plugins + plugins = [ + DPluginDescription("with_hook", PluginWithHook), + DPluginDescription("without_hook", PluginWithoutHook), + ] + with PluginHandler(plugins) as p: + # Collect inputhooks + hooks = p.collect_inputhooks() + + # Should find one hook (from PluginWithHook) + assert len(hooks) == 1 + assert callable(hooks[0]) + + +def test_phandler_plugin_start_stop_dynamic(nxscope): + """Test plugin_start_dynamic and plugin_stop_dynamic methods.""" + with nxscope: + plugins = [DPluginDescription("plugin1", MockPlugin1)] + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) + + # Start plugin dynamically (no chanlist configured) + pid = p.plugin_start_dynamic("plugin1", channels=[0, 1]) + assert pid == 0 + assert len(p._started) == 1 + + # Get started plugins + started = p.get_started_plugins() + assert len(started) == 1 + assert started[0] == (0, "plugin1") + + # Stop plugin dynamically + p.plugin_stop_dynamic(pid) + assert len(p._started) == 0 + + # Test invalid PID + pytest.raises(IndexError, p.plugin_stop_dynamic, 99) + + +def test_phandler_plugin_get_instance(nxscope): + """Test plugin_get_instance returns instance or None.""" + with nxscope: + plugins = [DPluginDescription("plugin1", MockPlugin1)] + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) + + # Out-of-range before any plugin started + assert p.plugin_get_instance(-1) is None + assert p.plugin_get_instance(0) is None + + pid = p.plugin_start_dynamic("plugin1", channels=[0]) + assert pid == 0 + + # Valid pid returns the IPlugin instance + instance = p.plugin_get_instance(pid) + assert isinstance(instance, MockPlugin1) + + # Out-of-range pid returns None + assert p.plugin_get_instance(99) is None + + p.plugin_stop_dynamic(pid) + + +def test_phandler_get_started_plugins_unregistered(): + """Test get_started_plugins with unregistered plugin class.""" + plugins = [DPluginDescription("plugin1", MockPlugin1)] + with PluginHandler(plugins=plugins) as p: + # Manually create a plugin instance not in _plugins + plugin = MockPlugin2() + p._started.append((plugin, {})) + + # Should fall back to class name + started = p.get_started_plugins() + assert len(started) == 1 + assert started[0] == (0, "MockPlugin2") + + +def test_phandler_plugin_start_dynamic_all_channels(nxscope): + """Test plugin_start_dynamic with -1 (all channels).""" + with nxscope: + plugins = [DPluginDescription("plugin1", MockPlugin1)] + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) + + # Start plugin with -1 to select all channels + pid = p.plugin_start_dynamic("plugin1", channels=[-1]) + assert pid == 0 + + # Stop plugin + p.plugin_stop_dynamic(pid) + + +def test_phandler_plugin_start_dynamic_plot_plugin(nxscope): + """Test plugin_start_dynamic for plot plugin branch.""" + with nxscope: + plugins = [DPluginDescription("plugin3", MockPlugin3)] + with PluginHandler(plugins=plugins) as p: + p.nxscope_connect(nxscope) + + pid = p.plugin_start_dynamic("plugin3", channels=[0]) + assert pid == 0 + assert len(p._started) == 1 + + p.plugin_stop_dynamic(pid) + + +def test_phandler_chanlist_plugin_dynamic_mode(nxscope): + """Test chanlist_plugin in dynamic mode (no chanlist configured).""" + with nxscope: + with PluginHandler() as p: + p.nxscope_connect(nxscope) + + # Don't configure channels - this puts us in dynamic mode + # _chanlist should be empty + + # Test with -1 (all channels) + chanlist = p.chanlist_plugin([ChannelRef.all_channels()]) + assert len(chanlist) > 0 + assert all(ch.data.is_valid for ch in chanlist) + + # Test with specific channels + chanlist = p.chanlist_plugin( + [ + ChannelRef.physical(0), + ChannelRef.physical(1), + ChannelRef.physical(2), + ] + ) + assert len(chanlist) <= 3 + assert all(ch.data.is_valid for ch in chanlist) + + # Test with a channel that might not exist (high channel ID) + # This tests the branch where ch might be None or not valid + chanlist = p.chanlist_plugin( + [ChannelRef.physical(0), ChannelRef.physical(999)] + ) + assert len(chanlist) <= 2 + + +class _MockProvider: + def __init__(self) -> None: + self.connected = False + self.started = False + self.channels = {} + self.subs = [] + + def on_connect(self, nxscope) -> None: + del nxscope + self.connected = True + + def on_disconnect(self) -> None: + self.connected = False + + def on_stream_start(self) -> None: + self.started = True + + def on_stream_stop(self) -> None: + self.started = False + + def channel_get(self, channel): + if channel.is_virtual: + return self.channels.get(channel.virtual_name()) + return None + + def channel_list(self): + return tuple(self.channels.values()) + + def stream_sub(self, channel): + if not channel.is_virtual: + return None + chan = channel.virtual_name() + if chan not in self.channels: + return None + q = queue.Queue() + self.subs.append(q) + return q + + def stream_unsub(self, subq) -> bool: + if subq in self.subs: + self.subs.remove(subq) + return True + return False + + +def test_phandler_stream_provider(nxscope): + with PluginHandler() as p: + provider = _MockProvider() + provider.channels["v0"] = DeviceChannel(-2, 10, 1, "v0") + provider.channels["vinvalid"] = DeviceChannel(-4, 0, 1, "vinvalid") + + p.stream_provider_add(provider) + p.service_set("k", "v") + assert p.service_get("k") == "v" + assert p.channel_get(ChannelRef.virtual(0)) is not None + + p.nxscope_connect(nxscope) + provider2 = _MockProvider() + provider2.channels["v1"] = DeviceChannel(-3, 10, 1, "v1") + p.stream_provider_add(provider2) + assert p.channel_get(ChannelRef.virtual(1)) is not None + assert provider.connected is True + all_channels = p.chanlist_plugin([ChannelRef.all_channels()]) + assert any(ch.data.chan == -2 for ch in all_channels) + p.channels_configure([ChannelRef.virtual(0)], div=1, writenow=True) + p.channels_configure([ChannelRef.virtual(0)], div=[1], writenow=True) + p._chanlist_enable() + p.stream_start() + assert provider.started is True + + sub = p.stream_sub(ChannelRef.virtual(0)) + assert isinstance(sub, queue.Queue) + p.stream_unsub(sub) + + p.stream_stop() + assert provider.started is False + + +def test_phandler_stream_unsub_fallback(nxscope): + with PluginHandler() as p: + provider = _MockProvider() + p.stream_provider_add(provider) + p.nxscope_connect(nxscope) + subq = p.stream_sub(ChannelRef.physical(0)) + p.stream_unsub(subq) + + +def test_phandler_provider_channel_get_fallbacks(nxscope): + with PluginHandler() as p: + p.nxscope_connect(nxscope) + + provider1 = _MockProvider() + provider2 = _MockProvider() + provider2.channels["v0"] = DeviceChannel(-2, 10, 1, "v0") + + p.stream_provider_add(provider1) + p.stream_provider_add(provider2) + + assert p.channel_get(ChannelRef.virtual(0)) is not None + assert p.channel_get(ChannelRef.virtual(42)) is None + + +def test_phandler_enable_and_div_skip_virtual(nxscope): + with PluginHandler() as p: + p.nxscope_connect(nxscope) + phys = p.channel_get(ChannelRef.physical(0)) + assert phys is not None + virt = DeviceChannel(-2, 10, 1, "v0") + p._chanlist = [virt, phys] + + p._chanlist_enable() + enabled = p.get_enabled_channels(applied=False) + assert 0 in enabled + + p._chanlist_div(1) + p._chanlist_div([0, 1]) + assert p.get_channel_divider(0, applied=False) == 1 + + +def test_phandler_channel_ref_parser_branches() -> None: + with PluginHandler() as p: + assert p._channel_ref(-1).is_all + assert p._channel_ref("2").physical_id() == 2 + assert p._channel_ref("v7").virtual_name() == "v7" + assert p._channel_refs(None, default_all=False) == [] + + with pytest.raises(ValueError): + p._channel_ref("vA") + + with pytest.raises(ValueError): + p._channel_ref("bad") + + +def test_phandler_stream_sub_nonphysical_raises(nxscope) -> None: + with PluginHandler() as p: + p.nxscope_connect(nxscope) + + with pytest.raises(ValueError): + p.stream_sub(ChannelRef.virtual(99)) + + +def test_phandler_enable_div_skip_nonexistent_channels(nxscope) -> None: + with PluginHandler() as p: + p.nxscope_connect(nxscope) + + missing = DeviceChannel(999, 10, 1, "missing") + p._chanlist = [missing] + p._chanlist_enable() + p._chanlist_div(1) + p._chanlist_div([1]) + + +def test_phandler_chanlist_plugin_virtual_in_configured_mode(nxscope) -> None: + with PluginHandler() as p: + p.nxscope_connect(nxscope) + + provider = _MockProvider() + provider.channels["v0"] = DeviceChannel(-2, 10, 1, "v0") + p.stream_provider_add(provider) + p.channels_configure([ChannelRef.physical(0)], div=0, writenow=False) + + chanlist = p.chanlist_plugin([ChannelRef.virtual(0)]) + assert any(ch.data.chan == -2 for ch in chanlist) + + +def test_phandler_chanlist_plugin_virtual_multi_refs(nxscope) -> None: + with PluginHandler() as p: + p.nxscope_connect(nxscope) + + provider = _MockProvider() + provider.channels["v0"] = DeviceChannel(-2, 10, 1, "v0") + p.stream_provider_add(provider) + p.channels_configure([ChannelRef.physical(0)], div=0, writenow=False) + + chanlist = p.chanlist_plugin( + [ChannelRef.virtual(42), ChannelRef.virtual(0)] + ) + assert any(ch.data.chan == -2 for ch in chanlist) + + +def test_mock_provider_non_virtual_paths() -> None: + provider = _MockProvider() + assert provider.channel_get(ChannelRef.physical(0)) is None + assert provider.stream_sub(ChannelRef.physical(0)) is None + assert provider.stream_sub(ChannelRef.virtual(0)) is None + assert provider.stream_unsub(queue.Queue()) is False + + +def test_phandler_chanlist_plugin_mixed_refs(nxscope) -> None: + with PluginHandler() as p: + p.nxscope_connect(nxscope) + + provider = _MockProvider() + provider.channels["v0"] = DeviceChannel(-2, 10, 1, "v0") + p.stream_provider_add(provider) + p.channels_configure([ChannelRef.physical(0)], div=0, writenow=False) + + chanlist = p.chanlist_plugin( + [ChannelRef.physical(0), ChannelRef.virtual(0)] + ) + assert any(ch.data.chan == -2 for ch in chanlist) + assert any(ch.data.chan == 0 for ch in chanlist) + + +def test_phandler_stream_unsub_without_nxscope() -> None: + with PluginHandler() as p: + p.stream_unsub(queue.Queue()) + + +def test_phandler_stream_fallback_to_nxscope(nxscope) -> None: + with PluginHandler() as p: + # Force direct Nxscope fallback path. + p._providers = [] + p.nxscope_connect(nxscope) + subq = p.stream_sub(ChannelRef.physical(0)) + p.stream_unsub(subq) + + +def test_phandler_chanlist_plugin_all_skips_missing_channel(nxscope) -> None: + with PluginHandler() as p: + p.nxscope_connect(nxscope) + dev_channel_get = nxscope.dev_channel_get + + def wrapped(chid: int): + if chid == 1: + return None + return dev_channel_get(chid) + + nxscope.dev_channel_get = wrapped + + chanlist = p.chanlist_plugin([ChannelRef.all_channels()]) + assert all(ch.data.chan != 1 for ch in chanlist) + + +def test_pluginthread_is_done_partial() -> None: + plugin = PluginNone() + plugin._samples = 2 + plugin._nostop = False + assert plugin._is_done([1]) is False + + +def test_pluginthread_common_not_done_path() -> None: + class _QD: + def queue_get(self, block, timeout=1.0): + del block, timeout + return [ + DNxscopeStreamBlock( + data=np.array([[1.0]], dtype=float), + meta=np.array([[0]], dtype=np.uint32), + ) + ] + + plugin = PluginNone() + plugin._samples = 2 + plugin._nostop = False + plugin._datalen = [0] + plugin._plugindata = SimpleNamespace(qdlist=[_QD()]) + + plugin._thread_common() + assert plugin._datalen == [1] + + +def test_pluginnone_handle_blocks_empty_and_done_path() -> None: + pdata = type("Q", (), {"vdim": 1})() + plugin = PluginNone() + plugin._samples = 1 + plugin._nostop = False + plugin._datalen = [0] + + empty_block = DNxscopeStreamBlock(data=np.empty((0, 1)), meta=None) + plugin._handle_blocks([empty_block], pdata, 0) + assert plugin._datalen == [0] + + full_block = DNxscopeStreamBlock(data=np.array([[1.0]]), meta=None) + plugin._datalen = [1] + plugin._handle_blocks([full_block], pdata, 0) + assert plugin._datalen == [1] + + plugin._nostop = True + plugin._datalen = [0] + plugin._handle_blocks([full_block], pdata, 0) + assert plugin._datalen == [1] + + +def test_iplugin_get_plot_handler_default(): + """Test that IPlugin.get_plot_handler() returns None by default.""" + plugin = MockPlugin1() + assert plugin.get_plot_handler() is None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..cafafad --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,120 @@ +import numpy as np + +from nxscli.transforms.operators_window import ( + fft_spectrum, + histogram_counts, + polar_relation, + xy_relation, +) +from nxscli.transforms.pipeline import ( + HopGate, + SampleStore, + TransformPipeline, + WindowBinaryProcessor, + WindowUnaryProcessor, +) + + +def test_sample_store_max_points() -> None: + store = SampleStore(max_points=3) + store.ingest({"a": [1.0, 2.0, 3.0]}) + store.ingest({"a": [4.0, 5.0]}) + assert store.count("a") == 5 + assert store.series("a").tolist() == [3.0, 4.0, 5.0] + assert store.count("missing") == 0 + assert store.series("missing").tolist() == [] + + +def test_pipeline_fanout_same_source_fft_and_hist() -> None: + pipe = TransformPipeline(max_points=64) + pipe.register( + WindowUnaryProcessor( + name="fft", + channel="a", + window=8, + hop=4, + fn=lambda arr: fft_spectrum(arr, window_fn="rect"), + ) + ) + pipe.register( + WindowUnaryProcessor( + name="hist", + channel="a", + window=8, + hop=4, + fn=lambda arr: histogram_counts(arr, bins=4, range_mode="auto"), + ) + ) + + out1 = pipe.ingest({"a": [0.0, 1.0, 0.0, -1.0]}) + assert set(out1) == {"fft", "hist"} + assert int(out1["fft"].freq.size) > 0 + assert int(out1["hist"].counts.sum()) == 4 + + out2 = pipe.ingest({"a": [0.0, 1.0]}) + assert out2 == {} + + out3 = pipe.ingest({"a": [0.0, -1.0]}) + assert set(out3) == {"fft", "hist"} + + +def test_pipeline_binary_xy_fanout() -> None: + pipe = TransformPipeline(max_points=16) + pipe.register( + WindowBinaryProcessor( + name="xy", + left_channel="x", + right_channel="y", + window=8, + hop=2, + fn=lambda x, y: xy_relation(x, y, window=8), + ) + ) + + out1 = pipe.ingest({"x": [1.0, 2.0], "y": [3.0, 4.0]}) + assert set(out1) == {"xy"} + assert out1["xy"].x.tolist() == [1.0, 2.0] + assert out1["xy"].y.tolist() == [3.0, 4.0] + + out2 = pipe.ingest({"x": [3.0], "y": [5.0]}) + assert out2 == {} + + out3 = pipe.ingest({"x": [4.0], "y": [6.0]}) + assert set(out3) == {"xy"} + assert np.allclose(out3["xy"].x, np.asarray([1.0, 2.0, 3.0, 4.0])) + + +def test_pipeline_binary_xy_and_polar_fanout() -> None: + pipe = TransformPipeline(max_points=16) + pipe.register( + WindowBinaryProcessor( + name="xy", + left_channel="x", + right_channel="y", + window=8, + hop=2, + fn=lambda x, y: xy_relation(x, y, window=8), + ) + ) + pipe.register( + WindowBinaryProcessor( + name="polar", + left_channel="x", + right_channel="y", + window=8, + hop=2, + fn=lambda x, y: polar_relation(x, y, window=8), + ) + ) + + out = pipe.ingest({"x": [1.0, 0.0], "y": [0.0, 1.0]}) + assert set(out) == {"xy", "polar"} + assert np.allclose(out["polar"].theta, np.asarray([0.0, np.pi / 2.0])) + assert np.allclose(out["polar"].radius, np.asarray([1.0, 1.0])) + + +def test_pipeline_store_property_and_hop_gate_zero_count() -> None: + pipe = TransformPipeline(max_points=2) + assert isinstance(pipe.store, SampleStore) + gate = HopGate(hop=2) + assert gate.ready(0) is False diff --git a/tests/test_pluginthr.py b/tests/test_pluginthr.py new file mode 100644 index 0000000..96429f2 --- /dev/null +++ b/tests/test_pluginthr.py @@ -0,0 +1,124 @@ +import numpy as np +import pytest +from nxslib.nxscope import DNxscopeStreamBlock + +from nxscli.pluginthr import PluginThread + + +class _ThreadStub: + def __init__(self) -> None: + self.stopped = False + + def stop_set(self) -> None: + self.stopped = True + + +class _PluginThreadImpl(PluginThread): + def __init__(self) -> None: + super().__init__() + self.handled: list[object] = [] + + def _handle_blocks(self, data, pdata, j) -> None: # noqa: ANN001 + self.handled.extend(data) + self._datalen[j] += len(data) + + def _init(self) -> None: # pragma: no cover + return + + def _final(self) -> None: # pragma: no cover + return + + +def test_pluginthread_block_payload_uses_queue_get() -> None: + class PData: + def queue_get( + self, block: bool = True, timeout: float = 1.0 + ): # noqa: ANN001, ARG002 + return [DNxscopeStreamBlock(data=np.array([[1.0]]), meta=None)] + + plug = _PluginThreadImpl() + plug._thread = _ThreadStub() + plug._plugindata = type("PD", (), {"qdlist": [PData()]})() + plug._samples = 1 + plug._nostop = False + plug._datalen = [0] + + plug._thread_common() + + assert len(plug.handled) == 1 + assert plug._datalen == [1] + assert plug._thread.stopped is True + + +def test_pluginthread_block_payload_without_converter_uses_queue_get() -> None: + class PData: + def queue_get( + self, block: bool = True, timeout: float = 1.0 + ): # noqa: ANN001, ARG002 + return [DNxscopeStreamBlock(data=np.array([[1.0]]), meta=None)] + + plug = _PluginThreadImpl() + plug._thread = _ThreadStub() + plug._plugindata = type("PD", (), {"qdlist": [PData()]})() + plug._samples = 1 + plug._nostop = False + plug._datalen = [0] + + plug._thread_common() + + assert len(plug.handled) == 1 + assert plug._datalen == [1] + assert plug._thread.stopped is True + + +def test_pluginthread_block_rows_handles_none_meta() -> None: + class _PData: + pass + + plug = _PluginThreadImpl() + rows = list( + plug._block_rows( + [DNxscopeStreamBlock(data=np.array([[1.0], [2.0]]), meta=None)], + _PData(), + 0, + ) + ) + assert rows == [((1.0,), ()), ((2.0,), ())] + + +def test_pluginthread_non_block_payload_raises() -> None: + class PData: + def queue_get( + self, block: bool = True, timeout: float = 1.0 + ): # noqa: ANN001, ARG002 + return [{"data": [1.0]}] + + plug = _PluginThreadImpl() + plug._thread = _ThreadStub() + plug._plugindata = type("PD", (), {"qdlist": [PData()]})() + plug._samples = 1 + plug._nostop = False + plug._datalen = [0] + + with pytest.raises(RuntimeError): + plug._thread_common() + + +def test_pluginthread_empty_payload_is_ignored() -> None: + class PData: + def queue_get( + self, block: bool = True, timeout: float = 1.0 + ): # noqa: ANN001, ARG002 + return [] + + plug = _PluginThreadImpl() + plug._thread = _ThreadStub() + plug._plugindata = type("PD", (), {"qdlist": [PData()]})() + plug._samples = 1 + plug._nostop = False + plug._datalen = [0] + + plug._thread_common() + + assert plug.handled == [] + assert plug._datalen == [0] diff --git a/tests/test_stream_hub.py b/tests/test_stream_hub.py new file mode 100644 index 0000000..db2666e --- /dev/null +++ b/tests/test_stream_hub.py @@ -0,0 +1,165 @@ +import queue +from types import SimpleNamespace + +import numpy as np +from nxslib.dev import DeviceChannel +from nxslib.nxscope import DNxscopeStreamBlock + +from nxscli.channelref import ChannelRef +from nxscli.stream_hub import SharedStreamProvider + + +class _FakeNxscopeHub: + def __init__(self) -> None: + self._channels = { + 0: DeviceChannel(0, 10, 1, "ch0"), + 1: DeviceChannel(1, 10, 1, "ch1"), + } + self.dev = SimpleNamespace(data=SimpleNamespace(chmax=2)) + self.sub_calls: dict[int, int] = {} + self.unsub_calls = 0 + self.source_queues: dict[int, queue.Queue] = {} + + def dev_channel_get(self, chid: int): + return self._channels.get(chid) + + def stream_sub(self, chid: int): + self.sub_calls[chid] = self.sub_calls.get(chid, 0) + 1 + q = queue.Queue() + self.source_queues[chid] = q + return q + + def stream_unsub(self, subq) -> None: + self.unsub_calls += 1 + for chid in list(self.source_queues.keys()): + if self.source_queues[chid] is subq: + del self.source_queues[chid] + return + + +def _block(value: float) -> DNxscopeStreamBlock: + return DNxscopeStreamBlock( + data=np.asarray([[value]], dtype=np.float64), + meta=None, + ) + + +def test_stream_hub_fanout_single_upstream_subscription() -> None: + hub = SharedStreamProvider() + fake = _FakeNxscopeHub() + hub.on_connect(fake) + + sub1 = hub.stream_sub(ChannelRef.physical(0)) + sub2 = hub.stream_sub(ChannelRef.physical(0)) + assert sub1 is not None + assert sub2 is not None + + hub.on_stream_start() + assert fake.sub_calls[0] == 1 + + fake.source_queues[0].put([_block(1.0)]) + hub._thread_common() + out1 = sub1.get(block=True, timeout=0.2) + out2 = sub2.get(block=True, timeout=0.2) + assert float(out1[0].data[0, 0]) == 1.0 + assert float(out2[0].data[0, 0]) == 1.0 + + assert hub.stream_unsub(sub1) is True + assert hub.stream_unsub(sub2) is True + assert fake.unsub_calls >= 1 + hub.on_stream_stop() + hub.on_disconnect() + + +def test_stream_hub_subscribe_after_start_and_stop_paths() -> None: + hub = SharedStreamProvider() + fake = _FakeNxscopeHub() + + hub.on_stream_start() # no nxscope + hub._thread_common() # not started + + hub.on_connect(fake) + hub.on_stream_start() # no subscribers + hub._thread_common() # started but no sources + + sub = hub.stream_sub(ChannelRef.physical(1)) + assert sub is not None + assert fake.sub_calls[1] == 1 + + # empty source queue path + hub._thread_common() + + # empty blocks path + fake.source_queues[1].put([]) + hub._thread_common() + + assert hub.stream_unsub(queue.Queue()) is False + assert hub.stream_unsub(sub) is True + hub.on_stream_stop() + hub.on_stream_stop() # already stopped + + +def test_stream_hub_nonphysical_and_channel_paths() -> None: + hub = SharedStreamProvider() + fake = _FakeNxscopeHub() + + assert hub.channel_get(ChannelRef.physical(0)) is None + assert hub.stream_sub(ChannelRef.virtual(9)) is None + assert hub.channel_list() == () + + hub.on_connect(fake) + assert hub.channel_get(ChannelRef.virtual(0)) is None + ch0 = hub.channel_get(ChannelRef.physical(0)) + assert ch0 is not None + assert ch0.data.chan == 0 + + hub.on_disconnect() + assert hub.stream_unsub(queue.Queue()) is False + + +def test_stream_hub_connect_disconnect_reset() -> None: + hub = SharedStreamProvider() + fake = _FakeNxscopeHub() + hub.on_connect(fake) + sub = hub.stream_sub(ChannelRef.physical(0)) + assert sub is not None + hub.on_stream_start() + hub.on_stream_start() # already started + hub.on_disconnect() + # disconnected provider should not find previous subscription + assert hub.stream_unsub(sub) is False + + +def test_stream_hub_internal_branch_paths() -> None: + hub = SharedStreamProvider() + fake = _FakeNxscopeHub() + + # _ensure_source_sub_locked with no nxscope + hub._ensure_source_sub_locked(0) + + hub.on_connect(fake) + hub._ensure_source_sub_locked(0) + # repeated ensure must not resubscribe same channel + hub._ensure_source_sub_locked(0) + assert fake.sub_calls[0] == 1 + + # mapped queue missing from subscriber list + sub = hub.stream_sub(ChannelRef.physical(0)) + assert sub is not None + hub._subscribers[0].clear() + assert hub.stream_unsub(sub) is True + + # started with no nxscope path in stop + hub._started = True + hub._nxscope = None + hub._source_subs = {0: queue.Queue()} + hub.on_stream_stop() + + +def test_fake_nxscopehub_stream_unsub_paths() -> None: + fake = _FakeNxscopeHub() + q0 = fake.stream_sub(0) + fake.stream_unsub(queue.Queue()) + assert 0 in fake.source_queues + fake.stream_unsub(q0) + assert 0 not in fake.source_queues diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..f54eca3 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,233 @@ +import numpy as np +import pytest + +from nxscli.transforms.models import WindowConfig, WindowCursor +from nxscli.transforms.operators_sample import apply_scale_offset, binary_op +from nxscli.transforms.operators_window import ( + fft_spectrum, + histogram_counts, + polar_relation, + windowed_fft, + windowed_histogram, + windowed_polar, + windowed_xy, + xy_relation, +) +from nxscli.transforms.window_engine import ( + latest_window, + normalize_window_config, + should_recompute, +) + + +def test_apply_scale_offset() -> None: + arr = apply_scale_offset([1.0, 2.0, 3.0], scale=2.0, offset=-1.0) + assert arr.tolist() == [1.0, 3.0, 5.0] + + +def test_binary_op_truncates_shorter_input() -> None: + arr = binary_op([1.0, 2.0, 3.0], [10.0, 20.0], np.add) + assert arr.tolist() == [11.0, 22.0] + + +def test_binary_op_empty_result() -> None: + arr = binary_op([], [1.0], np.add) + assert arr.tolist() == [] + + +def test_fft_spectrum_nonempty() -> None: + res = fft_spectrum([0.0, 1.0, 0.0, -1.0], window_fn="rect") + assert int(res.freq.size) > 0 + assert int(res.amplitude.size) == int(res.freq.size) + hamming = fft_spectrum([0.0, 1.0, 0.0, -1.0], window_fn="hamming") + blackman = fft_spectrum([0.0, 1.0, 0.0, -1.0], window_fn="blackman") + assert int(hamming.freq.size) > 0 + assert int(blackman.freq.size) > 0 + + +def test_fft_spectrum_empty() -> None: + res = fft_spectrum([1.0]) + assert res.freq.tolist() == [] + assert res.amplitude.tolist() == [] + + +def test_histogram_counts_preserves_count() -> None: + src = [0.0, 0.1, 0.2, 0.7, 0.9] + res = histogram_counts(src, bins=3, range_mode="auto") + assert int(res.counts.sum()) == len(src) + assert int(res.edges.size) == 4 + + +def test_histogram_counts_fixed_and_empty_and_error() -> None: + fixed = histogram_counts( + [0.0, 0.5], bins=2, range_mode="fixed", value_range=(0.0, 1.0) + ) + assert int(fixed.counts.sum()) == 2 + empty = histogram_counts([], bins=2) + assert empty.counts.tolist() == [] + with pytest.raises(ValueError): + histogram_counts([0.0], bins=2, range_mode="fixed") + + +def test_xy_relation_truncate() -> None: + res = xy_relation([1.0, 2.0, 3.0], [10.0, 20.0], window=64) + assert res.x.tolist() == [2.0, 3.0] + assert res.y.tolist() == [10.0, 20.0] + + +def test_xy_relation_errors_and_empty() -> None: + res = xy_relation([], [], window=8) + assert res.x.tolist() == [] + assert res.y.tolist() == [] + with pytest.raises(ValueError): + xy_relation([1.0], [1.0], window=2, align_policy="pad") + + +def test_polar_relation_converts_xy() -> None: + res = polar_relation([1.0, 0.0], [0.0, 1.0], window=16) + assert np.allclose(res.theta, np.asarray([0.0, np.pi / 2.0])) + assert np.allclose(res.radius, np.asarray([1.0, 1.0])) + + +def test_polar_relation_empty_input() -> None: + res = polar_relation([], [], window=4) + assert res.theta.tolist() == [] + assert res.radius.tolist() == [] + + +def test_windowed_polar_hop_gating() -> None: + cursor = WindowCursor() + assert ( + windowed_polar( + [1.0, 2.0], + [3.0, 4.0], + window=16, + hop=2, + align_policy="truncate", + cursor=cursor, + ) + is not None + ) + assert ( + windowed_polar( + [1.0, 2.0, 3.0], + [3.0, 4.0, 5.0], + window=16, + hop=2, + align_policy="truncate", + cursor=cursor, + ) + is None + ) + + +def test_windowed_fft_hop_gating() -> None: + cursor = WindowCursor() + assert windowed_fft([1.0, 2.0], window=4, hop=2, cursor=cursor) is not None + assert ( + windowed_fft([1.0, 2.0, 3.0], window=4, hop=2, cursor=cursor) is None + ) + assert ( + windowed_fft([1.0, 2.0, 3.0, 4.0], window=4, hop=2, cursor=cursor) + is not None + ) + + +def test_windowed_histogram_hop_gating() -> None: + cursor = WindowCursor() + assert ( + windowed_histogram( + [0.0, 1.0], + window=4, + hop=3, + bins=2, + range_mode="auto", + cursor=cursor, + ) + is not None + ) + assert ( + windowed_histogram( + [0.0, 1.0, 2.0], + window=4, + hop=3, + bins=2, + range_mode="auto", + cursor=cursor, + ) + is None + ) + + +def test_windowed_xy_hop_gating() -> None: + cursor = WindowCursor() + assert ( + windowed_xy( + [1.0, 2.0], + [3.0, 4.0], + window=16, + hop=2, + align_policy="truncate", + cursor=cursor, + ) + is not None + ) + assert ( + windowed_xy( + [1.0, 2.0, 3.0], + [3.0, 4.0, 5.0], + window=16, + hop=2, + align_policy="truncate", + cursor=cursor, + ) + is None + ) + + +def test_windowed_total_count_and_window_helpers() -> None: + cfg = normalize_window_config(1, None) + assert cfg.window >= 2 + assert cfg.hop >= 1 + cfg2 = normalize_window_config(8, 0) + assert cfg2.hop == 2 + cfg3 = normalize_window_config(8, 3) + assert cfg3.hop == 3 + + cursor = WindowCursor() + assert should_recompute(0, WindowConfig(window=4, hop=2), cursor) is False + assert should_recompute(2, WindowConfig(window=4, hop=2), cursor) is True + assert should_recompute(3, WindowConfig(window=4, hop=2), cursor) is False + + arr1 = latest_window([1.0, 2.0], WindowConfig(window=8, hop=2)) + assert arr1.tolist() == [1.0, 2.0] + arr2 = latest_window([1.0, 2.0, 3.0], WindowConfig(window=2, hop=1)) + assert arr2.tolist() == [2.0, 3.0] + + cursor2 = WindowCursor() + assert ( + windowed_xy( + [1.0, 2.0], + [3.0, 4.0], + window=2, + hop=2, + align_policy="truncate", + cursor=cursor2, + total_count=2, + ) + is not None + ) + + cursor3 = WindowCursor() + assert ( + windowed_polar( + [1.0, 2.0], + [3.0, 4.0], + window=2, + hop=2, + align_policy="truncate", + cursor=cursor3, + total_count=2, + ) + is not None + ) diff --git a/tests/test_trigger.py b/tests/test_trigger.py index 860122d..838342c 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -1,7 +1,8 @@ from threading import Lock +import numpy as np import pytest # type: ignore -from nxslib.nxscope import DNxscopeStream +from nxslib.nxscope import DNxscopeStream, DNxscopeStreamBlock from nxscli.trigger import ( DTriggerConfig, @@ -981,6 +982,69 @@ def test_triggerhandle_chanxtochany_hoffset(): TriggerHandler.cls_cleanup() +def test_triggerhandler_block_helpers_and_cache_paths() -> None: + with global_lock: + dtc = DTriggerConfig(ETriggerType.EDGE_RISING, hoffset=2, level=5.0) + th = TriggerHandler(0, dtc) + + assert th._combined_vector([], 0) == [] + assert th._edgerising([], 0, 0.0).state is False + assert th._edgefalling([], 0, 0.0).state is False + assert th._slice_from([], 1) == [] + + block = DNxscopeStreamBlock(data=np.array([[0.0], [2.0]]), meta=None) + out = th.data_triggered([block]) + assert out == [] + assert th._cache + + sliced = th._slice_from([block], 1) + assert len(sliced) == 1 + assert sliced[0].data.shape[0] == 1 + assert sliced[0].meta is None + + tail = th._cache_tail([block], 1) + assert len(tail) == 1 + assert tail[0].data.shape[0] == 1 + + assert th._cache_tail([block], 0) == [block] + assert th._cache_tail([], 1) == [] + assert th._cache_tail([DNxscopeStream((1,), ())], 1) == [ + DNxscopeStream((1,), ()) + ] + + block0 = DNxscopeStreamBlock(data=np.array([[0.0], [1.0]]), meta=None) + block1 = DNxscopeStreamBlock(data=np.array([[2.0]]), meta=None) + assert th._slice_from([block0, block1], 2) == [block1] + + scalar_block = DNxscopeStreamBlock(data=np.array([[7.0]]), meta=None) + assert th._combined_vector([scalar_block], 0) == [7.0] + + list_block = DNxscopeStreamBlock( + data=np.array([[1.0], [2.0]]), meta=None + ) + assert th._combined_vector([list_block], 0) == [1.0, 2.0] + + concat0 = DNxscopeStreamBlock(data=np.array([[0.0]]), meta=None) + concat1 = DNxscopeStreamBlock(data=np.array([[1.0]]), meta=None) + assert th._combined_vector([concat0, concat1], 0) == [0.0, 1.0] + + TriggerHandler.cls_cleanup() + + +def test_triggerhandler_block_cache_hoffset_zero_keeps_current_batch() -> None: + with global_lock: + dtc = DTriggerConfig(ETriggerType.EDGE_RISING, hoffset=0, level=5.0) + th = TriggerHandler(0, dtc) + block = DNxscopeStreamBlock(data=np.array([[0.0], [2.0]]), meta=None) + payload = [block] + + out = th.data_triggered(payload) + + assert out == [] + assert th._cache is payload + TriggerHandler.cls_cleanup() + + def test_triggerhandle_edgerising_hoffset(): # TODO pass diff --git a/tests/virtual/test_command.py b/tests/virtual/test_command.py new file mode 100644 index 0000000..a02e461 --- /dev/null +++ b/tests/virtual/test_command.py @@ -0,0 +1,145 @@ +"""Tests for virtual channel CLI command.""" + +from click.testing import CliRunner + +from nxscli.channelref import ChannelRef +from nxscli.cli.environment import Environment +from nxscli.commands.config.cmd_vadd import cmd_vadd + + +class _FakeRuntime: + def __init__(self) -> None: + self.calls = [] + + def add_virtual_channel(self, **kwargs): + self.calls.append(kwargs) + return [("v0", "v0")] + + +def test_cmd_vadd(monkeypatch) -> None: + runtime = _FakeRuntime() + monkeypatch.setattr( + "nxscli.commands.config.cmd_vadd.get_runtime", + lambda _p: runtime, + ) + + env = Environment() + env.phandler = object() + runner = CliRunner() + result = runner.invoke( + cmd_vadd, + ["0", "0", "--operator", "scale_offset"], + obj=env, + ) + assert result.exit_code == 0 + assert "channel v0" in result.output + assert runtime.calls[0]["channel_id"] == 0 + + +def test_cmd_vadd_requires_inputs(monkeypatch) -> None: + runtime = _FakeRuntime() + monkeypatch.setattr( + "nxscli.commands.config.cmd_vadd.get_runtime", + lambda _p: runtime, + ) + env = Environment() + env.phandler = object() + runner = CliRunner() + result = runner.invoke(cmd_vadd, ["1"], obj=env) + assert result.exit_code != 0 + + +def test_cmd_vadd_sets_channels_from_physical_inputs(monkeypatch) -> None: + runtime = _FakeRuntime() + monkeypatch.setattr( + "nxscli.commands.config.cmd_vadd.get_runtime", + lambda _p: runtime, + ) + env = Environment() + env.phandler = object() + runner = CliRunner() + result = runner.invoke( + cmd_vadd, + ["100", "0", "--operator", "scale_offset"], + obj=env, + ) + assert result.exit_code == 0 + assert env.channels == ([ChannelRef.physical(0)], 0) + + +def test_cmd_vadd_merges_required_sources(monkeypatch) -> None: + runtime = _FakeRuntime() + monkeypatch.setattr( + "nxscli.commands.config.cmd_vadd.get_runtime", + lambda _p: runtime, + ) + env = Environment() + env.phandler = object() + env.channels = ([ChannelRef.physical(2)], 0) + runner = CliRunner() + result = runner.invoke( + cmd_vadd, + ["100", "0", "--operator", "scale_offset"], + obj=env, + ) + assert result.exit_code == 0 + assert env.channels == ( + [ChannelRef.physical(2), ChannelRef.physical(0)], + 0, + ) + + +def test_cmd_vadd_keeps_all_selector(monkeypatch) -> None: + runtime = _FakeRuntime() + monkeypatch.setattr( + "nxscli.commands.config.cmd_vadd.get_runtime", + lambda _p: runtime, + ) + env = Environment() + env.phandler = object() + env.channels = ([ChannelRef.all_channels()], 0) + runner = CliRunner() + result = runner.invoke( + cmd_vadd, + ["100", "0", "--operator", "scale_offset"], + obj=env, + ) + assert result.exit_code == 0 + assert env.channels == ([ChannelRef.all_channels()], 0) + + +def test_cmd_vadd_no_required_physical_sources(monkeypatch) -> None: + runtime = _FakeRuntime() + monkeypatch.setattr( + "nxscli.commands.config.cmd_vadd.get_runtime", + lambda _p: runtime, + ) + env = Environment() + env.phandler = object() + runner = CliRunner() + result = runner.invoke( + cmd_vadd, + ["100", "v0", "--operator", "scale_offset"], + obj=env, + ) + assert result.exit_code == 0 + assert env.channels is None + + +def test_cmd_vadd_does_not_duplicate_required_source(monkeypatch) -> None: + runtime = _FakeRuntime() + monkeypatch.setattr( + "nxscli.commands.config.cmd_vadd.get_runtime", + lambda _p: runtime, + ) + env = Environment() + env.phandler = object() + env.channels = ([ChannelRef.physical(0)], 0) + runner = CliRunner() + result = runner.invoke( + cmd_vadd, + ["100", "0", "--operator", "scale_offset"], + obj=env, + ) + assert result.exit_code == 0 + assert env.channels == ([ChannelRef.physical(0)], 0) diff --git a/tests/virtual/test_manager.py b/tests/virtual/test_manager.py new file mode 100644 index 0000000..4b22d11 --- /dev/null +++ b/tests/virtual/test_manager.py @@ -0,0 +1,281 @@ +"""Tests for virtual operator graph manager.""" + +import pytest + +from nxscli.virtual.errors import VirtualChannelError +from nxscli.virtual.manager import VirtualChannelManager +from nxscli.virtual.models import ChannelSpec, VirtualChannelSpec +from nxscli.virtual.operators import default_operator_registry + + +def test_default_registry_contains_builtin_ops() -> None: + reg = default_operator_registry() + assert set(reg) == { + "scale_offset", + "math_binary", + "stats_running", + } + + +def test_virtual_manager_pipeline() -> None: + mgr = VirtualChannelManager() + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + mgr.add_physical_channel(ChannelSpec("1", "ch1", "float", 1)) + + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="scaled", + operator="scale_offset", + inputs=("0",), + params={"scale": 2.0, "offset": 1.0}, + ) + ) + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v1", + name="sum", + operator="math_binary", + inputs=("v0", "1"), + params={"op": "add"}, + ) + ) + + out = mgr.process_sample({"0": (1.0,), "1": (3.0,)}) + assert out["v0"] == (3.0,) + assert out["v1"] == (6.0,) + assert mgr.required_physical_channel_ids() == ("0", "1") + + +def test_virtual_manager_errors() -> None: + mgr = VirtualChannelManager() + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + + with pytest.raises(VirtualChannelError): + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="bad", + operator="unknown", + inputs=("0",), + ) + ) + + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="ok", + operator="scale_offset", + inputs=("0",), + ) + ) + with pytest.raises(VirtualChannelError): + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="dup", + operator="scale_offset", + inputs=("0",), + ) + ) + + +def test_virtual_manager_more_error_paths() -> None: + mgr = VirtualChannelManager() + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + with pytest.raises(VirtualChannelError): + mgr.add_physical_channel(ChannelSpec("0", "dup", "float", 1)) + + with pytest.raises(VirtualChannelError): + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="noin", + operator="scale_offset", + inputs=(), + ) + ) + with pytest.raises(VirtualChannelError): + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="unknown-input", + operator="scale_offset", + inputs=("9",), + ) + ) + + with pytest.raises(VirtualChannelError): + mgr.process_update("v0", (1.0,)) + with pytest.raises(VirtualChannelError): + mgr.process_update("0", (1.0, 2.0)) + + +def test_virtual_manager_disabled_and_reset() -> None: + mgr = VirtualChannelManager() + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="scaled", + operator="scale_offset", + inputs=("0",), + enabled=False, + ) + ) + changed = mgr.process_update("0", (1.0,)) + assert changed == {} + out = mgr.process_sample({"0": (1.0,)}) + assert out == {"0": (1.0,)} + assert mgr.channel_spec("0") is not None + assert len(mgr.channel_specs()) >= 1 + assert mgr.physical_channel_ids() == ("0",) + mgr.reset() + + +def test_virtual_manager_cycle_detection() -> None: + mgr = VirtualChannelManager() + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v0", + name="scaled", + operator="scale_offset", + inputs=("0",), + ) + ) + # Force cycle in compiled graph and verify detector. + compiled = mgr._compiled["v0"] + mgr._compiled["v0"] = type(compiled)( + spec=VirtualChannelSpec( + channel_id="v0", + name="scaled", + operator="scale_offset", + inputs=("v0",), + ), + outputs=compiled.outputs, + output_ids=compiled.output_ids, + operator=compiled.operator, + ) + with pytest.raises(VirtualChannelError): + mgr._rebuild_order() + + +class _NoOutputs: + def configure(self, spec, inputs) -> None: + del spec, inputs + + def describe_outputs(self, spec): + del spec + return () + + def process(self, inputs): + del inputs + return () + + def reset(self) -> None: + return + + +class _BadOutLen: + def configure(self, spec, inputs) -> None: + del spec, inputs + + def describe_outputs(self, spec): + return ( + ChannelSpec( + channel_id=spec.channel_id, + name=spec.name, + dtype="float", + vdim=1, + ), + ) + + def process(self, inputs): + del inputs + return ((1.0,), (2.0,)) + + def reset(self) -> None: + return + + +class _Collide: + def configure(self, spec, inputs) -> None: + del spec, inputs + + def describe_outputs(self, spec): + del spec + return (ChannelSpec("0", "dup", "float", 1),) + + def process(self, inputs): + del inputs + return ((1.0,),) + + def reset(self) -> None: + return + + +def test_virtual_manager_no_outputs_and_collide() -> None: + assert _NoOutputs().process(()) == () + _NoOutputs().reset() + assert _Collide().process(()) == ((1.0,),) + _Collide().reset() + mgr = VirtualChannelManager(operators={"no": _NoOutputs, "col": _Collide}) + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + with pytest.raises(VirtualChannelError): + mgr.add_virtual_channel(VirtualChannelSpec("v0", "x", "no", ("0",))) + with pytest.raises(VirtualChannelError): + mgr.add_virtual_channel(VirtualChannelSpec("v0", "x", "col", ("0",))) + + +def test_virtual_manager_bad_output_len() -> None: + _BadOutLen().reset() + mgr2 = VirtualChannelManager(operators={"badlen": _BadOutLen}) + mgr2.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + mgr2.add_virtual_channel(VirtualChannelSpec("v0", "x", "badlen", ("0",))) + with pytest.raises(VirtualChannelError): + mgr2.process_update("0", (1.0,)) + with pytest.raises(VirtualChannelError): + mgr2.process_sample({}) + + +def test_virtual_manager_process_sample_missing_input() -> None: + mgr = VirtualChannelManager() + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + mgr.add_physical_channel(ChannelSpec("1", "ch1", "float", 1)) + mgr.add_virtual_channel( + VirtualChannelSpec( + channel_id="v1", + name="sum", + operator="math_binary", + inputs=("0", "1"), + params={"op": "add"}, + ) + ) + with pytest.raises(VirtualChannelError): + mgr.process_sample({"0": (1.0,)}) + + +def test_virtual_manager_process_sample_invalid_outputs_len() -> None: + mgr = VirtualChannelManager(operators={"badlen": _BadOutLen}) + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + mgr.add_virtual_channel(VirtualChannelSpec("v0", "x", "badlen", ("0",))) + with pytest.raises(VirtualChannelError): + mgr.process_sample({"0": (1.0,)}) + + +def test_virtual_manager_process_sample_missing_runtime_input() -> None: + mgr = VirtualChannelManager() + mgr.add_physical_channel(ChannelSpec("0", "ch0", "float", 1)) + mgr.add_virtual_channel( + VirtualChannelSpec("v0", "x", "scale_offset", ("0",)) + ) + compiled = mgr._compiled["v0"] + mgr._compiled["v0"] = type(compiled)( + spec=VirtualChannelSpec("v0", "x", "scale_offset", ("ghost",)), + outputs=compiled.outputs, + output_ids=compiled.output_ids, + operator=compiled.operator, + ) + with pytest.raises(VirtualChannelError): + mgr.process_sample({"0": (1.0,)}) diff --git a/tests/virtual/test_models.py b/tests/virtual/test_models.py new file mode 100644 index 0000000..f48ae07 --- /dev/null +++ b/tests/virtual/test_models.py @@ -0,0 +1,36 @@ +"""Tests for virtual data models helpers.""" + +from nxslib.dev import DeviceChannel, EDeviceChannelType + +from nxscli.virtual.models import ChannelSpec, to_float + + +def test_channel_spec_from_device_channel() -> None: + ch = DeviceChannel(7, int(EDeviceChannelType.FLOAT.value), 2, "temp") + spec = ChannelSpec.from_device_channel(ch, data_kind="stats") + assert spec.channel_id == "7" + assert spec.name == "temp" + assert spec.vdim == 2 + assert spec.data_kind == "stats" + + +def test_channel_spec_dtype_and_channel_id_parsing() -> None: + spec = ChannelSpec(channel_id="v0", name="virt", dtype="int", vdim=1) + assert spec.dtype == int(EDeviceChannelType.INT32.value) + assert spec.device_channel.data.chan == -1 + + spec2 = ChannelSpec(channel_id="3", name="x", dtype=9, vdim=1) + assert spec2.dtype == 9 + assert spec2.device_channel.data.chan == 3 + spec3 = ChannelSpec(channel_id="4", name="y", dtype="unknown", vdim=1) + assert spec3.dtype == int(EDeviceChannelType.FLOAT.value) + spec4 = ChannelSpec(channel_id="5", name="z", dtype=object(), vdim=1) + assert spec4.dtype == int(EDeviceChannelType.FLOAT.value) + + +def test_to_float_paths() -> None: + assert to_float(1, 0.0) == 1.0 + assert to_float(1.5, 0.0) == 1.5 + assert to_float("2.5", 0.0) == 2.5 + assert to_float("bad", 7.0) == 7.0 + assert to_float(object(), 9.0) == 9.0 diff --git a/tests/virtual/test_operators.py b/tests/virtual/test_operators.py new file mode 100644 index 0000000..ac504a0 --- /dev/null +++ b/tests/virtual/test_operators.py @@ -0,0 +1,126 @@ +"""Tests for virtual built-in operators.""" + +import math + +import pytest + +from nxscli.virtual.errors import VirtualChannelError +from nxscli.virtual.models import ChannelSpec, VirtualChannelSpec +from nxscli.virtual.operators import ( + MathBinaryOperator, + RunningStatsOperator, + ScaleOffsetOperator, + default_operator_registry, +) + + +def _spec( + channel_id: str, + operator: str, + *, + params: dict[str, object] | None = None, +) -> VirtualChannelSpec: + return VirtualChannelSpec( + channel_id=channel_id, + name=channel_id, + operator=operator, + inputs=("0",), + params=params or {}, + ) + + +def test_scale_offset_operator() -> None: + op = ScaleOffsetOperator() + with pytest.raises(VirtualChannelError): + op.configure(_spec("v0", "scale_offset"), ()) + op.configure( + _spec("v0", "scale_offset", params={"scale": "2", "offset": "1"}), + (ChannelSpec("0", "ch0", "float", 2),), + ) + out = op.describe_outputs(_spec("v0", "scale_offset")) + assert out[0].vdim == 2 + assert op.process(((1.0, 2.0),))[0] == (3.0, 5.0) + op.reset() + + +def test_math_binary_operator_paths() -> None: + op = MathBinaryOperator() + with pytest.raises(VirtualChannelError): + op.configure( + _spec("v0", "math_binary"), + (ChannelSpec("0", "a", "float", 1),), + ) + with pytest.raises(VirtualChannelError): + op.configure( + _spec("v0", "math_binary"), + ( + ChannelSpec("0", "a", "float", 1), + ChannelSpec("1", "b", "float", 2), + ), + ) + with pytest.raises(VirtualChannelError): + op.configure( + _spec("v0", "math_binary", params={"op": "bad"}), + ( + ChannelSpec("0", "a", "float", 1), + ChannelSpec("1", "b", "float", 1), + ), + ) + op.configure( + _spec("v0", "math_binary", params={"op": "sub"}), + ( + ChannelSpec("0", "a", "float", 1), + ChannelSpec("1", "b", "float", 1), + ), + ) + assert op.process(((4.0,), (1.5,)))[0] == (2.5,) + op.reset() + + +def test_running_stats_operator_paths() -> None: + op = RunningStatsOperator() + with pytest.raises(VirtualChannelError): + op.configure( + _spec("v0", "stats_running"), + (), + ) + with pytest.raises(VirtualChannelError): + op.configure( + _spec("v0", "stats_running"), + (ChannelSpec("0", "a", "float", 0),), + ) + with pytest.raises(VirtualChannelError): + op.configure( + _spec("v0", "stats_running", params={"bad": 1}), + (ChannelSpec("0", "a", "float", 1),), + ) + + op.configure( + _spec("v0", "stats_running"), + (ChannelSpec("0", "a", "float", 1),), + ) + outs = op.describe_outputs(_spec("v0", "stats_running")) + assert len(outs) == 4 + minv, maxv, avgv, rmsv = op.process(((2.0,),)) + assert minv == (2.0,) + assert maxv == (2.0,) + assert avgv == (2.0,) + assert rmsv == (2.0,) + minv, maxv, avgv, rmsv = op.process(((4.0,),)) + assert minv == (2.0,) + assert maxv == (4.0,) + assert avgv == (3.0,) + assert math.isclose(rmsv[0], math.sqrt((4.0 + 16.0) / 2.0)) + minv, maxv, _, _ = op.process(((1.0,),)) + assert minv == (1.0,) + assert maxv == (4.0,) + with pytest.raises(VirtualChannelError): + op.process(((1.0, 2.0),)) + op.reset() + + +def test_default_registry_factories() -> None: + reg = default_operator_registry() + assert callable(reg["scale_offset"]) + assert callable(reg["math_binary"]) + assert callable(reg["stats_running"]) diff --git a/tests/virtual/test_params.py b/tests/virtual/test_params.py new file mode 100644 index 0000000..4774509 --- /dev/null +++ b/tests/virtual/test_params.py @@ -0,0 +1,22 @@ +"""Tests for virtual command parameter parsing.""" + +import click +import pytest + +from nxscli.commands.config.cmd_vadd import _parse_params + + +def test_parse_params_ok() -> None: + parsed = _parse_params(["a=1", "b=1.5", "c=true", "d=x", " "]) + assert parsed["a"] == 1 + assert parsed["b"] == 1.5 + assert parsed["c"] is True + assert parsed["d"] == "x" + + +def test_parse_params_invalid() -> None: + with pytest.raises(click.BadParameter): + _parse_params(["broken"]) + + with pytest.raises(click.BadParameter): + _parse_params(["=1"]) diff --git a/tests/virtual/test_runtime.py b/tests/virtual/test_runtime.py new file mode 100644 index 0000000..fdd0df0 --- /dev/null +++ b/tests/virtual/test_runtime.py @@ -0,0 +1,350 @@ +"""Tests for shared virtual runtime provider.""" + +import queue + +import numpy as np +import pytest +from nxslib.dev import Device, DeviceChannel +from nxslib.nxscope import DNxscopeStreamBlock + +from nxscli.channelref import ChannelRef +from nxscli.virtual.errors import VirtualChannelError +from nxscli.virtual.runtime import VirtualStreamRuntime +from nxscli.virtual.services import get_runtime + + +class _FakeRegistry: + def __init__(self) -> None: + self._services = {} + self.providers = [] + + def service_get(self, name: str): + return self._services.get(name) + + def service_set(self, name: str, service) -> None: + self._services[name] = service + + def stream_provider_add(self, provider) -> None: + self.providers.append(provider) + + +class _FakeNxscope: + def __init__(self) -> None: + channels = [ + DeviceChannel(0, 10, 1, "ch0"), + DeviceChannel(1, 10, 1, "ch1"), + ] + self.dev = Device(2, 0, 0, channels) + self._channels = {0: channels[0], 1: channels[1]} + self._subs: dict[int, list[queue.Queue[list[DNxscopeStreamBlock]]]] = { + 0: [], + 1: [], + } + + def dev_channel_get(self, chid: int): + return self._channels.get(chid) + + def stream_sub(self, chan: int): + subq = queue.Queue() + self._subs[chan].append(subq) + return subq + + def stream_unsub(self, subq) -> None: + for chan in self._subs: + if subq in self._subs[chan]: + self._subs[chan].remove(subq) + return + + +class _FakeNxscopeSparse(_FakeNxscope): + def dev_channel_get(self, chid: int): + if chid == 1: + return None + return super().dev_channel_get(chid) + + +def test_runtime_shared_provider() -> None: + reg = _FakeRegistry() + rt1 = get_runtime(reg) + rt2 = get_runtime(reg) + assert rt1 is rt2 + assert len(reg.providers) == 1 + + +def test_runtime_streams_virtual_blocks() -> None: + runtime = VirtualStreamRuntime() + fake = _FakeNxscope() + + runtime.add_virtual_channel( + channel_id=0, + name="v0", + operator="math_binary", + inputs=("0", "1"), + params={"op": "add"}, + ) + + runtime.on_connect(fake) + outq = runtime.stream_sub(ChannelRef.virtual(0)) + assert outq is not None + + runtime.on_stream_start() + + b0 = DNxscopeStreamBlock( + data=np.asarray([[1.0]], dtype=np.float64), meta=None + ) + b1 = DNxscopeStreamBlock( + data=np.asarray([[2.0]], dtype=np.float64), meta=None + ) + fake._subs[0][0].put([b0]) + runtime._thread_common() + fake._subs[1][0].put([b1]) + runtime._thread_common() + + out = outq.get(block=True, timeout=0.2) + assert isinstance(out[0], DNxscopeStreamBlock) + assert out[0].data.shape == (1, 1) + assert float(out[0].data[0, 0]) == 3.0 + + runtime.on_stream_stop() + runtime.on_disconnect() + + +def test_runtime_streams_from_1d_blocks() -> None: + runtime = VirtualStreamRuntime() + fake = _FakeNxscope() + + runtime.add_virtual_channel( + channel_id=0, + name="v0", + operator="scale_offset", + inputs=("0",), + params={"scale": 2.0, "offset": 1.0}, + ) + + runtime.on_connect(fake) + outq = runtime.stream_sub(ChannelRef.virtual(0)) + assert outq is not None + runtime.on_stream_start() + + # vdim=1 payload can arrive as 1D array + b0 = DNxscopeStreamBlock( + data=np.asarray([1.0, 2.0, 3.0], dtype=np.float64), + meta=None, + ) + fake._subs[0][0].put([b0]) + runtime._thread_common() + + out = outq.get(block=True, timeout=0.2) + assert isinstance(out[0], DNxscopeStreamBlock) + assert out[0].data.shape == (3, 1) + assert float(out[0].data[0, 0]) == 3.0 + assert float(out[0].data[1, 0]) == 5.0 + assert float(out[0].data[2, 0]) == 7.0 + + runtime.on_stream_stop() + runtime.on_disconnect() + + +def test_runtime_guard_paths_and_clear() -> None: + runtime = VirtualStreamRuntime() + runtime.on_stream_start() # no nxscope + runtime.on_stream_stop() # not started + runtime.on_disconnect() # no connection + runtime.clear() + assert runtime.declared() == () + runtime._thread_common() + + +def test_runtime_alias_and_duplicate_paths() -> None: + runtime = VirtualStreamRuntime() + with pytest.raises(VirtualChannelError): + runtime.add_virtual_channel( + channel_id=-1, + name="bad", + operator="scale_offset", + inputs=("0",), + params={}, + ) + + runtime.add_virtual_channel( + channel_id=10, + name="v10", + operator="scale_offset", + inputs=("0",), + params={}, + ) + with pytest.raises(VirtualChannelError): + runtime.add_virtual_channel( + channel_id=10, + name="dup", + operator="scale_offset", + inputs=("0",), + params={}, + ) + with pytest.raises(VirtualChannelError): + runtime.add_virtual_channel( + channel_id=8, + name="dup-alias", + operator="stats_running", + inputs=("0",), + params={}, + ) + with pytest.raises(VirtualChannelError): + runtime._normalize_input_token("vbad") + + +def test_runtime_channel_and_sub_paths() -> None: + runtime = VirtualStreamRuntime() + assert runtime.channel_get(ChannelRef.physical(0)) is None + assert runtime.stream_sub(ChannelRef.physical(0)) is None + assert runtime.stream_unsub(queue.Queue()) is False + assert runtime.channel_list() == () + assert runtime._normalize_input_token("v01") == "v1" + + +def test_runtime_stats_aliases_and_subscribe_unknown() -> None: + runtime = VirtualStreamRuntime() + fake = _FakeNxscope() + runtime.add_virtual_channel( + channel_id=20, + name="stats", + operator="stats_running", + inputs=("0",), + params={}, + ) + runtime.on_connect(fake) + # First alias exists, unknown alias does not. + sub = runtime.stream_sub(ChannelRef.virtual(20)) + assert sub is not None + assert runtime.stream_sub(ChannelRef.virtual(99)) is None + assert runtime.stream_unsub(sub) is True + runtime.on_disconnect() + + +def test_runtime_rebuild_mismatch_and_collect_skip() -> None: + runtime = VirtualStreamRuntime() + fake = _FakeNxscope() + runtime.add_virtual_channel( + channel_id=0, + name="v0", + operator="scale_offset", + inputs=("0",), + params={}, + ) + # Force mismatched declared output ids. + dec = runtime._declared[0] + runtime._declared[0] = type(dec)( + spec=dec.spec, + output_ids=("bad",), + aliases=dec.aliases, + aliased_names=dec.aliased_names, + ) + with pytest.raises(VirtualChannelError): + runtime.on_connect(fake) + + # Restore and connect. + runtime = VirtualStreamRuntime() + runtime.add_virtual_channel( + channel_id=0, + name="v0", + operator="scale_offset", + inputs=("0",), + params={}, + ) + runtime.on_connect(fake) + out = runtime._collect_output_rows( + 0, + [ + DNxscopeStreamBlock( + data=np.asarray([], dtype=np.float64), meta=None + ) + ], + ) + assert out == {} + # channel id mismatch path in collect-update + out2 = runtime._collect_output_rows( + 9, + [ + DNxscopeStreamBlock( + data=np.asarray([[1.0]], dtype=np.float64), meta=None + ) + ], + ) + assert out2 == {} + runtime._output_id_to_alias["v0"] = "missing" + batches = runtime._build_output_blocks({"missing": [(1.0,)]}) + assert batches == {} + runtime._output_id_to_alias["v0"] = "v0" + bad = DNxscopeStreamBlock( + data=np.asarray([["bad"]], dtype=object), meta=None + ) + assert runtime._collect_output_rows(0, [bad]) == {} + runtime._output_id_to_alias.clear() + good = DNxscopeStreamBlock( + data=np.asarray([[1.0]], dtype=np.float64), meta=None + ) + assert runtime._collect_output_rows(0, [good]) == {} + assert runtime._to_sample("bad") is None + assert runtime._normalize_input_token("1") == "1" + runtime.on_disconnect() + + +def test_runtime_start_stop_transitions() -> None: + runtime = VirtualStreamRuntime() + fake = _FakeNxscope() + runtime.on_connect(fake) + runtime.on_stream_start() # no declarations + runtime.add_virtual_channel( + channel_id=0, + name="v0", + operator="scale_offset", + inputs=("0",), + params={}, + ) + runtime.on_stream_start() + runtime.on_stream_start() # already started + runtime.on_stream_stop() + runtime.on_stream_stop() # already stopped + + +def test_runtime_rebuild_skips_missing_device_channels() -> None: + runtime = VirtualStreamRuntime() + fake = _FakeNxscopeSparse() + runtime.add_virtual_channel( + channel_id=0, + name="v0", + operator="scale_offset", + inputs=("0",), + params={}, + ) + runtime.on_connect(fake) + assert runtime.channel_list() + runtime._started = True + runtime._nxscope = None + runtime.on_stream_stop() + + +def test_runtime_stream_unsub_branch_paths() -> None: + runtime = VirtualStreamRuntime() + fake = _FakeNxscope() + runtime.add_virtual_channel( + channel_id=0, + name="v0", + operator="scale_offset", + inputs=("0",), + params={}, + ) + runtime.on_connect(fake) + sub1 = runtime.stream_sub(ChannelRef.virtual(0)) + sub2 = runtime.stream_sub(ChannelRef.virtual(0)) + assert sub1 is not None and sub2 is not None + assert runtime.stream_unsub(sub1) is True + # unrelated queue triggers scan-without-match path + assert runtime.stream_unsub(queue.Queue()) is False + assert runtime.stream_unsub(sub2) is True + runtime.on_disconnect() + + +def test_fake_nxscope_stream_unsub_no_match_branch() -> None: + fake = _FakeNxscope() + fake.stream_unsub(queue.Queue()) diff --git a/tox.ini b/tox.ini index f8fb223..20fc139 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,7 @@ deps = pytest-mock pytest-sugar coverage>=6 + numpy commands = coverage run -m pytest {posargs} # ensure 100% coverage of tests @@ -20,9 +21,7 @@ commands = description = run tests without coverage report (but in parallel) usedevelop=True deps = - pytest - pytest-mock - pytest-sugar + {[testenv]deps} pytest-xdist commands = pytest -n 4 {posargs} @@ -65,6 +64,7 @@ commands = description = run type checks deps = mypy + numpy commands = mypy --strict --pretty {posargs:src}