Skip to content

Commit

Permalink
Merge pull request #74 from qiboteam/dmcopy
Browse files Browse the repository at this point in the history
Change `.flatten()` to `.ravel()` to avoid copies
  • Loading branch information
scarrazza authored Mar 16, 2022
2 parents d266b0d + c02e38e commit d3ff74e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions src/qibojit/custom_operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ def _density_matrix_call(self, gate, state):
qubits = gate.cache.qubits_tensor + gate.nqubits
shape = state.shape
gate_op = self.get_gate_op(gate)
state = gate_op(state.flatten(), 2 * gate.nqubits, gate.target_qubits, qubits)
state = gate_op(state.ravel(), 2 * gate.nqubits, gate.target_qubits, qubits)
state = gate_op(state, 2 * gate.nqubits, gate.cache.target_qubits_dm, gate.cache.qubits_tensor)
return self.reshape(state, shape)

def density_matrix_matrix_call(self, gate, state):
qubits = gate.cache.qubits_tensor + gate.nqubits
shape = state.shape
gate_op = self.get_gate_op(gate)
state = gate_op(state.flatten(), gate.custom_op_matrix, 2 * gate.nqubits, gate.target_qubits, qubits)
state = gate_op(state.ravel(), gate.custom_op_matrix, 2 * gate.nqubits, gate.target_qubits, qubits)
adjmatrix = self.conj(gate.custom_op_matrix)
state = gate_op(state, adjmatrix, 2 * gate.nqubits, gate.cache.target_qubits_dm, gate.cache.qubits_tensor)
return self.reshape(state, shape)
Expand All @@ -287,14 +287,14 @@ def _density_matrix_half_call(self, gate, state):
qubits = gate.cache.qubits_tensor + gate.nqubits
shape = state.shape
gate_op = self.get_gate_op(gate)
state = gate_op(state.flatten(), 2 * gate.nqubits, gate.target_qubits, qubits)
state = gate_op(state.ravel(), 2 * gate.nqubits, gate.target_qubits, qubits)
return self.reshape(state, shape)

def density_matrix_half_matrix_call(self, gate, state):
qubits = gate.cache.qubits_tensor + gate.nqubits
shape = state.shape
gate_op = self.get_gate_op(gate)
state = gate_op(state.flatten(), gate.custom_op_matrix, 2 * gate.nqubits, gate.target_qubits, qubits)
state = gate_op(state.ravel(), gate.custom_op_matrix, 2 * gate.nqubits, gate.target_qubits, qubits)
return self.reshape(state, shape)

def _result_tensor(self, result):
Expand All @@ -309,7 +309,7 @@ def density_matrix_collapse(self, gate, state, result):
result = self._result_tensor(result)
qubits = gate.cache.qubits_tensor + gate.nqubits
shape = state.shape
state = self.collapse_state(state.flatten(), qubits, result, 2 * gate.nqubits, False)
state = self.collapse_state(state.ravel(), qubits, result, 2 * gate.nqubits, False)
state = self.collapse_state(state, gate.cache.qubits_tensor, result, 2 * gate.nqubits, False)
state = self.reshape(state, shape)
return state / self.trace(state)
Expand Down
6 changes: 3 additions & 3 deletions src/qibojit/custom_operators/platforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def one_qubit_base(self, state, nqubits, target, kernel, gate, qubits=None):
if kernel in ("apply_x", "apply_y", "apply_z"):
args = (state, tk, m)
else:
args = (state, tk, m, self.cast(gate, dtype=state.dtype).flatten())
args = (state, tk, m, self.cast(gate, dtype=state.dtype).ravel())

ktype = self.get_kernel_type(state)
if ncontrols:
Expand Down Expand Up @@ -366,7 +366,7 @@ def two_qubit_base(self, state, nqubits, target1, target2, kernel, gate, qubits=
if kernel == "apply_swap":
args = (state, tk1, tk2, m1, m2, uk1, uk2)
else:
args = (state, tk1, tk2, m1, m2, uk1, uk2, self.cast(gate).flatten())
args = (state, tk1, tk2, m1, m2, uk1, uk2, self.cast(gate).ravel())
assert state.dtype == args[-1].dtype

ktype = self.get_kernel_type(state)
Expand All @@ -384,7 +384,7 @@ def two_qubit_base(self, state, nqubits, target1, target2, kernel, gate, qubits=
def multi_qubit_base(self, state, nqubits, targets, gate, qubits=None):
assert gate is not None
state = self.cast(state)
gate = self.cast(gate.flatten())
gate = self.cast(gate.ravel())
assert state.dtype == gate.dtype

ntargets = len(targets)
Expand Down

0 comments on commit d3ff74e

Please sign in to comment.