Skip to content

Commit

Permalink
Run product state functions inplace to avoid copies where possible (#…
Browse files Browse the repository at this point in the history
…6396)

* Run product state merges inplace to avoid copies

* rename to merged_state

* comment
  • Loading branch information
daxfohl authored Feb 3, 2024
1 parent e9e12ee commit 5dd05bf
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions cirq-core/cirq/sim/simulation_product_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,19 @@ def split_untangled_states(self) -> bool:
return self._split_untangled_states

def create_merged_state(self) -> TSimulationState:
merged_state = self.sim_states[None]
if not self.split_untangled_states:
return self.sim_states[None]
final_args = self.sim_states[None]
for args in set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]):
final_args = final_args.kronecker_product(args)
return final_args.transpose_to_qubit_order(self.qubits)
return merged_state
extra_states = set([self.sim_states[k] for k in self.sim_states.keys() if k is not None])
if not extra_states:
return merged_state

# This comes from a member variable so we need to copy it if we're going to modify inplace
# before returning. We're not running a step currently, so no need to copy buffers.
merged_state = merged_state.copy(deep_copy_buffers=False)
for state in extra_states:
merged_state.kronecker_product(state, inplace=True)
return merged_state.transpose_to_qubit_order(self.qubits, inplace=True)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
Expand Down Expand Up @@ -106,7 +113,7 @@ def _act_on_fallback_(
if op_args_opt is None:
op_args_opt = self.sim_states[q]
elif q not in op_args_opt.qubits:
op_args_opt = op_args_opt.kronecker_product(self.sim_states[q])
op_args_opt.kronecker_product(self.sim_states[q], inplace=True)
op_args = op_args_opt or self.sim_states[None]

# (Backfill the args map with the new value)
Expand All @@ -123,7 +130,7 @@ def _act_on_fallback_(
):
for q in qubits:
if op_args.allows_factoring and len(op_args.qubits) > 1:
q_args, op_args = op_args.factor((q,), validate=False)
q_args, _ = op_args.factor((q,), validate=False, inplace=True)
self._sim_states[q] = q_args

# (Backfill the args map with the new value)
Expand Down

0 comments on commit 5dd05bf

Please sign in to comment.