diff --git a/changes.d/935.removal b/changes.d/935.removal new file mode 100644 index 00000000..84549094 --- /dev/null +++ b/changes.d/935.removal @@ -0,0 +1,2 @@ +The keyword argument `to_single_waveform` of `PulseTemplate.create_waveform` no longer accepts pulse template objects, only identifiers. +Use `PulseTemplate.metadata.to_single_waveform` to enforce single waveform translation. diff --git a/qupulse/program/protocol.py b/qupulse/program/protocol.py index a7656a8e..93e0c51d 100644 --- a/qupulse/program/protocol.py +++ b/qupulse/program/protocol.py @@ -1,4 +1,5 @@ """Definition of the program builder protocol.""" +import contextlib import copy import dataclasses from abc import abstractmethod, ABC @@ -63,7 +64,7 @@ def apply_mappings(self, @dataclasses.dataclass(frozen=True) class BuildSettings: """This dataclass bundles the immutable settings.""" - to_single_waveform: AbstractSet[str | object] + to_single_waveform: AbstractSet[str] @runtime_checkable @@ -122,7 +123,7 @@ def with_transformation(self, transformation: Transformation) -> ContextManager[ """Modify the build context for the duration of the context manager.""" @abstractmethod - def with_metadata(self, metadata: TemplateMetadata) -> ContextManager['ProgramBuilder']: + def with_metadata(self, metadata: TemplateMetadata, identifier: str | None) -> ContextManager['ProgramBuilder']: """Modify the build context for the duration of the context manager.""" @abstractmethod @@ -272,13 +273,19 @@ def _with_patched_context(self, **kwargs): self._build_context_stack.pop() @contextmanager - def with_metadata(self, metadata: TemplateMetadata): - # metadata.to_single_waveform == "always" is handled in PulseTemplate._build_program + def with_metadata(self, metadata: TemplateMetadata, identifier: str | None): + stack = contextlib.ExitStack() + + builder = self + if metadata.minimal_sample_rate is not None: - with self._with_patched_context(minimal_sample_rate=metadata.minimal_sample_rate) as builder: - yield builder - else: - yield self + builder = stack.enter_context(builder._with_patched_context(minimal_sample_rate=metadata.minimal_sample_rate)) + + if metadata.to_single_waveform == "always" or identifier in self.build_settings.to_single_waveform: + builder = stack.enter_context(builder.new_subprogram()) + + with stack: + yield builder @contextmanager def with_transformation(self, transformation: Transformation): diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 9c2b0bc3..9169b8d0 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -163,7 +163,7 @@ def create_program(self, *, measurement_mapping: Optional[Mapping[str, Optional[str]]]=None, channel_mapping: Optional[Mapping[ChannelID, Optional[ChannelID]]]=None, global_transformation: Optional[Transformation]=None, - to_single_waveform: Set[Union[str, 'PulseTemplate']]=None, + to_single_waveform: Set[str]=None, volatile: Union[Set[str], str] = None, program_builder: ProgramBuilder = None) -> Optional[Program]: """Translates this PulseTemplate into a program Loop. @@ -177,8 +177,7 @@ def create_program(self, *, measurement_mapping: A mapping of measurement window names. Windows that are mapped to None are omitted. channel_mapping: A mapping of channel names. Channels that are mapped to None are omitted. global_transformation: This transformation is applied to every waveform - to_single_waveform: A set of pulse templates (or identifiers) which are directly translated to a - waveform. This might change how transformations are applied. TODO: clarify + to_single_waveform: A set of identifiers which are directly translated to a waveform. volatile: Everything in the final program that depends on these parameters is marked as volatile program_builder: This program builder is used to build the return value. If `None` `default_program_builder` is used. @@ -193,6 +192,11 @@ def create_program(self, *, channel_mapping = dict() if to_single_waveform is None: to_single_waveform = set() + elif not all(isinstance(elem, str) for elem in to_single_waveform): + non_str_types = {type(elem).__name__ for elem in to_single_waveform if not isinstance(elem, str)} + warnings.warn(f"Non str to_single_waveform members are ignored: {non_str_types}", + category=NonStrToSingleWaveformWarning, stacklevel=2) + to_single_waveform = {elem for elem in to_single_waveform if isinstance(elem, str)} if volatile is None: volatile = set() elif isinstance(volatile, str): @@ -254,12 +258,8 @@ def _build_program(self, program_builder: ProgramBuilder): if (validate_scope := getattr(self, "validate_scope", None)) is not None: validate_scope(program_builder.build_context.scope) - to_single_waveform = program_builder.build_settings.to_single_waveform - if self.metadata.to_single_waveform == 'always' or self.identifier in to_single_waveform or self in to_single_waveform: - with program_builder.new_subprogram() as inner_program_builder: - self._internal_build_program(inner_program_builder) - else: - self._internal_build_program(program_builder) + with program_builder.with_metadata(self.metadata, self.identifier) as inner_program_builder: + self._internal_build_program(inner_program_builder) def _internal_build_program(self, program_builder: ProgramBuilder): """The subclass specific implementation of create_program().""" @@ -783,3 +783,6 @@ class UnknownVolatileParameter(RuntimeWarning): class MetadataComparison(RuntimeWarning): pass + +class NonStrToSingleWaveformWarning(RuntimeWarning): + pass diff --git a/tests/program/linspace_tests.py b/tests/program/linspace_tests.py index deab804a..29f6713d 100644 --- a/tests/program/linspace_tests.py +++ b/tests/program/linspace_tests.py @@ -568,11 +568,11 @@ def test_global_trafo_program(self): def test_local_trafo_program(self): program_builder = LinSpaceBuilder(('a',)) + self.pulse_template.metadata.to_single_waveform = "always" with self.assertRaises(NotImplementedError): # not implemented yet. This test should work as soon as its implemented program = self.pulse_template.create_program(program_builder=program_builder, - global_transformation=self.transformation, - to_single_waveform={self.pulse_template}) + global_transformation=self.transformation) self.assertEqual(self.program, program)