diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 98af33acb0d..1734d019f16 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -16,6 +16,7 @@ import itertools import random +import time from unittest import mock import numpy as np @@ -1585,3 +1586,94 @@ def qubits(self): simulator.simulate_sweep(program=circuit, params=params) assert op1.count == 1 assert op2.count == 2 + + +def _gh5916_regression_circuit() -> cirq.Circuit: + """Circuit from gh-5916 that previously drifted trace under re-simulation.""" + q0, q1 = cirq.NamedQubit('q0'), cirq.NamedQubit('q1') + return cirq.Circuit(cirq.CNOT.on(q1, q0), cirq.H.on(q1), cirq.measure(q1)) + + +def _density_matrix_trace(density_matrix: np.ndarray) -> complex: + size = int(np.prod(density_matrix.shape[: len(density_matrix.shape) // 2])) + return np.trace(density_matrix.reshape(size, size)) + + +def _assert_density_matrix_trace_one(density_matrix: np.ndarray, *, atol: float = 1e-6) -> None: + trace = _density_matrix_trace(density_matrix) + assert np.isclose(trace, 1.0, atol=atol), trace + + +def test_density_matrix_trace_after_each_moment_gh5916() -> None: + """Regression for gh-5916: trace must stay normalized after each moment.""" + circuit = _gh5916_regression_circuit() + sim = cirq.DensityMatrixSimulator(dtype=np.complex128) + for step in sim.simulate_moment_steps(circuit): + _assert_density_matrix_trace_one(step.density_matrix()) + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_density_matrix_trace_stable_under_repeated_simulation( + dtype: type[np.complexfloating], +) -> None: + """Regression for gh-5916: reusing final_density_matrix must not drift trace.""" + circuit = _gh5916_regression_circuit() + sim = cirq.DensityMatrixSimulator(dtype=dtype) + initial_state = None + for _ in range(50): + result = sim.simulate(circuit, initial_state=initial_state) + _assert_density_matrix_trace_one(result.final_density_matrix) + initial_state = result.final_density_matrix + + +@pytest.mark.parametrize( + 'circuit', + [ + cirq.Circuit( + cirq.CNOT.on(cirq.LineQubit(1), cirq.LineQubit(0)), + cirq.H(cirq.LineQubit(1)), + cirq.measure(cirq.LineQubit(1)), + ), + cirq.Circuit( + cirq.H(cirq.LineQubit(0)), + cirq.CNOT.on(cirq.LineQubit(0), cirq.LineQubit(1)), + cirq.measure(cirq.LineQubit(0)), + ), + cirq.Circuit( + cirq.H(cirq.LineQubit(0)), + cirq.H(cirq.LineQubit(1)), + cirq.H(cirq.LineQubit(2)), + cirq.CCX(*cirq.LineQubit.range(3)), + cirq.measure(*cirq.LineQubit.range(3)), + ), + ], +) +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_density_matrix_trace_stable_across_circuit_families( + circuit: cirq.Circuit, dtype: type[np.complexfloating] +) -> None: + sim = cirq.DensityMatrixSimulator(dtype=dtype) + initial_state = None + for _ in range(20): + result = sim.simulate(circuit, initial_state=initial_state) + _assert_density_matrix_trace_one(result.final_density_matrix) + for step in sim.simulate_moment_steps(circuit, initial_state=initial_state): + _assert_density_matrix_trace_one(step.density_matrix()) + initial_state = result.final_density_matrix + + +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_density_matrix_trace_repeated_simulation_performance_smoke( + dtype: type[np.complexfloating], +) -> None: + """Smoke test: repeated simulation stays within existing suite runtime scale.""" + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.CNOT(q1, q0), cirq.H(q1), cirq.measure(q1)) + sim = cirq.DensityMatrixSimulator(dtype=dtype) + initial_state = None + start = time.perf_counter() + for _ in range(100): + result = sim.simulate(circuit, initial_state=initial_state) + initial_state = result.final_density_matrix + elapsed = time.perf_counter() - start + assert elapsed < 5.0 diff --git a/cirq-core/cirq/sim/density_matrix_utils.py b/cirq-core/cirq/sim/density_matrix_utils.py index 325a69388a4..b9efb1e4a62 100644 --- a/cirq-core/cirq/sim/density_matrix_utils.py +++ b/cirq-core/cirq/sim/density_matrix_utils.py @@ -176,8 +176,9 @@ def measure_density_matrix( # Reshape to a tensor inplace to set the masked values to 0. arrout.reshape(qid_shape * 2, copy=False)[mask] = 0 - # Renormalize. + # Renormalize to the measured outcome, then correct trace drift from float error. arrout /= probs[result] + _renormalize_density_matrix_trace(arrout, qid_shape) return measurement_bits, arrout @@ -192,6 +193,16 @@ def _probs( return simulation_utils.state_probabilities_by_indices(all_probs.real, indices, qid_shape) +def _renormalize_density_matrix_trace( + density_matrix: np.ndarray, qid_shape: tuple[int, ...] +) -> None: + """Rescale ``density_matrix`` in-place so its trace is 1.""" + size = int(np.prod(qid_shape)) + trace = np.trace(density_matrix.reshape(size, size)) + if not np.isclose(trace, 0): + density_matrix /= trace + + def _validate_density_matrix_qid_shape( density_matrix: np.ndarray, qid_shape: tuple[int, ...] ) -> tuple[int, ...]: