Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option for skipping OptimizeSwapBeforeMeasure when not all the wires are measured. #5890

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ class OptimizeSwapBeforeMeasure(TransformationPass):

Transpiler pass to remove swaps in front of measurements by re-targeting
the classical bit of the measure instruction.

If `all_measurement` is `True` (default is `False`)`, the SWAP to be removed
has to be measure on both wires. Otherwise, it stays.
"""
def __init__(self, all_measurement=False):
self.all_measurement = all_measurement
super().__init__()

def run(self, dag):
"""Run the OptimizeSwapBeforeMeasure pass on `dag`.
Expand All @@ -39,8 +45,10 @@ def run(self, dag):
for swap in swaps[::-1]:
final_successor = []
for successor in dag.successors(swap):
final_successor.append(successor.type == 'out' or (successor.type == 'op' and
successor.op.name == 'measure'))
is_final_successor = successor.type == 'op' and successor.op.name == 'measure'
if not self.all_measurement:
is_final_successor = is_final_successor or successor.type == 'out'
final_successor.append(is_final_successor)
if all(final_successor):
# the node swap needs to be removed and, if a measure follows, needs to be adapted
swap_qargs = swap.qargs
Expand Down
2 changes: 1 addition & 1 deletion qiskit/transpiler/preset_passmanagers/level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _opt_control(property_set):

_reset = [RemoveResetInZeroState()]

_meas = [OptimizeSwapBeforeMeasure(), RemoveDiagonalGatesBeforeMeasure()]
_meas = [OptimizeSwapBeforeMeasure(all_measurement=True), RemoveDiagonalGatesBeforeMeasure()]

_opt = [
Collect2qBlocks(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
Fixes #4911. Level 3 now calls
:class:`qiskit.transpiler.passes.optimization.optimize_swap_before_measure.OptimizeSwapBeforeMeasure`
with a special parameter to avoid removing SWAP gates when not followed by measurment on both wires.
80 changes: 80 additions & 0 deletions test/python/transpiler/test_optimize_swap_before_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,86 @@ def test_optimize_overlap_swap(self):

self.assertEqual(expected, after)

def test_all_measurement(self):
"""OptimizeSwapBeforeMeasure(all_measurement=True) on total measurment
qr0:--X-----m--
| |
qr1:--X--m--|--
| |
cr :-----0--1--
"""
qr = QuantumRegister(2, 'qr')
cr = ClassicalRegister(2, 'cr')
circuit = QuantumCircuit(qr, cr)
circuit.swap(qr[0], qr[1])
circuit.measure(qr[1], cr[0])
circuit.measure(qr[0], cr[1])

expected = QuantumCircuit(qr, cr)
expected.measure(qr[0], cr[0])
expected.measure(qr[1], cr[1])

pass_manager = PassManager()
pass_manager.append(
[OptimizeSwapBeforeMeasure(all_measurement=True), DAGFixedPoint()],
do_while=lambda property_set: not property_set['dag_fixed_point'])
after = pass_manager.run(circuit)

self.assertEqual(expected, after)

def test_all_measurement_skip(self):
"""OptimizeSwapBeforeMeasure(all_measurement=True) on no total measurment
qr0:--X-----
|
qr1:--X--m--
|
cr0:-----.--
"""
qr = QuantumRegister(2, 'qr')
cr = ClassicalRegister(1, 'cr')
circuit = QuantumCircuit(qr, cr)
circuit.swap(qr[0], qr[1])
circuit.measure(qr[1], cr[0])

pass_manager = PassManager()
pass_manager.append(
[OptimizeSwapBeforeMeasure(all_measurement=True), DAGFixedPoint()],
do_while=lambda property_set: not property_set['dag_fixed_point'])
after = pass_manager.run(circuit)

self.assertEqual(circuit, after)

def test_all_measurement_mixed(self):
"""OptimizeSwapBeforeMeasure(all_measurement=True) on mixed measurment
qr0:--X-----------
|
qr1:--X--X-----m--
| |
qr2:-----X--m--|--
| |
cr :--------0--1--
"""
qr = QuantumRegister(3, 'qr')
cr = ClassicalRegister(2, 'cr')
circuit = QuantumCircuit(qr, cr)
circuit.swap(qr[0], qr[1])
circuit.swap(qr[1], qr[2])
circuit.measure(qr[2], cr[0])
circuit.measure(qr[1], cr[1])

expected = QuantumCircuit(qr, cr)
expected.swap(qr[0], qr[1])
expected.measure(qr[1], cr[0])
expected.measure(qr[2], cr[1])

pass_manager = PassManager()
pass_manager.append(
[OptimizeSwapBeforeMeasure(all_measurement=True), DAGFixedPoint()],
do_while=lambda property_set: not property_set['dag_fixed_point'])
after = pass_manager.run(circuit)

self.assertEqual(expected, after)


if __name__ == '__main__':
unittest.main()