From 4094a35b14b7317d34481eb79c6c747fe1ee4351 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 30 Jun 2026 15:30:34 -0700 Subject: [PATCH 1/3] Add unit tests for `stimflow.Flow` - Fix indentantion of `stimflow.Chunk.__repr__` - Make flow centers default to `'infer'` rather than `None`, to distinguish the two desires - Add `skip_verification_before_append=False` to `stimflow.ChunkCompiler.__init__` - Make `verify` calls to objects passed into `stimflow.ChunkCompiler.append`, unless the above property was set to True - Add unit tests and better docstrings to `stimflow.Flow` - BREAKING CHANGE: rename `name=` to `obs_name=` in `stimflow.PauliMap.from_{x,y,z}s` --- dev/doctest_proper.py | 3 +- glue/stimflow/src/stimflow/_chunk/_chunk.py | 9 +- .../src/stimflow/_chunk/_chunk_builder.py | 2 +- .../src/stimflow/_chunk/_chunk_compiler.py | 26 +- glue/stimflow/src/stimflow/_core/_flow.py | 201 +++++++++++-- .../stimflow/src/stimflow/_core/_flow_test.py | 280 ++++++++++++++++++ .../stimflow/src/stimflow/_core/_pauli_map.py | 14 +- 7 files changed, 497 insertions(+), 38 deletions(-) diff --git a/dev/doctest_proper.py b/dev/doctest_proper.py index de9fb8394..040a58894 100755 --- a/dev/doctest_proper.py +++ b/dev/doctest_proper.py @@ -89,6 +89,7 @@ def main(): type=str, help="Objects that don't need an 'examples:' section in their documentation.") args = parser.parse_args() + suppressed = args.suppress_examples_warning_for globs = { k: __import__(k) for k in getattr(args, 'import') @@ -105,7 +106,7 @@ def main(): if '\n' in v.strip() and 'examples:' not in v and 'example:' not in v and '[deprecated]' not in v: if k.split('.')[-1] not in ['__format__', '__next__', '__iter__', '__init_subclass__', '__module__', '__eq__', '__ne__', '__str__', '__repr__']: if all(not (e.startswith('_') and not e.startswith('__')) for e in k.split('.')): - if all(not k.startswith(prefix) for prefix in args.suppress_examples_warning_for): + if all(not k.startswith(prefix) for prefix in suppressed): print(f" Warning: Missing 'examples:' section in docstring of {k!r}", file=sys.stderr) module.__test__ = {k: v for k, v in out.items()} diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk.py b/glue/stimflow/src/stimflow/_chunk/_chunk.py index 43d311930..293b6c37b 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk.py @@ -186,7 +186,7 @@ def _then_reflow(self, other: ChunkReflow) -> Chunk: else: new_flows.append(flow) for out, inputs in other.out2in.items(): - acc = None + acc: Flow | None = None used_outputs.update(inputs) for inp in inputs: if inp in old_discarded_outputs: @@ -316,7 +316,10 @@ def __repr__(self) -> str: lines.append(f" q2i={self.q2i!r},") lines.append(f" circuit={self.circuit!r},".replace("\n", "\n ")) if self.flows: - lines.append(f" flows={self.flows!r},") + lines.append(f" flows=[") + for flow in self.flows: + lines.append(f" {flow!r},".replace('\n', '\n ')) + lines.append(" ],") if self.discarded_inputs: lines.append(f" discarded_inputs={self.discarded_inputs!r},") if self.discarded_outputs: @@ -324,7 +327,7 @@ def __repr__(self) -> str: if self.wants_to_merge_with_prev: lines.append(f" wants_to_merge_with_prev={self.wants_to_merge_with_prev!r},") if self.wants_to_merge_with_next: - lines.append(f" discarded_outputs={self.wants_to_merge_with_next!r},") + lines.append(f" wants_to_merge_with_next={self.wants_to_merge_with_next!r},") lines.append(")") return "\n".join(lines) diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk_builder.py b/glue/stimflow/src/stimflow/_chunk/_chunk_builder.py index baa20badf..61011fbd5 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk_builder.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk_builder.py @@ -391,7 +391,7 @@ def add_flow( end: PauliMap | Tile | Literal["auto"] | None = None, measurements: Iterable[Any] | Literal["auto"] = (), ignore_unknown_measurements: bool = False, - center: complex | None = None, + center: complex | None | Literal['infer'] = 'infer', flags: Iterable[str] = frozenset(), sign: bool | None = None, ) -> None: diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py b/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py index a62adda9d..ad77c09bc 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py @@ -81,12 +81,22 @@ class ChunkCompiler: ''') """ - def __init__(self, *, metadata_func: Callable[[Flow], FlowMetadata] | None = None): + def __init__( + self, + *, + metadata_func: Callable[[Flow], FlowMetadata] | None = None, + skip_verification_before_append: bool = False, + ): """ Args: metadata_func: Determines coordinate data appended to detectors (after x, y, and t). Defaults to None (no extra metadata). + skip_verification_before_append: Defaults to False. When False, the + `verify` method if chunks (or other objects being appended) are + verified before being appended. When True, this verification step + is skipped. Setting to True will improve performance at the cost + of safety. Examples: >>> import stim @@ -159,6 +169,7 @@ def __init__(self, *, metadata_func: Callable[[Flow], FlowMetadata] | None = Non self.discarded_observables: set[int] = set() self.metadata_func: Callable[[Flow], FlowMetadata] = cast(Any, metadata_func) self.prev_chunk_wants_to_merge_with_next: bool = False + self.skip_verification_before_append: bool = skip_verification_before_append def ensure_qubits_included(self, qubits: Iterable[complex]): """Adds the given qubit positions to the indexed positions, if they aren't already.""" @@ -213,7 +224,8 @@ def cur_circuit_html_viewer(self) -> stimflow.str_html: from stimflow._viz import html_viewer return html_viewer( - circuit=copy.finish_circuit(), background=self.cur_end_interface() + copy.finish_circuit(), + background=self.cur_end_interface() ) def finish_circuit(self) -> stim.Circuit: @@ -372,8 +384,18 @@ def append(self, appended: Chunk | ChunkLoop | ChunkReflow) -> None: The input flows of the appended chunk must exactly match the open outgoing flows of the circuit so far. + + Args: + appended: The object to append to the circuit. + + Unless `skip_verification_before_append=True` was specified when constructing the + compiler, the `verify` method of this object will be called in order to ensure it + is well form. If verification is skipped and the object is not well-formed, the + compiler may output an invalid Stim circuit (e.g. with non-deterministic detectors). """ __tracebackhide__ = True + if not self.skip_verification_before_append: + appended.verify() if self.waiting_for_magic_init: self.append_magic_init_chunk(appended.start_interface()) diff --git a/glue/stimflow/src/stimflow/_core/_flow.py b/glue/stimflow/src/stimflow/_core/_flow.py index 6beda72f8..02b49fd30 100644 --- a/glue/stimflow/src/stimflow/_core/_flow.py +++ b/glue/stimflow/src/stimflow/_core/_flow.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Mapping -from typing import Any, cast +from typing import Any, cast, Literal import stim @@ -12,7 +12,7 @@ class _UNSPECIFIED_: def __repr__(self): - return "" + return "_UNSPECIFIED" _UNSPECIFIED: Any = _UNSPECIFIED_() @@ -25,25 +25,26 @@ def __init__( start: PauliMap | Tile | None = None, end: PauliMap | Tile | None = None, measurement_indices: Iterable[int] = (), - center: complex | None = None, + center: complex | None | Literal['infer'] = 'infer', flags: Iterable[Any] = frozenset(), sign: bool | None = None, ): """Initializes a Flow. Args: - start: Defaults to None (empty). The Pauli product operator at the beginning of the - circuit (before *all* operations, including resets). + start: Defaults to None (empty). The Pauli product operator at the beginning of + the circuit (before *all* operations, including resets). end: Defaults to None (empty). The Pauli product operator at the end of the circuit (after *all* operations, including measurements). - measurement_indices: Defaults to empty. Indices of measurements that mediate the flow (that multiply - into it as it traverses the circuit). - center: Defaults to None (unspecified). Specifies a 2d coordinate to use in metadata - when the flow is completed into a detector. Incompatible with obs_name. - flags: Defaults to empty. Custom information about the flow, that can be used by code - operating on chunks for a variety of purposes. For example, this could identify the - "color" of the flow in a color code. - sign: Defaults to None (unsigned). The expected sign of the flow. + measurement_indices: Defaults to empty. Indices of measurements that mediate + the flow (that multiply into it as it traverses the circuit). + center: Defaults to 'infer' (attempt to infer). Specifies a 2d coordinate to + use in metadata, when the flow is completed into a detector. Can be set to a + complex number or to None. + flags: Defaults to empty. Custom information about the flow, that can be used by + code operating on chunks for a variety of purposes. For example, this could + identify the "color" of the flow in a color code. + sign: Defaults to None (unsigned). """ if start is not None and not isinstance(start, (PauliMap, Tile)): raise TypeError( @@ -57,10 +58,12 @@ def __init__( raise TypeError(f"{flags=} is a str instead of a set") if isinstance(start, PauliMap) and isinstance(end, PauliMap) and start.obs_name != end.obs_name: raise ValueError(f'{start.obs_name=} != {end.obs_name=}') + if sign == -1: + raise ValueError(f"sign is a bool, not an int. Specify sign=True instead of {sign=}.") - if center is None and isinstance(start, Tile): + if center == 'infer' and isinstance(start, Tile): center = start.measure_qubit - if center is None and isinstance(end, Tile): + if center == 'infer' and isinstance(end, Tile): center = end.measure_qubit if isinstance(start, PauliMap): @@ -78,12 +81,14 @@ def __init__( elif end is None: end = PauliMap(obs_name=obs_name) - if center is None: + if center == 'infer': qubits: list[complex] = [] qubits.extend(start.keys()) qubits.extend(end.keys()) if qubits: center = sum(qubits) / len(qubits) + else: + center = None self.start: PauliMap = start self.end: PauliMap = end @@ -95,6 +100,31 @@ def __init__( def to_stim_flow( self, *, q2i: dict[complex, int], o2i: Mapping[Any, int | None] | None = None ) -> stim.Flow: + """Converts this `stimflow.Flow` into a `stim.Flow`. + + Args: + q2i: A mapping from stimflow qubit positions to stim qubit indices. + o2i: A mapping from stimflow obs names to stim obs indices. + This argument can be skipped if the flow has no obs_name. + + Returns: + The stim flow. + + Raise: + ValueError: + The flow has an `obs_name` but `o2i` wasn't specified. + + Examples: + >>> import stimflow as sf + >>> flow = sf.Flow( + ... start=sf.PauliMap({'Z': 1j}, obs_name="test"), + ... end=sf.PauliMap({'X': 1 + 1j}, obs_name="test"), + ... measurement_indices=[1, 2], + ... sign=True, + ... ) + >>> flow.to_stim_flow(q2i={1j: 2, 1 + 1j: 3}, o2i={"test": 0}) + stim.Flow("__Z -> -___X xor rec[1] xor rec[2] xor obs[0]") + """ out = self.end.to_stim_pauli_string(q2i) if self.sign: out.sign = -1 @@ -129,10 +159,59 @@ def with_edits( center: complex | None = _UNSPECIFIED, flags: Iterable[str] = _UNSPECIFIED, sign: Any = _UNSPECIFIED, + obs_name: None | str = _UNSPECIFIED, ) -> Flow: + """Returns the same flow but with specified edits. + + Args: + start: If specified, the returned flow has the specified start instead of the + start used by the original flow. Note: if `obs_name` is also specified, + the obs_name of this argument must be consistent with the given `obs_name`. + end: If specified, the returned flow has the specified end instead of the + end used by the original flow. Note: if `obs_name` is also specified, + the obs_name of this argument must be consistent with the given `obs_name`. + measurement_indices: If specified, the returned flow has the specified + measurement_indices instead of the measurement_indices used by the original + flow. + center: If specified, the returned flow has the specified center instead of the + center used by the original flow. + flags: If specified, the returned flow has the specified flags instead of the + flags used by the original flow. + sign: If specified, the returned flow has the specified sign instead of the + sign used by the original flow. + obs_name: If specified, the returned flow has the obs_name of both its start and + end changed to the given value. If `start` or `end` are specified alongside + this argument, they must use the same observable name. + + Returns: + The edited flow. + + Raises: + ValueError: + Specified contradictory `obs_name=` and `start=` values. + + OR + + Specified contradictory `obs_name=` and `end=` values. + + OR + + The edits produced an invalid flow (stimflow.Flow.__init__ raised an error). + """ + if start is not _UNSPECIFIED and obs_name is not _UNSPECIFIED and start.obs_name != obs_name: + raise ValueError(f"Specified contradictory observable names in `start` and `obs_name`.\n {start.obs_name=}\n {obs_name=}") + if end is not _UNSPECIFIED and obs_name is not _UNSPECIFIED and end.obs_name != obs_name: + raise ValueError(f"Specified contradictory observable names in `end` and `obs_name`.\n {end.obs_name=}\n {obs_name=}") + + start = self.start if start is _UNSPECIFIED else start + end = self.end if end is _UNSPECIFIED else end + if obs_name is not _UNSPECIFIED: + start = start.with_obs_name(obs_name) + end = end.with_obs_name(obs_name) + return Flow( - start=self.start if start is _UNSPECIFIED else start, - end=self.end if end is _UNSPECIFIED else end, + start=start, + end=end, measurement_indices=( self.measurement_indices if measurement_indices is _UNSPECIFIED @@ -211,17 +290,17 @@ def __str__(self) -> str: def __repr__(self): lines = ["stimflow.Flow("] if self.start: - lines.append(f"start={self.start!r},") + lines.append(f" start={self.start!r},") if self.end: - lines.append(f"end={self.end!r},") + lines.append(f" end={self.end!r},") if self.measurement_indices: - lines.append(f"measurement_indices={self.measurement_indices!r},") + lines.append(f" measurement_indices={self.measurement_indices!r},") if self.flags: - lines.append(f"flags={self.flags!r},") + lines.append(f" flags={self.flags!r},") if self.center is not None: - lines.append(f"center={self.center!r},") + lines.append(f" center={self.center!r},") if self.sign is not None: - lines.append(f"sign={self.sign!r},") + lines.append(f" sign={self.sign!r},") lines.append(")") return '\n'.join(lines) @@ -236,6 +315,43 @@ def with_transformed_coords(self, transform: Callable[[complex], complex]) -> Fl ) def fused_with_next_flow(self, next_flow: Flow, *, next_flow_measure_offset: int) -> Flow: + """Combines flows tail-to-head. + + For example, fusing X1 -> Y2 with Y2 -> Z3 produces X1 -> Z3. + + Measurement sets are xored, adjusting for the offset. Centers are + taken as is, preferring the center of the prior flow. Signs are xored. + flags are union'd. + + Args: + next_flow: The flow that occurs after this flow. Must have a start + that matches the end of this flow. + next_flow_measure_offset: What offset to add into measurement indices + used by the other flow. + + Returns: + The fused flow. + + Examples: + >>> import stimflow as sf + >>> a = sf.Flow( + ... start=sf.PauliMap({1: 'X'}), + ... end=sf.PauliMap({2: 'Y'}), + ... measurement_indices=[-1, 2], + ... ) + >>> b = sf.Flow( + ... start=sf.PauliMap({2: 'Y'}), + ... end=sf.PauliMap({3: 'Z'}), + ... measurement_indices=[-10, 20], + ... ) + >>> a.fused_with_next_flow(b, next_flow_measure_offset=100) + stimflow.Flow( + start=stimflow.PauliMap({(1+0j): 'X'}), + end=stimflow.PauliMap({(3+0j): 'Z'}), + measurement_indices=(2, 90, 99, 120), + center=(2+0j), + ) + """ if next_flow.start != self.end: raise ValueError("other.start != self.end") if next_flow.obs_name != self.obs_name: @@ -265,7 +381,42 @@ def fused_with_next_flow(self, next_flow: Flow, *, next_flow_measure_offset: int def __mul__(self, other: Flow) -> Flow: """Computes the product of two flows. - The product of A -> B and C -> D is (A*C) -> (B*D). + The product of two flows sends the product of their inputs to the product of their + outputs. For example, (A -> B) * (C -> D) = (A*C) -> (B*D). + + Starts are multiplied. Ends are multiplied. Measurement sets are xored. Centers are + averaged. Signs are xored. flags are union'd. + + Args: + other: The other flow in the multiplication. + + Raises: + ValueError: + The flows have incompatible observable names. + + OR + + The flows disagree on whether they're unsigned. + + Examples: + >>> import stimflow as sf + >>> a = sf.Flow( + ... start=sf.PauliMap({1: 'X'}), + ... end=sf.PauliMap({2: 'Y'}), + ... measurement_indices=[-1, 2], + ... ) + >>> b = sf.Flow( + ... start=sf.PauliMap({2: 'Y'}), + ... end=sf.PauliMap({3: 'Z'}), + ... measurement_indices=[-10, 20], + ... ) + >>> a * b + stimflow.Flow( + start=stimflow.PauliMap({(1+0j): 'X', (2+0j): 'Y'}), + end=stimflow.PauliMap({(2+0j): 'Y', (3+0j): 'Z'}), + measurement_indices=(-10, -1, 2, 20), + center=(2+0j), + ) """ if self.obs_name != other.obs_name: raise ValueError(f"{self.obs_name=} != {other.obs_name=}") diff --git a/glue/stimflow/src/stimflow/_core/_flow_test.py b/glue/stimflow/src/stimflow/_core/_flow_test.py index 2618da15a..e936d6eab 100644 --- a/glue/stimflow/src/stimflow/_core/_flow_test.py +++ b/glue/stimflow/src/stimflow/_core/_flow_test.py @@ -1,7 +1,287 @@ +import pytest +import stim + import stimflow +def test_with_edits(): + flow = stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + + assert flow.with_edits(start=stimflow.PauliMap()) == stimflow.Flow( + start=stimflow.PauliMap(), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + assert flow.with_edits(end=stimflow.PauliMap({1: 'X'})) == stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap({1: 'X'}), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + assert flow.with_edits(measurement_indices=[2]) == stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[2], + center=1j, + flags={'test'}, + sign=None, + ) + assert flow.with_edits(center=None) == stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=None, + flags={'test'}, + sign=None, + ) + assert flow.with_edits(flags=set()) == stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags=set(), + sign=None, + ) + assert flow.with_edits(sign=True) == stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ) + assert flow.with_edits(obs_name='test') == stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap(obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + assert flow.with_edits(start=stimflow.PauliMap({1: 'X'}, obs_name='test'), obs_name='test') == stimflow.Flow( + start=stimflow.PauliMap({1: 'X'}, obs_name='test'), + end=stimflow.PauliMap(obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + with pytest.raises(ValueError, match='contradict'): + flow.with_edits(start=stimflow.PauliMap(obs_name='xx'), obs_name='test') + with pytest.raises(ValueError, match='contradict'): + flow.with_edits(end=stimflow.PauliMap(obs_name='xx'), obs_name='test') + + +def test_repr(): + flow = stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + repr_text = repr(flow) + round_tripped = eval(repr_text, {}, {'stimflow': stimflow}) + assert round_tripped == flow + assert repr_text == repr(round_tripped) + + flow = stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + ) + assert len(repr(flow)) < len(repr_text) # Shorter. + repr_text = repr(flow) + round_tripped = eval(repr_text, {}, {'stimflow': stimflow}) + assert round_tripped == flow + assert repr_text == repr(round_tripped) + + +def test_equality(): + ref = stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + ref2 = stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + assert ref == ref2 + assert hash(ref) == hash(ref2) + assert ref != stimflow.Flow( + start=stimflow.PauliMap({1: 'Y', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + assert ref != stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap({1: 'Z'}), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=None, + ) + assert ref != stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[2], + center=1j, + flags={'test'}, + sign=None, + ) + assert ref != stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=2j, + flags={'test'}, + sign=None, + ) + assert ref != stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test2'}, + sign=None, + ) + assert ref != stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ) + + +def test_to_stim_flow(): + assert stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap(obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ).to_stim_flow(q2i={1: 0, 2: 1}, o2i={'test': 3}) == stim.Flow("XZ -> -rec[1] xor obs[3]") + + +def test_obs_name(): + assert stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap(obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ).obs_name == 'test' + + assert stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap(), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ).obs_name is None + + +def test_str(): + assert str(stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}), + end=stimflow.PauliMap({3: 'Y'}), + measurement_indices=[1], + )) == '''X[1+0j]*Z[2+0j] -> Y[3+0j]*rec[1]''' + assert str(stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap(obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=False, + )) == '''+X[1+0j]*Z[2+0j] -> rec[1] (obs=test) (flags=['test'])''' + assert str(stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap(obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + )) == '''-X[1+0j]*Z[2+0j] -> rec[1] (obs=test) (flags=['test'])''' + + def test_with_xz_flipped(): assert stimflow.Flow(start=stimflow.PauliMap({1: "X", 2: "Z"}), center=0).with_xz_flipped() == stimflow.Flow( start=stimflow.PauliMap({1: "Z", 2: "X"}), center=0 ) + assert stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap({3: 'Y', 4: 'X'}, obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ).with_xz_flipped() == stimflow.Flow( + start=stimflow.PauliMap({1: 'Z', 2: 'X'}, obs_name='test'), + end=stimflow.PauliMap({3: 'Y', 4: 'Z'}, obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ) + + +def test_with_transformed_coords(): + assert stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap({3: 'Y', 4: 'X'}, obs_name='test'), + measurement_indices=[1], + center=1j, + flags={'test'}, + sign=True, + ).with_transformed_coords(lambda e: e*2 + 1) == stimflow.Flow( + start=stimflow.PauliMap({3: 'X', 5: 'Z'}, obs_name='test'), + end=stimflow.PauliMap({7: 'Y', 9: 'X'}, obs_name='test'), + measurement_indices=[1], + center=2j + 1, + flags={'test'}, + sign=True, + ) + + assert stimflow.Flow( + start=stimflow.PauliMap({1: 'X', 2: 'Z'}, obs_name='test'), + end=stimflow.PauliMap({3: 'Y', 4: 'X'}, obs_name='test'), + measurement_indices=[1], + center=None, + flags={'test'}, + sign=True, + ).with_transformed_coords(lambda e: e*2 + 1) == stimflow.Flow( + start=stimflow.PauliMap({3: 'X', 5: 'Z'}, obs_name='test'), + end=stimflow.PauliMap({7: 'Y', 9: 'X'}, obs_name='test'), + measurement_indices=[1], + center=None, + flags={'test'}, + sign=True, + ) diff --git a/glue/stimflow/src/stimflow/_core/_pauli_map.py b/glue/stimflow/src/stimflow/_core/_pauli_map.py index 0605072db..9874db4e1 100644 --- a/glue/stimflow/src/stimflow/_core/_pauli_map.py +++ b/glue/stimflow/src/stimflow/_core/_pauli_map.py @@ -132,19 +132,19 @@ def _inline_svg_(self, *, q2p: Callable[[complex], complex], out_lines: list[str out_lines.append(f'''{p}''') @staticmethod - def from_xs(xs: Iterable[complex], *, name: Any = None) -> PauliMap: + def from_xs(xs: Iterable[complex], *, obs_name: Any = None) -> PauliMap: """Returns a PauliMap mapping the given qubits to the X basis.""" - return PauliMap({"X": xs}, obs_name=name) + return PauliMap({"X": xs}, obs_name=obs_name) @staticmethod - def from_ys(ys: Iterable[complex], *, name: Any = None) -> PauliMap: + def from_ys(ys: Iterable[complex], *, obs_name: Any = None) -> PauliMap: """Returns a PauliMap mapping the given qubits to the Y basis.""" - return PauliMap({"Y": ys}, obs_name=name) + return PauliMap({"Y": ys}, obs_name=obs_name) @staticmethod - def from_zs(zs: Iterable[complex], *, name: Any = None) -> PauliMap: + def from_zs(zs: Iterable[complex], *, obs_name: Any = None) -> PauliMap: """Returns a PauliMap mapping the given qubits to the Z basis.""" - return PauliMap({"Z": zs}, obs_name=name) + return PauliMap({"Z": zs}, obs_name=obs_name) def __contains__(self, item: complex) -> bool: """Determines if the PauliMap maps the given qubit to a non-identity Pauli.""" @@ -224,6 +224,8 @@ def __repr__(self) -> str: p = set(self.values()) if p == {'X'}: return f"stimflow.PauliMap.from_xs({qs!r}{s2})" + if p == {'Y'}: + return f"stimflow.PauliMap.from_ys({qs!r}{s2})" if p == {'Z'}: return f"stimflow.PauliMap.from_zs({qs!r}{s2})" s = {q: self._dict[q] for q in qs} From 99973cc8e8b7d0e3cf5c6cf451df0ab98e6e78d5 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 30 Jun 2026 15:32:51 -0700 Subject: [PATCH 2/3] regen --- glue/stimflow/doc/api.md | 194 ++++++++++++++++++++++++++++++++++----- 1 file changed, 172 insertions(+), 22 deletions(-) diff --git a/glue/stimflow/doc/api.md b/glue/stimflow/doc/api.md index 04958e58f..f913d27ec 100644 --- a/glue/stimflow/doc/api.md +++ b/glue/stimflow/doc/api.md @@ -921,7 +921,7 @@ def add_flow( end: "PauliMap | Tile | Literal['auto'] | None" = None, measurements: "Iterable[Any] | Literal['auto']" = (), ignore_unknown_measurements: bool = False, - center: complex | None = None, + center: "complex | None | Literal['infer']" = 'infer, flags: Iterable[str] = frozenset(), sign: bool | None = None, ) -> None: @@ -1296,12 +1296,18 @@ def __init__( self, *, metadata_func: Callable[[Flow], FlowMetadata] | None = None, + skip_verification_before_append: bool = False, ): """ Args: metadata_func: Determines coordinate data appended to detectors (after x, y, and t). Defaults to None (no extra metadata). + skip_verification_before_append: Defaults to False. When False, the + `verify` method if chunks (or other objects being appended) are + verified before being appended. When True, this verification step + is skipped. Setting to True will improve performance at the cost + of safety. Examples: >>> import stim @@ -1378,6 +1384,14 @@ def append( The input flows of the appended chunk must exactly match the open outgoing flows of the circuit so far. + + Args: + appended: The object to append to the circuit. + + Unless `skip_verification_before_append=True` was specified when constructing the + compiler, the `verify` method of this object will be called in order to ensure it + is well form. If verification is skipped and the object is not well-formed, the + compiler may output an invalid Stim circuit (e.g. with non-deterministic detectors). """ ``` @@ -2015,25 +2029,26 @@ def __init__( start: PauliMap | Tile | None = None, end: PauliMap | Tile | None = None, measurement_indices: Iterable[int] = (), - center: complex | None = None, + center: "complex | None | Literal['infer']" = 'infer, flags: Iterable[Any] = frozenset(), sign: bool | None = None, ): """Initializes a Flow. Args: - start: Defaults to None (empty). The Pauli product operator at the beginning of the - circuit (before *all* operations, including resets). + start: Defaults to None (empty). The Pauli product operator at the beginning of + the circuit (before *all* operations, including resets). end: Defaults to None (empty). The Pauli product operator at the end of the circuit (after *all* operations, including measurements). - measurement_indices: Defaults to empty. Indices of measurements that mediate the flow (that multiply - into it as it traverses the circuit). - center: Defaults to None (unspecified). Specifies a 2d coordinate to use in metadata - when the flow is completed into a detector. Incompatible with obs_name. - flags: Defaults to empty. Custom information about the flow, that can be used by code - operating on chunks for a variety of purposes. For example, this could identify the - "color" of the flow in a color code. - sign: Defaults to None (unsigned). The expected sign of the flow. + measurement_indices: Defaults to empty. Indices of measurements that mediate + the flow (that multiply into it as it traverses the circuit). + center: Defaults to 'infer' (attempt to infer). Specifies a 2d coordinate to + use in metadata, when the flow is completed into a detector. Can be set to a + complex number or to None. + flags: Defaults to empty. Custom information about the flow, that can be used by + code operating on chunks for a variety of purposes. For example, this could + identify the "color" of the flow in a color code. + sign: Defaults to None (unsigned). """ ``` @@ -2048,7 +2063,42 @@ def __mul__( ) -> Flow: """Computes the product of two flows. - The product of A -> B and C -> D is (A*C) -> (B*D). + The product of two flows sends the product of their inputs to the product of their + outputs. For example, (A -> B) * (C -> D) = (A*C) -> (B*D). + + Starts are multiplied. Ends are multiplied. Measurement sets are xored. Centers are + averaged. Signs are xored. flags are union'd. + + Args: + other: The other flow in the multiplication. + + Raises: + ValueError: + The flows have incompatible observable names. + + OR + + The flows disagree on whether they're unsigned. + + Examples: + >>> import stimflow as sf + >>> a = sf.Flow( + ... start=sf.PauliMap({1: 'X'}), + ... end=sf.PauliMap({2: 'Y'}), + ... measurement_indices=[-1, 2], + ... ) + >>> b = sf.Flow( + ... start=sf.PauliMap({2: 'Y'}), + ... end=sf.PauliMap({3: 'Z'}), + ... measurement_indices=[-10, 20], + ... ) + >>> a * b + stimflow.Flow( + start=stimflow.PauliMap({(1+0j): 'X', (2+0j): 'Y'}), + end=stimflow.PauliMap({(2+0j): 'Y', (3+0j): 'Z'}), + measurement_indices=(-10, -1, 2, 20), + center=(2+0j), + ) """ ``` @@ -2063,6 +2113,43 @@ def fused_with_next_flow( *, next_flow_measure_offset: int, ) -> Flow: + """Combines flows tail-to-head. + + For example, fusing X1 -> Y2 with Y2 -> Z3 produces X1 -> Z3. + + Measurement sets are xored, adjusting for the offset. Centers are + taken as is, preferring the center of the prior flow. Signs are xored. + flags are union'd. + + Args: + next_flow: The flow that occurs after this flow. Must have a start + that matches the end of this flow. + next_flow_measure_offset: What offset to add into measurement indices + used by the other flow. + + Returns: + The fused flow. + + Examples: + >>> import stimflow as sf + >>> a = sf.Flow( + ... start=sf.PauliMap({1: 'X'}), + ... end=sf.PauliMap({2: 'Y'}), + ... measurement_indices=[-1, 2], + ... ) + >>> b = sf.Flow( + ... start=sf.PauliMap({2: 'Y'}), + ... end=sf.PauliMap({3: 'Z'}), + ... measurement_indices=[-10, 20], + ... ) + >>> a.fused_with_next_flow(b, next_flow_measure_offset=100) + stimflow.Flow( + start=stimflow.PauliMap({(1+0j): 'X'}), + end=stimflow.PauliMap({(3+0j): 'Z'}), + measurement_indices=(2, 90, 99, 120), + center=(2+0j), + ) + """ ``` @@ -2087,6 +2174,31 @@ def to_stim_flow( q2i: dict[complex, int], o2i: Mapping[Any, int | None] | None = None, ) -> stim.Flow: + """Converts this `stimflow.Flow` into a `stim.Flow`. + + Args: + q2i: A mapping from stimflow qubit positions to stim qubit indices. + o2i: A mapping from stimflow obs names to stim obs indices. + This argument can be skipped if the flow has no obs_name. + + Returns: + The stim flow. + + Raise: + ValueError: + The flow has an `obs_name` but `o2i` wasn't specified. + + Examples: + >>> import stimflow as sf + >>> flow = sf.Flow( + ... start=sf.PauliMap({'Z': 1j}, obs_name="test"), + ... end=sf.PauliMap({'X': 1 + 1j}, obs_name="test"), + ... measurement_indices=[1, 2], + ... sign=True, + ... ) + >>> flow.to_stim_flow(q2i={1j: 2, 1 + 1j: 3}, o2i={"test": 0}) + stim.Flow("__Z -> -___X xor rec[1] xor rec[2] xor obs[0]") + """ ``` @@ -2097,13 +2209,51 @@ def to_stim_flow( def with_edits( self, *, - start: PauliMap = , - end: PauliMap = , - measurement_indices: Iterable[int] = , - center: complex | None = , - flags: Iterable[str] = , - sign: Any = , + start: PauliMap = _UNSPECIFIED, + end: PauliMap = _UNSPECIFIED, + measurement_indices: Iterable[int] = _UNSPECIFIED, + center: complex | None = _UNSPECIFIED, + flags: Iterable[str] = _UNSPECIFIED, + sign: Any = _UNSPECIFIED, + obs_name: None | str = _UNSPECIFIED, ) -> Flow: + """Returns the same flow but with specified edits. + + Args: + start: If specified, the returned flow has the specified start instead of the + start used by the original flow. Note: if `obs_name` is also specified, + the obs_name of this argument must be consistent with the given `obs_name`. + end: If specified, the returned flow has the specified end instead of the + end used by the original flow. Note: if `obs_name` is also specified, + the obs_name of this argument must be consistent with the given `obs_name`. + measurement_indices: If specified, the returned flow has the specified + measurement_indices instead of the measurement_indices used by the original + flow. + center: If specified, the returned flow has the specified center instead of the + center used by the original flow. + flags: If specified, the returned flow has the specified flags instead of the + flags used by the original flow. + sign: If specified, the returned flow has the specified sign instead of the + sign used by the original flow. + obs_name: If specified, the returned flow has the obs_name of both its start and + end changed to the given value. If `start` or `end` are specified alongside + this argument, they must use the same observable name. + + Returns: + The edited flow. + + Raises: + ValueError: + Specified contradictory `obs_name=` and `start=` values. + + OR + + Specified contradictory `obs_name=` and `end=` values. + + OR + + The edits produced an invalid flow (stimflow.Flow.__init__ raised an error). + """ ``` @@ -2983,7 +3133,7 @@ def commutes( def from_xs( xs: Iterable[complex], *, - name: Any = None, + obs_name: Any = None, ) -> PauliMap: """Returns a PauliMap mapping the given qubits to the X basis. """ @@ -2997,7 +3147,7 @@ def from_xs( def from_ys( ys: Iterable[complex], *, - name: Any = None, + obs_name: Any = None, ) -> PauliMap: """Returns a PauliMap mapping the given qubits to the Y basis. """ @@ -3011,7 +3161,7 @@ def from_ys( def from_zs( zs: Iterable[complex], *, - name: Any = None, + obs_name: Any = None, ) -> PauliMap: """Returns a PauliMap mapping the given qubits to the Z basis. """ From 4182627a431ade991322e37f4c8243ce3ad85d4c Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 30 Jun 2026 15:52:32 -0700 Subject: [PATCH 3/3] Fix unit tests caught by compiler verification --- glue/stimflow/src/stimflow/_chunk/_chunk.py | 9 ++- .../stimflow/_chunk/_chunk_compiler_test.py | 72 ++++++++++--------- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk.py b/glue/stimflow/src/stimflow/_chunk/_chunk.py index 293b6c37b..3d052f566 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk.py @@ -110,8 +110,16 @@ def __init__( ... ) >>> chunk.verify() """ + flows = tuple(flows) if q2i is None: q2i = {x + 1j * y: i for i, (x, y) in circuit.get_final_qubit_coordinates().items()} + for flow in flows: + for pauli_string in flow.start, flow.end: + for q in pauli_string.keys(): + if q not in q2i: + raise ValueError( + f"The given flows use the qubit position {q}, but the given circuit doesn't include a QUBIT_COORDS for this position.\n" + f" Affected flow: {flow}") if len(q2i) != circuit.num_qubits: raise ValueError( "The given circuit doesn't have enough `QUBIT_COORDS` instructions to " @@ -119,7 +127,6 @@ def __init__( "specify it by passing a `q2i={...}` argument, or add the missing " "`QUBIT_COORDS`." ) - flows = tuple(flows) if o2i is None: if circuit.num_observables: raise ValueError( diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk_compiler_test.py b/glue/stimflow/src/stimflow/_chunk/_chunk_compiler_test.py index 018698393..bf64c0121 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk_compiler_test.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk_compiler_test.py @@ -149,14 +149,14 @@ def test_chunk_compiler_loop(): stimflow.Chunk( circuit=stim.Circuit( """ - QUBIT_COORDS(0, 0) 0 - QUBIT_COORDS(0, 1) 1 - QUBIT_COORDS(0, 2) 2 - QUBIT_COORDS(0, 3) 3 - R 0 1 2 3 - """ + QUBIT_COORDS(0, 0) 0 + QUBIT_COORDS(0, 1) 1 + QUBIT_COORDS(0, 2) 2 + QUBIT_COORDS(0, 3) 3 + R 0 1 2 3 + """ ), - flows=[stimflow.Flow(end=stimflow.PauliMap.from_zs([k]), center=0) for k in range(4)], + flows=[stimflow.Flow(end=stimflow.PauliMap.from_zs([k*1j]), center=0) for k in range(4)], ) ) compiler.append( @@ -165,27 +165,27 @@ def test_chunk_compiler_loop(): stimflow.Chunk( circuit=stim.Circuit( """ - QUBIT_COORDS(0, 0) 0 - QUBIT_COORDS(0, 1) 1 - QUBIT_COORDS(0, 2) 2 - QUBIT_COORDS(0, 3) 3 - SWAP 0 1 - SWAP 1 2 - SWAP 2 3 - M 3 - """ + QUBIT_COORDS(0, 0) 0 + QUBIT_COORDS(0, 1) 1 + QUBIT_COORDS(0, 2) 2 + QUBIT_COORDS(0, 3) 3 + SWAP 0 1 + SWAP 1 2 + SWAP 2 3 + M 3 + """ ), flows=[ stimflow.Flow(start=stimflow.PauliMap.from_zs([0]), measurement_indices=[0], center=0), - stimflow.Flow(end=stimflow.PauliMap.from_zs([3]), measurement_indices=[0], center=0), + stimflow.Flow(end=stimflow.PauliMap.from_zs([3j]), measurement_indices=[0], center=0), stimflow.Flow( - start=stimflow.PauliMap.from_zs([1]), end=stimflow.PauliMap.from_zs([0]), center=0 + start=stimflow.PauliMap.from_zs([1j]), end=stimflow.PauliMap.from_zs([0]), center=0 ), stimflow.Flow( - start=stimflow.PauliMap.from_zs([2]), end=stimflow.PauliMap.from_zs([1]), center=0 + start=stimflow.PauliMap.from_zs([2j]), end=stimflow.PauliMap.from_zs([1j]), center=0 ), stimflow.Flow( - start=stimflow.PauliMap.from_zs([3]), end=stimflow.PauliMap.from_zs([2]), center=0 + start=stimflow.PauliMap.from_zs([3j]), end=stimflow.PauliMap.from_zs([2j]), center=0 ), ], ) @@ -197,14 +197,14 @@ def test_chunk_compiler_loop(): stimflow.Chunk( circuit=stim.Circuit( """ - QUBIT_COORDS(0, 0) 0 - QUBIT_COORDS(0, 1) 1 - QUBIT_COORDS(0, 2) 2 - QUBIT_COORDS(0, 3) 3 - M 0 1 2 3 - """ + QUBIT_COORDS(0, 0) 0 + QUBIT_COORDS(0, 1) 1 + QUBIT_COORDS(0, 2) 2 + QUBIT_COORDS(0, 3) 3 + M 0 1 2 3 + """ ), - flows=[stimflow.Flow(start=stimflow.PauliMap.from_zs([k]), measurement_indices=[k], center=0) for k in range(4)], + flows=[stimflow.Flow(start=stimflow.PauliMap.from_zs([k*1j]), measurement_indices=[k], center=0) for k in range(4)], ) ) assert compiler.finish_circuit() == stim.Circuit( @@ -449,7 +449,7 @@ def test_chunk_compiler_propagate_discards(): stimflow.Chunk( stim.Circuit( """ - MZZ 0 1 + MZZ 0 1 """ ), q2i={0: 0, 1: 1}, @@ -464,14 +464,15 @@ def test_chunk_compiler_propagate_discards(): stimflow.Chunk( stim.Circuit( """ - MX 0 1 + MX 0 1 """ ), q2i={0: 0, 1: 1}, discarded_inputs=[zz], - flows=[stimflow.Flow(start=xx, center=0, measurement_indices=[0])], + flows=[stimflow.Flow(start=xx, center=0, measurement_indices=[0, 1])], ) ) + print(c.finish_circuit()) assert c.finish_circuit() == stim.Circuit( """ QUBIT_COORDS(0, 0) 0 @@ -483,7 +484,7 @@ def test_chunk_compiler_propagate_discards(): SHIFT_COORDS(0, 0, 1) TICK MX 0 1 - """ + """ ) @@ -675,10 +676,11 @@ def test_merges_with_loop(): stimflow.Chunk( circuit=stim.Circuit( """ - QUBIT_COORDS(0, 1) 0 - R 0 - M 0 - """ + QUBIT_COORDS(0, 0) 1 + QUBIT_COORDS(0, 1) 0 + R 0 + M 0 + """ ), flows=[ stimflow.Flow(start=stimflow.PauliMap.from_zs([0]), end=stimflow.PauliMap.from_zs([0])),