diff --git a/cirq-core/cirq/sim/state_vector_simulation_state.py b/cirq-core/cirq/sim/state_vector_simulation_state.py index 0e611a4c415..efd902755fc 100644 --- a/cirq-core/cirq/sim/state_vector_simulation_state.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state.py @@ -33,7 +33,12 @@ class _BufferedStateVector(qis.QuantumStateRepresentation): """Contains the state vector and buffer for efficient state evolution.""" - def __init__(self, state_vector: np.ndarray, buffer: np.ndarray | None = None): + def __init__( + self, + state_vector: np.ndarray, + buffer: np.ndarray | None = None, + should_preserve_initial_state: bool = True, + ): """Initializes the object with the inputs. This initializer creates the buffer if necessary. @@ -41,14 +46,28 @@ def __init__(self, state_vector: np.ndarray, buffer: np.ndarray | None = None): Args: state_vector: The state vector, must be correctly formatted. The data is not checked for validity here due to performance concerns. - buffer: Optional, must be same shape as the state vector. If not provided, a buffer - will be created automatically. + buffer: Optional, must be same shape as the state vector. If not provided and + should_preserve_initial_state is True, a buffer will be created automatically. + If should_preserve_initial_state is False, the buffer is allocated lazily on + first use. + should_preserve_initial_state: If False, skips eager buffer allocation to avoid + the memory overhead of an extra state-vector-sized array. """ self._state_vector = state_vector - if buffer is None: - buffer = np.empty_like(state_vector) - self._buffer = buffer self._qid_shape = state_vector.shape + self._raw_buffer: np.ndarray | None = buffer + if self._raw_buffer is None and should_preserve_initial_state: + self._raw_buffer = np.empty_like(state_vector) + + @property + def _buffer(self) -> np.ndarray: + if self._raw_buffer is None: + self._raw_buffer = np.empty_like(self._state_vector) + return self._raw_buffer + + @_buffer.setter + def _buffer(self, value: np.ndarray) -> None: + self._raw_buffer = value @classmethod def create( @@ -58,6 +77,7 @@ def create( qid_shape: tuple[int, ...] | None = None, dtype: type[np.complexfloating] | np.dtype[np.complexfloating] | None = None, buffer: np.ndarray | None = None, + should_preserve_initial_state: bool = True, ): """Initializes the object with the inputs. @@ -70,6 +90,8 @@ def create( dtype: The dtype of the state vector, if the initial state is provided as an int. buffer: Optional, must be length 3 and same shape as the state vector. If not provided, a buffer will be created automatically. + should_preserve_initial_state: If False, skips copying the initial state and defers + buffer allocation to first use, reducing memory overhead for large state vectors. Raises: ValueError: If initial state is provided as integer, but qid_shape is not provided. """ @@ -84,10 +106,10 @@ def create( state_vector = initial_state.reshape(qid_shape) else: state_vector = initial_state - if np.may_share_memory(state_vector, initial_state): + if should_preserve_initial_state and np.may_share_memory(state_vector, initial_state): state_vector = state_vector.copy() state_vector = state_vector.astype(dtype, copy=False) - return cls(state_vector, buffer) + return cls(state_vector, buffer, should_preserve_initial_state) def copy(self, deep_copy_buffers: bool = True) -> _BufferedStateVector: """Copies the object. @@ -97,9 +119,16 @@ def copy(self, deep_copy_buffers: bool = True) -> _BufferedStateVector: Returns: A copy of the object. """ + preserve = self._raw_buffer is not None + buf: np.ndarray | None + if self._raw_buffer is not None and deep_copy_buffers: + buf = self._raw_buffer.copy() + else: + buf = self._raw_buffer return _BufferedStateVector( state_vector=self._state_vector.copy(), - buffer=self._buffer.copy() if deep_copy_buffers else self._buffer, + buffer=buf, + should_preserve_initial_state=preserve, ) def kron(self, other: _BufferedStateVector) -> _BufferedStateVector: @@ -327,6 +356,7 @@ def __init__( initial_state: np.ndarray | cirq.STATE_VECTOR_LIKE = 0, dtype: type[np.complexfloating] | np.dtype[np.complexfloating] = np.complex64, classical_data: cirq.ClassicalDataStore | None = None, + should_preserve_initial_state: bool = True, ): """Inits StateVectorSimulationState. @@ -354,6 +384,7 @@ def __init__( qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, dtype=dtype, buffer=available_buffer, + should_preserve_initial_state=should_preserve_initial_state, ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) diff --git a/cirq-core/cirq/sim/state_vector_simulation_state_test.py b/cirq-core/cirq/sim/state_vector_simulation_state_test.py index a738aa086c9..4e8c29e3b19 100644 --- a/cirq-core/cirq/sim/state_vector_simulation_state_test.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state_test.py @@ -274,3 +274,89 @@ def test_measured_mixture() -> None: def test_qid_shape_error() -> None: with pytest.raises(ValueError, match="qid_shape must be provided"): cirq.sim.state_vector_simulation_state._BufferedStateVector.create(initial_state=0) + + +def test_should_preserve_initial_state_false_skips_copy() -> None: + state_vector = np.array([1.0, 0.0], dtype=np.complex64) + bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create( + initial_state=state_vector, + qid_shape=(2,), + dtype=np.complex64, + should_preserve_initial_state=False, + ) + assert np.may_share_memory(bsv._state_vector, state_vector) + + +def test_should_preserve_initial_state_true_copies() -> None: + state_vector = np.array([1.0, 0.0], dtype=np.complex64) + bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create( + initial_state=state_vector, + qid_shape=(2,), + dtype=np.complex64, + should_preserve_initial_state=True, + ) + assert not np.may_share_memory(bsv._state_vector, state_vector) + + +def test_should_preserve_initial_state_false_no_eager_buffer() -> None: + bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create( + initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False + ) + assert bsv._raw_buffer is None + + +def test_should_preserve_initial_state_false_lazy_buffer_on_access() -> None: + bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create( + initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False + ) + assert bsv._raw_buffer is None + buf = bsv._buffer + assert bsv._raw_buffer is not None + assert buf.shape == bsv._state_vector.shape + assert buf.dtype == bsv._state_vector.dtype + + +def test_should_preserve_initial_state_false_correct_simulation() -> None: + state_vector = cirq.one_hot(shape=(2, 2), dtype=np.complex64) + args = cirq.StateVectorSimulationState( + qubits=cirq.LineQubit.range(2), + initial_state=state_vector, + dtype=np.complex64, + should_preserve_initial_state=False, + ) + assert args._state._raw_buffer is None + cirq.act_on(cirq.X, args, [cirq.LineQubit(0)]) + assert args._state._raw_buffer is not None + np.testing.assert_allclose( + args.target_tensor, cirq.one_hot(index=(1, 0), shape=(2, 2), dtype=np.complex64) + ) + + +def test_copy_with_unallocated_buffer_preserves_lazy_state() -> None: + bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create( + initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False + ) + copy = bsv.copy() + assert copy._raw_buffer is None + assert not np.may_share_memory(copy._state_vector, bsv._state_vector) + np.testing.assert_array_equal(copy._state_vector, bsv._state_vector) + + +def test_shallow_copy_with_unallocated_buffer() -> None: + bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create( + initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False + ) + copy = bsv.copy(deep_copy_buffers=False) + assert copy._raw_buffer is None + assert not np.may_share_memory(copy._state_vector, bsv._state_vector) + np.testing.assert_array_equal(copy._state_vector, bsv._state_vector) + + +def test_deep_copy_with_allocated_buffer() -> None: + bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create( + initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=True + ) + copy = bsv.copy(deep_copy_buffers=True) + assert copy._raw_buffer is not None + assert not np.may_share_memory(copy._raw_buffer, bsv._raw_buffer) + np.testing.assert_array_equal(copy._raw_buffer, bsv._raw_buffer) diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index 22bfb2ae083..010d03eb8d5 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -139,12 +139,12 @@ def final_state_vector(self) -> np.ndarray: ) return ret # normalize only if doing so improves the round-off on total probability - ret_norm = ret / norm + ret_norm = ret * (1.0 / norm) round_off_change = abs(np.vdot(ret_norm, ret_norm) - 1) - abs(np.vdot(ret, ret) - 1) result = ret_norm if round_off_change < 0 else ret return result - def state_vector(self, copy: bool = False) -> np.ndarray: + def state_vector(self, copy: bool = False, normalize: bool = True) -> np.ndarray: """Return the state vector at the end of the computation. The state is returned in the computational basis with these basis @@ -172,11 +172,19 @@ def state_vector(self, copy: bool = False) -> np.ndarray: Args: copy: If True, the returned state vector will be a copy of that - stored by the object. This is potentially expensive for large - state vectors, but prevents mutation of the object state, e.g. for - operating on intermediate states of a circuit. - Defaults to False. + stored by the object. This is potentially expensive for large + state vectors, but prevents mutation of the object state, e.g. + for operating on intermediate states of a circuit. + Defaults to False. + normalize: If True, the returned state vector is normalized and + the result is cached. If False, the raw target tensor is + returned directly, avoiding the extra allocation in + `final_state_vector`. + Defaults to True. """ + if not normalize: + ret = self._get_merged_sim_state().target_tensor.reshape(-1) + return ret.copy() if copy else ret return self.final_state_vector.copy() if copy else self.final_state_vector def _value_equality_values_(self):