diff --git a/glue/stimflow/doc/api.md b/glue/stimflow/doc/api.md index f913d27ec..6baa8bfd5 100644 --- a/glue/stimflow/doc/api.md +++ b/glue/stimflow/doc/api.md @@ -251,6 +251,7 @@ from typing import overload, TYPE_CHECKING, Any, Iterable import io import pathlib import numpy as np +import functools ``` @@ -1410,6 +1411,67 @@ def append_magic_end_chunk( expected: Defaults to None (unused). If set to None, no extra checks are performed. If set to a ChunkInterface, it is verified that the open flows actually correspond to this interface. + + Examples: + >>> import stim + >>> import stimflow as sf + + >>> zz = sf.PauliMap({0: 'Z', 1 + 1j: 'Z'}) + >>> lz = sf.PauliMap({0: 'Z'}, obs_name='LZ') + >>> lx = sf.PauliMap({0: 'X', 1 + 1j: 'X'}, obs_name='LX') + >>> idle_chunk = sf.Chunk( + ... stim.Circuit(''' + ... QUBIT_COORDS(0, 0) 0 + ... QUBIT_COORDS(0, 1) 1 + ... QUBIT_COORDS(1, 1) 2 + ... R 1 + ... TICK + ... CX 0 1 + ... TICK + ... CX 2 1 + ... TICK + ... M 1 + ... '''), + ... flows=[ + ... sf.Flow(start=zz, measurement_indices=[0]), + ... sf.Flow(end=zz, measurement_indices=[0]), + ... sf.Flow(start=lz, end=lz), + ... sf.Flow(start=lx, end=lx), + ... ] + ... ) + + >>> compiler = sf.ChunkCompiler() + >>> compiler.append_magic_init_chunk() + >>> compiler.append(idle_chunk) + >>> compiler.append_magic_end_chunk() + >>> compiler.finish_circuit() + stim.Circuit(''' + QUBIT_COORDS(0, 0) 0 + QUBIT_COORDS(0, 1) 1 + QUBIT_COORDS(1, 1) 2 + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + TICK + MPP Z0*Z2 + TICK + R 1 + TICK + CX 0 1 + TICK + CX 2 1 + TICK + M 1 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + SHIFT_COORDS(0, 0, 1) + TICK + MPP Z0*Z2 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + TICK + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + ''') """ ``` @@ -1429,6 +1491,70 @@ def append_magic_init_chunk( verified that the next appended chunk actually has a start interface matching the given expected interface. If set to None, then no checks are performed; no constraints are placed on the next chunk. + + Examples: + >>> import stim + >>> import stimflow as sf + + >>> zz = sf.PauliMap({0: 'Z', 1 + 1j: 'Z'}) + >>> lz = sf.PauliMap({0: 'Z'}, obs_name='LZ') + >>> lx = sf.PauliMap({0: 'X', 1 + 1j: 'X'}, obs_name='LX') + >>> idle_chunk = sf.Chunk( + ... stim.Circuit(''' + ... QUBIT_COORDS(0, 0) 0 + ... QUBIT_COORDS(0, 1) 1 + ... QUBIT_COORDS(1, 1) 2 + ... R 1 + ... TICK + ... CX 0 1 + ... TICK + ... CX 2 1 + ... TICK + ... M 1 + ... '''), + ... flows=[ + ... sf.Flow(start=zz, measurement_indices=[0]), + ... sf.Flow(end=zz, measurement_indices=[0]), + ... sf.Flow(start=lz, end=lz), + ... sf.Flow(start=lx, end=lx), + ... ] + ... ) + + >>> compiler = sf.ChunkCompiler() + >>> # Tell the compiler to somehow satisfy whatever chunk comes next. + >>> compiler.append_magic_init_chunk() + >>> # As the next chunk is appended, the compiler notes its expected inputs and + >>> # adds corresponding MPP and OBSERVABLE_INCLUDE instructions: + >>> compiler.append(idle_chunk) + >>> compiler.append_magic_end_chunk() + >>> compiler.finish_circuit() + stim.Circuit(''' + QUBIT_COORDS(0, 0) 0 + QUBIT_COORDS(0, 1) 1 + QUBIT_COORDS(1, 1) 2 + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + TICK + MPP Z0*Z2 + TICK + R 1 + TICK + CX 0 1 + TICK + CX 2 1 + TICK + M 1 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + SHIFT_COORDS(0, 0, 1) + TICK + MPP Z0*Z2 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + TICK + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + ''') """ ``` @@ -1519,7 +1645,10 @@ class ChunkInterface: # stimflow.ChunkInterface.data_set # (in class stimflow.ChunkInterface) -class data_set: +@functools.cached_property +def data_set(self) -> frozenset[complex]: + """Returns the set of qubits used by the interface's stabilizers and observables. + """ ``` @@ -1586,7 +1715,8 @@ def to_svg( # stimflow.ChunkInterface.used_set # (in class stimflow.ChunkInterface) -class used_set: +@functools.cached_property +def used_set(self) -> frozenset[complex]: """Returns the set of qubits used in any flow mentioned by the chunk interface. """ ``` @@ -1939,7 +2069,12 @@ def from_auto_rewrite_transitions_using_stable( # stimflow.ChunkReflow.removed_inputs # (in class stimflow.ChunkReflow) -class removed_inputs: +@functools.cached_property +def removed_inputs(self) -> frozenset[PauliMap]: + """Returns the set of inputs expected by the reflow chunk. + + This includes stabilizer inputs, observable inputs, and discarded inputs. + """ ``` @@ -2049,6 +2184,15 @@ def __init__( code operating on chunks for a variety of purposes. For example, this could identify the "color" of the flow in a color code. sign: Defaults to None (unsigned). + + Examples: + >>> import stimflow as sf + >>> sf.Flow(start=sf.PauliMap.from_xs([0]), measurement_indices=[1]) + stimflow.Flow( + start=stimflow.PauliMap({0j: 'X'}), + measurement_indices=(1,), + center=0j, + ) """ ``` @@ -2161,6 +2305,19 @@ def fused_with_next_flow( def obs_name( self, ): + """The name of the observable that the flow is mapping. + + If the flow is not acting on a logical operator, this returns None. + + Examples: + >>> import stimflow as sf + >>> sf.Flow(start=sf.PauliMap.from_xs([0], obs_name='test')).obs_name + 'test' + >>> sf.Flow(end=sf.PauliMap.from_xs([0], obs_name='rest')).obs_name + 'rest' + >>> sf.Flow(start=sf.PauliMap.from_xs([0])).obs_name is None + True + """ ``` @@ -2253,6 +2410,18 @@ def with_edits( OR The edits produced an invalid flow (stimflow.Flow.__init__ raised an error). + + + Examples: + >>> import stimflow as sf + >>> flow = sf.Flow(start=sf.PauliMap.from_xs([0]), measurement_indices=[1]) + >>> flow.with_edits(end=sf.PauliMap.from_xs([1j])) + stimflow.Flow( + start=stimflow.PauliMap({0j: 'X'}), + end=stimflow.PauliMap({1j: 'X'}), + measurement_indices=(1,), + center=0j, + ) """ ``` @@ -2877,7 +3046,8 @@ class Patch: # stimflow.Patch.data_set # (in class stimflow.Patch) -class data_set: +@functools.cached_property +def data_set(self) -> frozenset[complex]: """Returns the set of all data qubits used by tiles in the patch. """ ``` @@ -2887,7 +3057,15 @@ class data_set: # stimflow.Patch.m2tile # (in class stimflow.Patch) -class m2tile: +@functools.cached_property +def m2tile(self) -> dict[complex, Tile]: + """Returns a measure-qubit-to-tile dictionary for the patch's tiles. + + Assumes all tiles have a unique measure qubit. Ignores tiles with no measure qubit. + + WARNING: Do not edit the returned dictionary! It is cached and returned by all + future calls to this property. Editing it will break future results. + """ ``` @@ -2895,7 +3073,8 @@ class m2tile: # stimflow.Patch.measure_set # (in class stimflow.Patch) -class measure_set: +@functools.cached_property +def measure_set(self) -> frozenset[complex]: """Returns the set of all measure qubits used by tiles in the patch. """ ``` @@ -2905,7 +3084,8 @@ class measure_set: # stimflow.Patch.partitioned_tiles # (in class stimflow.Patch) -class partitioned_tiles: +@functools.cached_property +def partitioned_tiles(self) -> tuple[tuple[Tile, ...], ...]: """Returns the tiles of the patch, but split into non-overlapping groups. """ ``` @@ -2938,7 +3118,8 @@ def to_svg( # stimflow.Patch.used_set # (in class stimflow.Patch) -class used_set: +@functools.cached_property +def used_set(self) -> frozenset[complex]: """Returns the set of all data and measure qubits used by tiles in the patch. """ ``` @@ -3395,7 +3576,10 @@ def concat_over( # stimflow.StabilizerCode.data_set # (in class stimflow.StabilizerCode) -class data_set: +@functools.cached_property +def data_set(self) -> frozenset[complex]: + """Returns the set of data qubits used by the stabilizers/logicals of the code. + """ ``` @@ -3428,7 +3612,8 @@ def find_logical_error( # stimflow.StabilizerCode.flat_logicals # (in class stimflow.StabilizerCode) -class flat_logicals: +@functools.cached_property +def flat_logicals(self) -> tuple[PauliMap, ...]: """Returns a list of the logical operators defined by the stabilizer code. It's "flat" because paired X/Z logicals are returned separately instead of @@ -3507,7 +3692,13 @@ def make_phenom_circuit( # stimflow.StabilizerCode.measure_set # (in class stimflow.StabilizerCode) -class measure_set: +@functools.cached_property +def measure_set(self) -> frozenset[complex]: + """Returns the set of measure qubits used by tiles of the code. + + Note that tiles may not specify measure qubits, in which case this will return + the empty set. + """ ``` @@ -3621,7 +3812,12 @@ def transversal_measure_chunk( # stimflow.StabilizerCode.used_set # (in class stimflow.StabilizerCode) -class used_set: +@functools.cached_property +def used_set(self) -> frozenset[complex]: + """Returns the set of all qubits mentioned by this code. + + This includes data qubits *and* measure qubits. + """ ``` @@ -3964,7 +4160,23 @@ def __init__( # stimflow.Tile.basis # (in class stimflow.Tile) -class basis: +@functools.cached_property +def basis(self) -> Literal['X', 'Y', 'Z'] | None: + """Returns the basis of the stabilizer, assuming it has exactly one. + + Returns: + If all data qubits have the same basis, returns that basis. + Otherwise, returns None. + + Examples: + >>> import stimflow as sf + >>> sf.Tile(bases="X", data_qubits=[0, 1, 1j]).basis + 'X' + >>> sf.Tile(bases="ZZZ", data_qubits=[0, 1, 1j]).basis + 'Z' + >>> sf.Tile(bases="XYZ", data_qubits=[0, 1, 1j]).basis is None + True + """ ``` @@ -3982,7 +4194,10 @@ def center( # stimflow.Tile.data_set # (in class stimflow.Tile) -class data_set: +@functools.cached_property +def data_set(self) -> frozenset[complex]: + """Returns the set of data qubits used by the Tile. + """ ``` @@ -4000,7 +4215,10 @@ def to_pauli_map( # stimflow.Tile.used_set # (in class stimflow.Tile) -class used_set: +@functools.cached_property +def used_set(self) -> frozenset[complex]: + """Returns the set of data and/or measure qubits used by the Tile. + """ ``` diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py b/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py index ad77c09bc..66a8f7a25 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk_compiler.py @@ -286,6 +286,70 @@ def append_magic_init_chunk(self, expected: ChunkInterface | None = None) -> Non verified that the next appended chunk actually has a start interface matching the given expected interface. If set to None, then no checks are performed; no constraints are placed on the next chunk. + + Examples: + >>> import stim + >>> import stimflow as sf + + >>> zz = sf.PauliMap({0: 'Z', 1 + 1j: 'Z'}) + >>> lz = sf.PauliMap({0: 'Z'}, obs_name='LZ') + >>> lx = sf.PauliMap({0: 'X', 1 + 1j: 'X'}, obs_name='LX') + >>> idle_chunk = sf.Chunk( + ... stim.Circuit(''' + ... QUBIT_COORDS(0, 0) 0 + ... QUBIT_COORDS(0, 1) 1 + ... QUBIT_COORDS(1, 1) 2 + ... R 1 + ... TICK + ... CX 0 1 + ... TICK + ... CX 2 1 + ... TICK + ... M 1 + ... '''), + ... flows=[ + ... sf.Flow(start=zz, measurement_indices=[0]), + ... sf.Flow(end=zz, measurement_indices=[0]), + ... sf.Flow(start=lz, end=lz), + ... sf.Flow(start=lx, end=lx), + ... ] + ... ) + + >>> compiler = sf.ChunkCompiler() + >>> # Tell the compiler to somehow satisfy whatever chunk comes next. + >>> compiler.append_magic_init_chunk() + >>> # As the next chunk is appended, the compiler notes its expected inputs and + >>> # adds corresponding MPP and OBSERVABLE_INCLUDE instructions: + >>> compiler.append(idle_chunk) + >>> compiler.append_magic_end_chunk() + >>> compiler.finish_circuit() + stim.Circuit(''' + QUBIT_COORDS(0, 0) 0 + QUBIT_COORDS(0, 1) 1 + QUBIT_COORDS(1, 1) 2 + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + TICK + MPP Z0*Z2 + TICK + R 1 + TICK + CX 0 1 + TICK + CX 2 1 + TICK + M 1 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + SHIFT_COORDS(0, 0, 1) + TICK + MPP Z0*Z2 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + TICK + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + ''') """ if expected is None: self.waiting_for_magic_init = True @@ -322,6 +386,67 @@ def append_magic_end_chunk(self, expected: ChunkInterface | None = None) -> None expected: Defaults to None (unused). If set to None, no extra checks are performed. If set to a ChunkInterface, it is verified that the open flows actually correspond to this interface. + + Examples: + >>> import stim + >>> import stimflow as sf + + >>> zz = sf.PauliMap({0: 'Z', 1 + 1j: 'Z'}) + >>> lz = sf.PauliMap({0: 'Z'}, obs_name='LZ') + >>> lx = sf.PauliMap({0: 'X', 1 + 1j: 'X'}, obs_name='LX') + >>> idle_chunk = sf.Chunk( + ... stim.Circuit(''' + ... QUBIT_COORDS(0, 0) 0 + ... QUBIT_COORDS(0, 1) 1 + ... QUBIT_COORDS(1, 1) 2 + ... R 1 + ... TICK + ... CX 0 1 + ... TICK + ... CX 2 1 + ... TICK + ... M 1 + ... '''), + ... flows=[ + ... sf.Flow(start=zz, measurement_indices=[0]), + ... sf.Flow(end=zz, measurement_indices=[0]), + ... sf.Flow(start=lz, end=lz), + ... sf.Flow(start=lx, end=lx), + ... ] + ... ) + + >>> compiler = sf.ChunkCompiler() + >>> compiler.append_magic_init_chunk() + >>> compiler.append(idle_chunk) + >>> compiler.append_magic_end_chunk() + >>> compiler.finish_circuit() + stim.Circuit(''' + QUBIT_COORDS(0, 0) 0 + QUBIT_COORDS(0, 1) 1 + QUBIT_COORDS(1, 1) 2 + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + TICK + MPP Z0*Z2 + TICK + R 1 + TICK + CX 0 1 + TICK + CX 2 1 + TICK + M 1 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + SHIFT_COORDS(0, 0, 1) + TICK + MPP Z0*Z2 + DETECTOR(0.5, 0.5, 0) rec[-2] rec[-1] + TICK + OBSERVABLE_INCLUDE(0) X0 X2 + TICK + OBSERVABLE_INCLUDE(1) Z0 + ''') """ if self.waiting_for_magic_init: self.waiting_for_magic_init = False diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk_interface.py b/glue/stimflow/src/stimflow/_chunk/_chunk_interface.py index ce0c6bbbf..8b6ef9c7c 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk_interface.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk_interface.py @@ -137,6 +137,7 @@ def __eq__(self, other): @functools.cached_property def data_set(self) -> frozenset[complex]: + """Returns the set of qubits used by the interface's stabilizers and observables.""" return frozenset( q for pauli_string_list in [self.ports, self.discards] diff --git a/glue/stimflow/src/stimflow/_chunk/_chunk_reflow.py b/glue/stimflow/src/stimflow/_chunk/_chunk_reflow.py index d27af7c5c..71a28198c 100644 --- a/glue/stimflow/src/stimflow/_chunk/_chunk_reflow.py +++ b/glue/stimflow/src/stimflow/_chunk/_chunk_reflow.py @@ -196,6 +196,10 @@ def end_patch(self) -> Patch: @functools.cached_property def removed_inputs(self) -> frozenset[PauliMap]: + """Returns the set of inputs expected by the reflow chunk. + + This includes stabilizer inputs, observable inputs, and discarded inputs. + """ return frozenset(v for vs in self.out2in.values() for v in vs) | frozenset(self.discard_in) def verify( diff --git a/glue/stimflow/src/stimflow/_chunk/_patch.py b/glue/stimflow/src/stimflow/_chunk/_patch.py index 2a52f3f94..3195036a2 100644 --- a/glue/stimflow/src/stimflow/_chunk/_patch.py +++ b/glue/stimflow/src/stimflow/_chunk/_patch.py @@ -94,7 +94,14 @@ def with_only_z_tiles(self) -> Patch: @functools.cached_property def m2tile(self) -> dict[complex, Tile]: - return {e.measure_qubit: e for e in self.tiles} + """Returns a measure-qubit-to-tile dictionary for the patch's tiles. + + Assumes all tiles have a unique measure qubit. Ignores tiles with no measure qubit. + + WARNING: Do not edit the returned dictionary! It is cached and returned by all + future calls to this property. Editing it will break future results. + """ + return {e.measure_qubit: e for e in self.tiles if e.measure_qubit is not None} def _repr_svg_(self) -> str: return self.to_svg() diff --git a/glue/stimflow/src/stimflow/_chunk/_stabilizer_code.py b/glue/stimflow/src/stimflow/_chunk/_stabilizer_code.py index bde5e9658..b4034ae64 100644 --- a/glue/stimflow/src/stimflow/_chunk/_stabilizer_code.py +++ b/glue/stimflow/src/stimflow/_chunk/_stabilizer_code.py @@ -422,6 +422,8 @@ def with_edits( @functools.cached_property def data_set(self) -> frozenset[complex]: + """Returns the set of data qubits used by the stabilizers/logicals of the code. + """ result = set(self.stabilizers.data_set) for obs in self.logicals: if isinstance(obs, PauliMap): @@ -434,10 +436,19 @@ def data_set(self) -> frozenset[complex]: @functools.cached_property def measure_set(self) -> frozenset[complex]: + """Returns the set of measure qubits used by tiles of the code. + + Note that tiles may not specify measure qubits, in which case this will return + the empty set. + """ return self.stabilizers.measure_set @functools.cached_property def used_set(self) -> frozenset[complex]: + """Returns the set of all qubits mentioned by this code. + + This includes data qubits *and* measure qubits. + """ result = set(self.stabilizers.used_set) for obs in self.logicals: if isinstance(obs, PauliMap): diff --git a/glue/stimflow/src/stimflow/_core/_flow.py b/glue/stimflow/src/stimflow/_core/_flow.py index 02b49fd30..be6e23489 100644 --- a/glue/stimflow/src/stimflow/_core/_flow.py +++ b/glue/stimflow/src/stimflow/_core/_flow.py @@ -45,6 +45,15 @@ def __init__( code operating on chunks for a variety of purposes. For example, this could identify the "color" of the flow in a color code. sign: Defaults to None (unsigned). + + Examples: + >>> import stimflow as sf + >>> sf.Flow(start=sf.PauliMap.from_xs([0]), measurement_indices=[1]) + stimflow.Flow( + start=stimflow.PauliMap({0j: 'X'}), + measurement_indices=(1,), + center=0j, + ) """ if start is not None and not isinstance(start, (PauliMap, Tile)): raise TypeError( @@ -148,6 +157,19 @@ def to_stim_flow( @property def obs_name(self) -> Any: + """The name of the observable that the flow is mapping. + + If the flow is not acting on a logical operator, this returns None. + + Examples: + >>> import stimflow as sf + >>> sf.Flow(start=sf.PauliMap.from_xs([0], obs_name='test')).obs_name + 'test' + >>> sf.Flow(end=sf.PauliMap.from_xs([0], obs_name='rest')).obs_name + 'rest' + >>> sf.Flow(start=sf.PauliMap.from_xs([0])).obs_name is None + True + """ return self.start.obs_name def with_edits( @@ -197,6 +219,18 @@ def with_edits( OR The edits produced an invalid flow (stimflow.Flow.__init__ raised an error). + + + Examples: + >>> import stimflow as sf + >>> flow = sf.Flow(start=sf.PauliMap.from_xs([0]), measurement_indices=[1]) + >>> flow.with_edits(end=sf.PauliMap.from_xs([1j])) + stimflow.Flow( + start=stimflow.PauliMap({0j: 'X'}), + end=stimflow.PauliMap({1j: 'X'}), + measurement_indices=(1,), + center=0j, + ) """ if start is not _UNSPECIFIED and obs_name is not _UNSPECIFIED and start.obs_name != obs_name: raise ValueError(f"Specified contradictory observable names in `start` and `obs_name`.\n {start.obs_name=}\n {obs_name=}") diff --git a/glue/stimflow/src/stimflow/_core/_tile.py b/glue/stimflow/src/stimflow/_core/_tile.py index 9bc27a20c..6f84d6aab 100644 --- a/glue/stimflow/src/stimflow/_core/_tile.py +++ b/glue/stimflow/src/stimflow/_core/_tile.py @@ -163,16 +163,33 @@ def __len__(self) -> int: @functools.cached_property def data_set(self) -> frozenset[complex]: + """Returns the set of data qubits used by the Tile.""" return frozenset(e for e in self.data_qubits if e is not None) @functools.cached_property def used_set(self) -> frozenset[complex]: + """Returns the set of data and/or measure qubits used by the Tile.""" if self.measure_qubit is None: return self.data_set return self.data_set | frozenset([self.measure_qubit]) @functools.cached_property def basis(self) -> Literal["X", "Y", "Z"] | None: + """Returns the basis of the stabilizer, assuming it has exactly one. + + Returns: + If all data qubits have the same basis, returns that basis. + Otherwise, returns None. + + Examples: + >>> import stimflow as sf + >>> sf.Tile(bases="X", data_qubits=[0, 1, 1j]).basis + 'X' + >>> sf.Tile(bases="ZZZ", data_qubits=[0, 1, 1j]).basis + 'Z' + >>> sf.Tile(bases="XYZ", data_qubits=[0, 1, 1j]).basis is None + True + """ bs: set[Literal["X", "Y", "Z"]] bs = cast(Any, {b for q, b in zip(self.data_qubits, self.bases) if q is not None}) if len(bs) == 0: diff --git a/glue/stimflow/tools/gen_api_reference.py b/glue/stimflow/tools/gen_api_reference.py index f0883930a..75daa04d5 100755 --- a/glue/stimflow/tools/gen_api_reference.py +++ b/glue/stimflow/tools/gen_api_reference.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import dataclasses +import functools import inspect import sys import types @@ -145,11 +146,13 @@ def splay_signature(sig: str) -> list[str]: def _handle_pybind_method( - *, obj: Any, is_property: bool, out_obj: DescribedObject, parent: Any, full_name: str + *, obj: Any, is_property: bool, is_cached_property: bool, out_obj: DescribedObject, parent: Any, full_name: str ) -> tuple[str, bool, str, str]: doc = normalize_doc_string(getattr(obj, "__doc__", "") or "") if is_property: out_obj.lines.append("@property") + if is_cached_property: + out_obj.lines.append("@functools.cached_property") doc_lines = doc.splitlines() new_args_name = None was_args = False @@ -163,6 +166,14 @@ def _handle_pybind_method( doc_lines_left.append(line) was_args = "Args:" in line + if not doc_lines_left: + doc_lines_left.append("") + print(f"WARNING: no documentation for {full_name}", file=sys.stderr) + + if is_cached_property: + return_type = obj.func.__annotations__['return'] + return f"def {term_name}(self) -> {return_type}:", False, "\n".join(doc_lines_left).lstrip(), "" + if is_property: sig_name = f"{term_name}(self)" if getattr(obj, "fset", None) is not None: @@ -206,6 +217,7 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Des term_name = full_name.split(".")[-1] is_property = isinstance(obj, property) + is_cached_property = isinstance(obj, functools.cached_property) is_method = doc.startswith(term_name) has_setter = False is_normal_method = isinstance(obj, types.FunctionType) @@ -228,9 +240,9 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Des text = text.replace(",,", ",") text = text.replace("numpy.", "np.") - elif is_method or is_property: + elif is_method or is_property or is_cached_property: text, has_setter, doc, sig_name = _handle_pybind_method( - obj=obj, is_property=is_property, out_obj=out_obj, parent=parent, full_name=full_name + obj=obj, is_property=is_property, is_cached_property=is_cached_property, out_obj=out_obj, parent=parent, full_name=full_name ) elif isinstance(obj, (int, str)): text = f"{term_name}: {type(obj).__name__} = {obj!r}" @@ -339,6 +351,7 @@ def main(): import io import pathlib import numpy as np +import functools ``` """.strip() )