Skip to content

Commit

Permalink
Remove transpilations inside subexperiment generation (#556)
Browse files Browse the repository at this point in the history
* No coverage for inplace

* black

* Convert transpiler tests

* comments

* Add remaining tests

* Ensure bumping lint version fixes CI

* Bump python lint version back down

* peaceiris

* restore peaceiris version

* No parallel sphinx

* Revert tox change

* Don't use reno 4

* revert pyproject

* Don't copy list. Remove inplace

* release note

* Update docstring

* Add zero state test

* Update circuit_knitting/cutting/cutting_experiments.py

Co-authored-by: Jim Garrison <[email protected]>

* Update circuit_knitting/cutting/cutting_experiments.py

Co-authored-by: Jim Garrison <[email protected]>

* Use QC.copy

* Update releasenotes/notes/subexperiment-gen-speedup-41a4e8679353d1d9.yaml

Co-authored-by: Jim Garrison <[email protected]>

* remove deep copies

* type hints

---------

Co-authored-by: Jim Garrison <[email protected]>
  • Loading branch information
caleb-johnson and garrison authored Apr 24, 2024
1 parent 69adfc5 commit f3fd581
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 25 deletions.
110 changes: 86 additions & 24 deletions circuit_knitting/cutting/cutting_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@
import numpy as np
from qiskit.circuit import QuantumCircuit, ClassicalRegister
from qiskit.quantum_info import PauliList
from qiskit.transpiler import PassManager
from qiskit.transpiler.passes import RemoveResetInZeroState, DAGFixedPoint
from qiskit.passmanager.flow_controllers import DoWhileController

from ..utils.iteration import strict_zip
from ..utils.transpiler_passes import RemoveFinalReset, ConsolidateResets
from ..utils.observable_grouping import ObservableCollection, CommutingObservableGroup
from .qpd import (
WeightType,
Expand Down Expand Up @@ -62,12 +58,6 @@ def generate_cutting_experiments(
The coefficients will always be returned as a 1D array -- one coefficient for each unique sample.
Note that this function also runs some transpiler passes on each generated
circuit, namely :class:`~qiskit.transpiler.passes.RemoveResetInZeroState`,
:class:`.RemoveFinalReset`, and :class:`.ConsolidateResets`, in order to
remove unnecessary :class:`~qiskit.circuit.library.Reset`\ s from the
circuit that are added by the subexperiment decompositions for cut wires.
Args:
circuits: The circuit(s) to partition and separate
observables: The observable(s) to evaluate for each unique sample
Expand Down Expand Up @@ -172,20 +162,11 @@ def generate_cutting_experiments(
# https://github.com/Qiskit-Extensions/circuit-knitting-toolbox/issues/452.
# While we are at it, we also consolidate each run of multiple resets
# (which can arise when re-using qubits) into a single reset.
pass_manager = PassManager()
passes = [
RemoveResetInZeroState(),
RemoveFinalReset(),
ConsolidateResets(),
DAGFixedPoint(),
]
pass_manager.append(
DoWhileController(
passes, do_while=lambda property_set: not property_set["dag_fixed_point"]
)
)
for label, subexperiments in subexperiments_dict.items():
subexperiments_dict[label] = pass_manager.run(subexperiments)
for subexperiments in subexperiments_dict.values():
for circ in subexperiments:
_remove_resets_in_zero_state(circ)
_remove_final_resets(circ)
_consolidate_resets(circ)

# If the input was a single quantum circuit, return the subexperiments as a list
subexperiments_out: list[QuantumCircuit] | dict[Hashable, list[QuantumCircuit]] = (
Expand Down Expand Up @@ -389,3 +370,84 @@ def _get_pauli_indices(cog: CommutingObservableGroup) -> list[int]:
if not pauli_indices:
pauli_indices = [0]
return pauli_indices


def _consolidate_resets(
circuit: QuantumCircuit, inplace: bool = True
) -> QuantumCircuit:
"""Consolidate redundant resets into a single reset."""
if not inplace: # pragma: no cover
circuit = circuit.copy()

# Keep up with whether the previous instruction on a given qubit was a reset
resets = [False] * circuit.num_qubits

# Remove resets which are immediately following other resets
remove_ids = []
for i, inst in enumerate(circuit.data):
qargs = [circuit.find_bit(q).index for q in inst.qubits]
if inst.operation.name == "reset":
if resets[qargs[0]]:
remove_ids.append(i)
else:
resets[qargs[0]] = True
else:
for q in qargs:
resets[q] = False

for i in sorted(remove_ids, reverse=True):
del circuit.data[i]

return circuit


def _remove_resets_in_zero_state(
circuit: QuantumCircuit, inplace: bool = True
) -> QuantumCircuit:
"""Remove resets if they are the first instruction on a qubit."""
if not inplace: # pragma: no cover
circuit = circuit.copy()

# Keep up with which qubits have at least one non-reset instruction
active_qubits = set()
remove_ids = []
for i, inst in enumerate(circuit.data):
qargs = [circuit.find_bit(q).index for q in inst.qubits]
if inst.operation.name == "reset":
if qargs[0] not in active_qubits:
remove_ids.append(i)
else:
for q in qargs:
active_qubits.add(q)

for i in sorted(remove_ids, reverse=True):
del circuit.data[i]

return circuit


def _remove_final_resets(
circuit: QuantumCircuit, inplace: bool = True
) -> QuantumCircuit:
"""Remove resets if they are the final instruction on a qubit."""
if not inplace: # pragma: no cover
circuit = circuit.copy()

# Keep up with whether we are at the end of a qubit
# We iterate in reverse, so all qubits begin in the "end" state
qubit_ended = set(range(circuit.num_qubits))
remove_ids = []
num_inst = len(circuit.data)
for i, inst in enumerate(reversed(circuit.data)):
qargs = [circuit.find_bit(q).index for q in inst.qubits]
if inst.operation.name == "reset":
if qargs[0] in qubit_ended:
remove_ids.append(num_inst - 1 - i)
else:
for q in qargs:
qubit_ended.discard(q)

for i in sorted(remove_ids, reverse=True):
del circuit.data[i]

return circuit
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
The :func:`.generate_cutting_experiments` function has been optimized for faster execution.
106 changes: 105 additions & 1 deletion test/cutting/test_cutting_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import numpy as np
from qiskit.quantum_info import PauliList, Pauli
from qiskit.circuit import QuantumCircuit
from qiskit.circuit import QuantumCircuit, QuantumRegister
from qiskit.circuit.library.standard_gates import CXGate

from circuit_knitting.cutting.qpd import (
Expand All @@ -30,6 +30,9 @@
from circuit_knitting.cutting.cutting_experiments import (
_append_measurement_register,
_append_measurement_circuit,
_remove_final_resets,
_consolidate_resets,
_remove_resets_in_zero_state,
)


Expand Down Expand Up @@ -219,3 +222,104 @@ def test_append_measurement_circuit(self):
e_info.value.args[0]
== "Quantum circuit qubit count (2) does not match qubit count of observable(s) (1). Try providing `qubit_locations` explicitly."
)

def test_consolidate_double_reset(self):
"""Consolidate a pair of resets.
qr0:--|0>--|0>-- ==> qr0:--|0>--
"""
qr = QuantumRegister(1, "qr")
circuit = QuantumCircuit(qr)
circuit.reset(qr)
circuit.reset(qr)

expected = QuantumCircuit(qr)
expected.reset(qr)

_consolidate_resets(circuit)

self.assertEqual(expected, circuit)

def test_two_resets(self):
"""Remove two final resets
qr0:--[H]-|0>-|0>-- ==> qr0:--[H]--
"""
qr = QuantumRegister(1, "qr")
circuit = QuantumCircuit(qr)
circuit.h(qr[0])
circuit.reset(qr[0])
circuit.reset(qr[0])

expected = QuantumCircuit(qr)
expected.h(qr[0])

_remove_final_resets(circuit)

self.assertEqual(expected, circuit)

def test_optimize_single_reset_in_diff_qubits(self):
"""Remove a single final reset in different qubits
qr0:--[H]--|0>-- qr0:--[H]--
==>
qr1:--[X]--|0>-- qr1:--[X]----
"""
qr = QuantumRegister(2, "qr")
circuit = QuantumCircuit(qr)
circuit.h(0)
circuit.x(1)
circuit.reset(qr)

expected = QuantumCircuit(qr)
expected.h(0)
expected.x(1)

_remove_final_resets(circuit)
self.assertEqual(expected, circuit)

def test_optimize_single_reset(self):
"""Remove a single final reset
qr0:--[H]--|0>-- ==> qr0:--[H]--
"""
qr = QuantumRegister(1, "qr")
circuit = QuantumCircuit(qr)
circuit.h(0)
circuit.reset(qr)

expected = QuantumCircuit(qr)
expected.h(0)

_remove_final_resets(circuit)

self.assertEqual(expected, circuit)

def test_dont_optimize_non_final_reset(self):
"""Do not remove reset if not final instruction
qr0:--|0>--[H]-- ==> qr0:--|0>--[H]--
"""
qr = QuantumRegister(1, "qr")
circuit = QuantumCircuit(qr)
circuit.reset(qr)
circuit.h(qr)

expected = QuantumCircuit(qr)
expected.reset(qr)
expected.h(qr)

_remove_final_resets(circuit)

self.assertEqual(expected, circuit)

def test_remove_reset_in_zero_state(self):
"""Remove reset if first instruction on qubit
qr0:--|0>--[H]-- ==> qr0:--|0>--[H]--
"""
qr = QuantumRegister(1, "qr")
circuit = QuantumCircuit(qr)
circuit.reset(qr)
circuit.h(qr)

expected = QuantumCircuit(qr)
expected.h(qr)

_remove_resets_in_zero_state(circuit)

self.assertEqual(expected, circuit)

0 comments on commit f3fd581

Please sign in to comment.