Skip to content
49 changes: 40 additions & 9 deletions cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,41 @@
class _BufferedStateVector(qis.QuantumStateRepresentation):
"""Contains the state vector and buffer for efficient state evolution."""

def __init__(self, state_vector: np.ndarray, buffer: np.ndarray | None = None):
def __init__(
self,
state_vector: np.ndarray,
buffer: np.ndarray | None = None,
should_preserve_initial_state: bool = True,
):
"""Initializes the object with the inputs.

This initializer creates the buffer if necessary.

Args:
state_vector: The state vector, must be correctly formatted. The data is not checked
for validity here due to performance concerns.
buffer: Optional, must be same shape as the state vector. If not provided, a buffer
will be created automatically.
buffer: Optional, must be same shape as the state vector. If not provided and
should_preserve_initial_state is True, a buffer will be created automatically.
If should_preserve_initial_state is False, the buffer is allocated lazily on
first use.
should_preserve_initial_state: If False, skips eager buffer allocation to avoid
the memory overhead of an extra state-vector-sized array.
"""
self._state_vector = state_vector
if buffer is None:
buffer = np.empty_like(state_vector)
self._buffer = buffer
self._qid_shape = state_vector.shape
self._raw_buffer: np.ndarray | None = buffer
if self._raw_buffer is None and should_preserve_initial_state:
self._raw_buffer = np.empty_like(state_vector)

@property
def _buffer(self) -> np.ndarray:
if self._raw_buffer is None:
self._raw_buffer = np.empty_like(self._state_vector)
return self._raw_buffer

@_buffer.setter
def _buffer(self, value: np.ndarray) -> None:
self._raw_buffer = value

