diff --git a/cirq-core/cirq/sim/simulation_product_state.py b/cirq-core/cirq/sim/simulation_product_state.py index c4421b41ec2..78ec5ceb60d 100644 --- a/cirq-core/cirq/sim/simulation_product_state.py +++ b/cirq-core/cirq/sim/simulation_product_state.py @@ -63,12 +63,19 @@ def split_untangled_states(self) -> bool: return self._split_untangled_states def create_merged_state(self) -> TSimulationState: + merged_state = self.sim_states[None] if not self.split_untangled_states: - return self.sim_states[None] - final_args = self.sim_states[None] - for args in set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]): - final_args = final_args.kronecker_product(args) - return final_args.transpose_to_qubit_order(self.qubits) + return merged_state + extra_states = set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]) + if not extra_states: + return merged_state + + # This comes from a member variable so we need to copy it if we're going to modify inplace + # before returning. We're not running a step currently, so no need to copy buffers. + merged_state = merged_state.copy(deep_copy_buffers=False) + for state in extra_states: + merged_state.kronecker_product(state, inplace=True) + return merged_state.transpose_to_qubit_order(self.qubits, inplace=True) def _act_on_fallback_( self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True @@ -106,7 +113,7 @@ def _act_on_fallback_( if op_args_opt is None: op_args_opt = self.sim_states[q] elif q not in op_args_opt.qubits: - op_args_opt = op_args_opt.kronecker_product(self.sim_states[q]) + op_args_opt.kronecker_product(self.sim_states[q], inplace=True) op_args = op_args_opt or self.sim_states[None] # (Backfill the args map with the new value) @@ -123,7 +130,7 @@ def _act_on_fallback_( ): for q in qubits: if op_args.allows_factoring and len(op_args.qubits) > 1: - q_args, op_args = op_args.factor((q,), validate=False) + q_args, _ = op_args.factor((q,), validate=False, inplace=True) self._sim_states[q] = q_args # (Backfill the args map with the new value)