@classmethod
def create(
Expand All @@ -58,6 +77,7 @@ def create(
qid_shape: tuple[int, ...] | None = None,
dtype: type[np.complexfloating] | np.dtype[np.complexfloating] | None = None,
buffer: np.ndarray | None = None,
should_preserve_initial_state: bool = True,
):
"""Initializes the object with the inputs.

Expand All @@ -70,6 +90,8 @@ def create(
dtype: The dtype of the state vector, if the initial state is provided as an int.
buffer: Optional, must be length 3 and same shape as the state vector. If not
provided, a buffer will be created automatically.
should_preserve_initial_state: If False, skips copying the initial state and defers
buffer allocation to first use, reducing memory overhead for large state vectors.
Raises:
ValueError: If initial state is provided as integer, but qid_shape is not provided.
"""
Expand All @@ -84,10 +106,10 @@ def create(
state_vector = initial_state.reshape(qid_shape)
else:
state_vector = initial_state
if np.may_share_memory(state_vector, initial_state):
if should_preserve_initial_state and np.may_share_memory(state_vector, initial_state):
state_vector = state_vector.copy()
state_vector = state_vector.astype(dtype, copy=False)
return cls(state_vector, buffer)
return cls(state_vector, buffer, should_preserve_initial_state)

def copy(self, deep_copy_buffers: bool = True) -> _BufferedStateVector:
"""Copies the object.
Expand All @@ -97,9 +119,16 @@ def copy(self, deep_copy_buffers: bool = True) -> _BufferedStateVector:
Returns:
A copy of the object.
"""
preserve = self._raw_buffer is not None
buf: np.ndarray | None
if self._raw_buffer is not None and deep_copy_buffers:
buf = self._raw_buffer.copy()
else:
buf = self._raw_buffer
return _BufferedStateVector(
state_vector=self._state_vector.copy(),
buffer=self._buffer.copy() if deep_copy_buffers else self._buffer,
buffer=buf,
should_preserve_initial_state=preserve,
)

def kron(self, other: _BufferedStateVector) -> _BufferedStateVector:
Expand Down Expand Up @@ -327,6 +356,7 @@ def __init__(
initial_state: np.ndarray | cirq.STATE_VECTOR_LIKE = 0,
dtype: type[np.complexfloating] | np.dtype[np.complexfloating] = np.complex64,
classical_data: cirq.ClassicalDataStore | None = None,
should_preserve_initial_state: bool = True,
):
"""Inits StateVectorSimulationState.

Expand Down Expand Up @@ -354,6 +384,7 @@ def __init__(
qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None,
dtype=dtype,
buffer=available_buffer,
should_preserve_initial_state=should_preserve_initial_state,
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

Expand Down
86 changes: 86 additions & 0 deletions cirq-core/cirq/sim/state_vector_simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,89 @@ def test_measured_mixture() -> None:
def test_qid_shape_error() -> None:
with pytest.raises(ValueError, match="qid_shape must be provided"):
cirq.sim.state_vector_simulation_state._BufferedStateVector.create(initial_state=0)


def test_should_preserve_initial_state_false_skips_copy() -> None:
state_vector = np.array([1.0, 0.0], dtype=np.complex64)
bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create(
initial_state=state_vector,
qid_shape=(2,),
dtype=np.complex64,
should_preserve_initial_state=False,
)
assert np.may_share_memory(bsv._state_vector, state_vector)


def test_should_preserve_initial_state_true_copies() -> None:
state_vector = np.array([1.0, 0.0], dtype=np.complex64)
bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create(
initial_state=state_vector,
qid_shape=(2,),
dtype=np.complex64,
should_preserve_initial_state=True,
)
assert not np.may_share_memory(bsv._state_vector, state_vector)


def test_should_preserve_initial_state_false_no_eager_buffer() -> None:
bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create(
initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False
)
assert bsv._raw_buffer is None


def test_should_preserve_initial_state_false_lazy_buffer_on_access() -> None:
bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create(
initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False
)
assert bsv._raw_buffer is None
buf = bsv._buffer
assert bsv._raw_buffer is not None
assert buf.shape == bsv._state_vector.shape
assert buf.dtype == bsv._state_vector.dtype


def test_should_preserve_initial_state_false_correct_simulation() -> None:
state_vector = cirq.one_hot(shape=(2, 2), dtype=np.complex64)
args = cirq.StateVectorSimulationState(
qubits=cirq.LineQubit.range(2),
initial_state=state_vector,
dtype=np.complex64,
should_preserve_initial_state=False,
)
assert args._state._raw_buffer is None
cirq.act_on(cirq.X, args, [cirq.LineQubit(0)])
assert args._state._raw_buffer is not None
np.testing.assert_allclose(
args.target_tensor, cirq.one_hot(index=(1, 0), shape=(2, 2), dtype=np.complex64)
)


def test_copy_with_unallocated_buffer_preserves_lazy_state() -> None:
bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create(
initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False
)
copy = bsv.copy()
assert copy._raw_buffer is None
assert not np.may_share_memory(copy._state_vector, bsv._state_vector)
np.testing.assert_array_equal(copy._state_vector, bsv._state_vector)


def test_shallow_copy_with_unallocated_buffer() -> None:
bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create(
initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=False
)
copy = bsv.copy(deep_copy_buffers=False)
assert copy._raw_buffer is None
assert not np.may_share_memory(copy._state_vector, bsv._state_vector)
np.testing.assert_array_equal(copy._state_vector, bsv._state_vector)


def test_deep_copy_with_allocated_buffer() -> None:
bsv = cirq.sim.state_vector_simulation_state._BufferedStateVector.create(
initial_state=0, qid_shape=(2,), dtype=np.complex64, should_preserve_initial_state=True
)
copy = bsv.copy(deep_copy_buffers=True)
assert copy._raw_buffer is not None
assert not np.may_share_memory(copy._raw_buffer, bsv._raw_buffer)
np.testing.assert_array_equal(copy._raw_buffer, bsv._raw_buffer)
20 changes: 14 additions & 6 deletions cirq-core/cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ def final_state_vector(self) -> np.ndarray:
)
return ret
# normalize only if doing so improves the round-off on total probability
ret_norm = ret / norm
ret_norm = ret * (1.0 / norm)
round_off_change = abs(np.vdot(ret_norm, ret_norm) - 1) - abs(np.vdot(ret, ret) - 1)
result = ret_norm if round_off_change < 0 else ret
return result

def state_vector(self, copy: bool = False) -> np.ndarray:
def state_vector(self, copy: bool = False, normalize: bool = True) -> np.ndarray:
"""Return the state vector at the end of the computation.

The state is returned in the computational basis with these basis
Expand Down Expand Up @@ -172,11 +172,19 @@ def state_vector(self, copy: bool = False) -> np.ndarray:

Args:
copy: If True, the returned state vector will be a copy of that
stored by the object. This is potentially expensive for large
state vectors, but prevents mutation of the object state, e.g. for
operating on intermediate states of a circuit.
Defaults to False.
stored by the object. This is potentially expensive for large
state vectors, but prevents mutation of the object state, e.g.
for operating on intermediate states of a circuit.
Defaults to False.
normalize: If True, the returned state vector is normalized and
the result is cached. If False, the raw target tensor is
returned directly, avoiding the extra allocation in
`final_state_vector`.
Defaults to True.
"""
if not normalize:
ret = self._get_merged_sim_state().target_tensor.reshape(-1)
return ret.copy() if copy else ret
return self.final_state_vector.copy() if copy else self.final_state_vector

def _value_equality_values_(self):
Expand Down
Loading