From f28f3835fb9e7d92b23a15eacceccf3d59dc8af4 Mon Sep 17 00:00:00 2001 From: ewinston Date: Tue, 21 Jun 2022 13:40:00 -0400 Subject: [PATCH] support control flow in BasisTranslator pass (#7808) * create two tests and 1st modification of unroller * if_else test, parameter test * black * linting * change shallow copy of control flow ops to not copy body * add special copy * debug * add `replace_blocks` method * minor update * clean debug code * linting fix bugs * minor commit * linting * don't recurse on run * linting * don't mutate basis in "_update_basis". * Update qiskit/transpiler/passes/basis/basis_translator.py Co-authored-by: Jake Lishman * apply_translation returns bool * factor out "replace_node" function * linting * singledispatchmethod -> singledispatch for python 3.7 * black * fix indentation bug * linting * black * changed _get_example_gates following @jakelishman suggestion. Co-authored-by: Jake Lishman Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- qiskit/circuit/controlflow/while_loop.py | 2 - qiskit/circuit/quantumcircuit.py | 17 ++ qiskit/dagcircuit/dagcircuit.py | 1 - .../passes/basis/basis_translator.py | 248 +++++++++++------- .../transpiler/test_basis_translator.py | 93 ++++++- 5 files changed, 256 insertions(+), 105 deletions(-) diff --git a/qiskit/circuit/controlflow/while_loop.py b/qiskit/circuit/controlflow/while_loop.py index d920ce565fc8..bc5c30973087 100644 --- a/qiskit/circuit/controlflow/while_loop.py +++ b/qiskit/circuit/controlflow/while_loop.py @@ -97,8 +97,6 @@ def blocks(self): def replace_blocks(self, blocks): (body,) = blocks - if not isinstance(body, QuantumCircuit): - raise CircuitError("WhileLoopOp expects a single QuantumCircuit when setting blocks") return WhileLoopOp(self.condition, body, label=self.label) def c_if(self, classical, val): diff --git a/qiskit/circuit/quantumcircuit.py b/qiskit/circuit/quantumcircuit.py index 628063aadf8f..58adf5921217 100644 --- a/qiskit/circuit/quantumcircuit.py +++ b/qiskit/circuit/quantumcircuit.py @@ -362,6 +362,23 @@ def calibrations(self, calibrations: dict): """ self._calibrations = defaultdict(dict, calibrations) + def has_calibration_for(self, instr_context: Tuple): + """Return True if the circuit has a calibration defined for the instruction context. In this + case, the operation does not need to be translated to the device basis. + """ + instr, qargs, _ = instr_context + if not self.calibrations or instr.name not in self.calibrations: + return False + qubits = tuple(self.qubits.index(qubit) for qubit in qargs) + params = [] + for p in instr.params: + if isinstance(p, ParameterExpression) and not p.parameters: + params.append(float(p)) + else: + params.append(p) + params = tuple(params) + return (qubits, params) in self.calibrations[instr.name] + @property def metadata(self) -> dict: """The user provided metadata associated with the circuit diff --git a/qiskit/dagcircuit/dagcircuit.py b/qiskit/dagcircuit/dagcircuit.py index bec20973f3bf..c31d47f0739e 100644 --- a/qiskit/dagcircuit/dagcircuit.py +++ b/qiskit/dagcircuit/dagcircuit.py @@ -964,7 +964,6 @@ def __eq__(self, other): # Try to convert to float, but in case of unbound ParameterExpressions # a TypeError will be raise, fallback to normal equality in those # cases - try: self_phase = float(self.global_phase) other_phase = float(other.global_phase) diff --git a/qiskit/transpiler/passes/basis/basis_translator.py b/qiskit/transpiler/passes/basis/basis_translator.py index a379c8249ae3..aea47b1dd226 100644 --- a/qiskit/transpiler/passes/basis/basis_translator.py +++ b/qiskit/transpiler/passes/basis/basis_translator.py @@ -17,12 +17,14 @@ from itertools import zip_longest from collections import defaultdict +from functools import singledispatch import retworkx -from qiskit.circuit import Gate, ParameterVector, QuantumRegister -from qiskit.circuit.equivalence import Key +from qiskit.circuit import Gate, ParameterVector, QuantumRegister, ControlFlowOp, QuantumCircuit from qiskit.dagcircuit import DAGCircuit +from qiskit.converters import circuit_to_dag, dag_to_circuit +from qiskit.circuit.equivalence import Key from qiskit.transpiler.basepasses import TransformationPass from qiskit.transpiler.exceptions import TranspilerError @@ -119,36 +121,12 @@ def run(self, dag): if self._target is None: basic_instrs = ["measure", "reset", "barrier", "snapshot", "delay"] target_basis = set(self._target_basis) - source_basis = set() - for node in dag.op_nodes(): - if not dag.has_calibration_for(node): - source_basis.add((node.name, node.op.num_qubits)) + source_basis = set(_extract_basis(dag)) qargs_local_source_basis = {} else: basic_instrs = ["barrier", "snapshot"] - source_basis = set() target_basis = self._target.keys() - set(self._non_global_operations) - qargs_local_source_basis = defaultdict(set) - for node in dag.op_nodes(): - qargs = tuple(qarg_indices[bit] for bit in node.qargs) - if dag.has_calibration_for(node): - continue - # Treat the instruction as on an incomplete basis if the qargs are in the - # qargs_with_non_global_operation dictionary or if any of the qubits in qargs - # are a superset for a non-local operation. For example, if the qargs - # are (0, 1) and that's a global (ie no non-local operations on (0, 1) - # operation but there is a non-local operation on (1,) we need to - # do an extra non-local search for this op to ensure we include any - # single qubit operation for (1,) as valid. This pattern also holds - # true for > 2q ops too (so for 4q operations we need to check for 3q, 2q, - # and 1q operations in the same manner) - if qargs in self._qargs_with_non_global_operation or any( - frozenset(qargs).issuperset(incomplete_qargs) - for incomplete_qargs in self._qargs_with_non_global_operation - ): - qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits)) - else: - source_basis.add((node.name, node.op.num_qubits)) + source_basis, qargs_local_source_basis = self._extract_basis_target(dag, qarg_indices) target_basis = set(target_basis).union(basic_instrs) @@ -225,68 +203,43 @@ def run(self, dag): # Replace source instructions with target translations. replace_start_time = time.time() - for node in dag.op_nodes(): - node_qargs = tuple(qarg_indices[bit] for bit in node.qargs) - qubit_set = frozenset(node_qargs) - - if node.name in target_basis: - continue - if ( - node_qargs in self._qargs_with_non_global_operation - and node.name in self._qargs_with_non_global_operation[node_qargs] - ): - continue - if dag.has_calibration_for(node): - continue - - def replace_node(node, instr_map): - target_params, target_dag = instr_map[node.op.name, node.op.num_qubits] - if len(node.op.params) != len(target_params): - raise TranspilerError( - "Translation num_params not equal to op num_params." - "Op: {} {} Translation: {}\n{}".format( - node.op.params, node.op.name, target_params, target_dag - ) - ) - - if node.op.params: - # Convert target to circ and back to assign_parameters, since - # DAGCircuits won't have a ParameterTable. - from qiskit.converters import dag_to_circuit, circuit_to_dag - - target_circuit = dag_to_circuit(target_dag) - - target_circuit.assign_parameters( - dict(zip_longest(target_params, node.op.params)), inplace=True - ) + def apply_translation(dag): + dag_updated = False + for node in dag.op_nodes(): + node_qargs = tuple(qarg_indices[bit] for bit in node.qargs) + qubit_set = frozenset(node_qargs) + if node.name in target_basis: + if isinstance(node.op, ControlFlowOp): + flow_blocks = [] + for block in node.op.blocks: + dag_block = circuit_to_dag(block) + dag_updated = apply_translation(dag_block) + if dag_updated: + flow_circ_block = dag_to_circuit(dag_block) + else: + flow_circ_block = block + flow_blocks.append(flow_circ_block) + node.op = node.op.replace_blocks(flow_blocks) + continue + if ( + node_qargs in self._qargs_with_non_global_operation + and node.name in self._qargs_with_non_global_operation[node_qargs] + ): + continue - bound_target_dag = circuit_to_dag(target_circuit) - else: - bound_target_dag = target_dag - - if len(bound_target_dag.op_nodes()) == 1 and len( - bound_target_dag.op_nodes()[0].qargs - ) == len(node.qargs): - dag_op = bound_target_dag.op_nodes()[0].op - # dag_op may be the same instance as other ops in the dag, - # so if there is a condition, need to copy - if node.op.condition: - dag_op = dag_op.copy() - dag.substitute_node(node, dag_op, inplace=True) - - if bound_target_dag.global_phase: - dag.global_phase += bound_target_dag.global_phase + if dag.has_calibration_for(node): + continue + if qubit_set in extra_instr_map: + self._replace_node(dag, node, extra_instr_map[qubit_set]) + elif (node.op.name, node.op.num_qubits) in instr_map: + self._replace_node(dag, node, instr_map) else: - dag.substitute_node_with_dag(node, bound_target_dag) - - if qubit_set in extra_instr_map: - replace_node(node, extra_instr_map[qubit_set]) - elif (node.op.name, node.op.num_qubits) in instr_map: - replace_node(node, instr_map) - else: - raise TranspilerError(f"BasisTranslator did not map {node.name}.") + raise TranspilerError(f"BasisTranslator did not map {node.name}.") + dag_updated = True + return dag_updated + apply_translation(dag) replace_end_time = time.time() logger.info( "Basis translation instructions replaced in %.3fs.", @@ -295,6 +248,110 @@ def replace_node(node, instr_map): return dag + def _replace_node(self, dag, node, instr_map): + target_params, target_dag = instr_map[node.op.name, node.op.num_qubits] + if len(node.op.params) != len(target_params): + raise TranspilerError( + "Translation num_params not equal to op num_params." + "Op: {} {} Translation: {}\n{}".format( + node.op.params, node.op.name, target_params, target_dag + ) + ) + + if node.op.params: + # Convert target to circ and back to assign_parameters, since + # DAGCircuits won't have a ParameterTable. + target_circuit = dag_to_circuit(target_dag) + + target_circuit.assign_parameters( + dict(zip_longest(target_params, node.op.params)), inplace=True + ) + + bound_target_dag = circuit_to_dag(target_circuit) + else: + bound_target_dag = target_dag + + if len(bound_target_dag.op_nodes()) == 1 and len( + bound_target_dag.op_nodes()[0].qargs + ) == len(node.qargs): + dag_op = bound_target_dag.op_nodes()[0].op + # dag_op may be the same instance as other ops in the dag, + # so if there is a condition, need to copy + if node.op.condition: + dag_op = dag_op.copy() + dag.substitute_node(node, dag_op, inplace=True) + + if bound_target_dag.global_phase: + dag.global_phase += bound_target_dag.global_phase + else: + dag.substitute_node_with_dag(node, bound_target_dag) + + def _extract_basis_target( + self, dag, qarg_indices, source_basis=None, qargs_local_source_basis=None + ): + if source_basis is None: + source_basis = set() + if qargs_local_source_basis is None: + qargs_local_source_basis = defaultdict(set) + for node in dag.op_nodes(): + qargs = tuple(qarg_indices[bit] for bit in node.qargs) + if dag.has_calibration_for(node): + continue + # Treat the instruction as on an incomplete basis if the qargs are in the + # qargs_with_non_global_operation dictionary or if any of the qubits in qargs + # are a superset for a non-local operation. For example, if the qargs + # are (0, 1) and that's a global (ie no non-local operations on (0, 1) + # operation but there is a non-local operation on (1,) we need to + # do an extra non-local search for this op to ensure we include any + # single qubit operation for (1,) as valid. This pattern also holds + # true for > 2q ops too (so for 4q operations we need to check for 3q, 2q, + # and 1q operations in the same manner) + if qargs in self._qargs_with_non_global_operation or any( + frozenset(qargs).issuperset(incomplete_qargs) + for incomplete_qargs in self._qargs_with_non_global_operation + ): + qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits)) + else: + source_basis.add((node.name, node.op.num_qubits)) + if isinstance(node.op, ControlFlowOp): + for block in node.op.blocks: + block_dag = circuit_to_dag(block) + source_basis, qargs_local_source_basis = self._extract_basis_target( + block_dag, + qarg_indices, + source_basis=source_basis, + qargs_local_source_basis=qargs_local_source_basis, + ) + return source_basis, qargs_local_source_basis + + +# this could be singledispatchmethod and included in above class when minimum +# supported python version=3.8. +@singledispatch +def _extract_basis(circuit): + return circuit + + +@_extract_basis.register +def _(dag: DAGCircuit): + for node in dag.op_nodes(): + if not dag.has_calibration_for(node): + yield (node.name, node.op.num_qubits) + if isinstance(node.op, ControlFlowOp): + for block in node.op.blocks: + yield from _extract_basis(block) + + +@_extract_basis.register +def _(circ: QuantumCircuit): + for instr_context in circ.data: + instr, _, _ = instr_context + if not circ.has_calibration_for(instr_context): + yield (instr.name, instr.num_qubits) + if isinstance(instr, ControlFlowOp): + for block in instr.blocks: + yield from _extract_basis(block) + class StopIfBasisRewritable(Exception): """Custom exception that signals `retworkx.dijkstra_search` to stop.""" @@ -486,8 +543,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag): source_basis but not affected by basis_transforms will be included as a key mapping to itself. """ - - example_gates = {(node.op.name, node.op.num_qubits): node.op for node in source_dag.op_nodes()} + example_gates = _get_example_gates(source_dag) mapped_instrs = {} for gate_name, gate_num_qubits in source_basis: @@ -523,7 +579,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag): ] if doomed_nodes and logger.isEnabledFor(logging.DEBUG): - from qiskit.converters import dag_to_circuit logger.debug( "Updating transform for mapped instr %s %s from \n%s", @@ -533,7 +588,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag): ) for node in doomed_nodes: - from qiskit.converters import circuit_to_dag replacement = equiv.assign_parameters( dict(zip_longest(equiv_params, node.op.params)) @@ -544,7 +598,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag): dag.substitute_node_with_dag(node, replacement_dag) if doomed_nodes and logger.isEnabledFor(logging.DEBUG): - from qiskit.converters import dag_to_circuit logger.debug( "Updated transform for mapped instr %s %s to\n%s", @@ -554,3 +607,16 @@ def _compose_transforms(basis_transforms, source_basis, source_dag): ) return mapped_instrs + + +def _get_example_gates(source_dag): + def recurse(dag, example_gates=None): + example_gates = example_gates or {} + for node in dag.op_nodes(): + example_gates[(node.op.name, node.op.num_qubits)] = node.op + if isinstance(node.op, ControlFlowOp): + for block in node.op.blocks: + example_gates = recurse(circuit_to_dag(block), example_gates) + return example_gates + + return recurse(source_dag) diff --git a/test/python/transpiler/test_basis_translator.py b/test/python/transpiler/test_basis_translator.py index a492bf91081a..3de3d1d9116f 100644 --- a/test/python/transpiler/test_basis_translator.py +++ b/test/python/transpiler/test_basis_translator.py @@ -20,7 +20,7 @@ from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit from qiskit import transpile from qiskit.test import QiskitTestCase -from qiskit.circuit import Gate, Parameter, EquivalenceLibrary +from qiskit.circuit import Gate, Parameter, EquivalenceLibrary, Qubit, Clbit from qiskit.circuit.library import ( U1Gate, U2Gate, @@ -49,15 +49,15 @@ class OneQubitZeroParamGate(Gate): """Mock one qubit zero param gate.""" - def __init__(self): - super().__init__("1q0p", 1, []) + def __init__(self, name="1q0p"): + super().__init__(name, 1, []) class OneQubitOneParamGate(Gate): """Mock one qubit one param gate.""" - def __init__(self, theta): - super().__init__("1q1p", 1, [theta]) + def __init__(self, theta, name="1q1p"): + super().__init__(name, 1, [theta]) class OneQubitOneParamPrimeGate(Gate): @@ -70,22 +70,22 @@ def __init__(self, alpha): class OneQubitTwoParamGate(Gate): """Mock one qubit two param gate.""" - def __init__(self, phi, lam): - super().__init__("1q2p", 1, [phi, lam]) + def __init__(self, phi, lam, name="1q2p"): + super().__init__(name, 1, [phi, lam]) class TwoQubitZeroParamGate(Gate): """Mock one qubit zero param gate.""" - def __init__(self): - super().__init__("2q0p", 2, []) + def __init__(self, name="2q0p"): + super().__init__(name, 2, []) class VariadicZeroParamGate(Gate): """Mock variadic zero param gate.""" - def __init__(self, num_qubits): - super().__init__("vq0p", num_qubits, []) + def __init__(self, num_qubits, name="vq0p"): + super().__init__(name, num_qubits, []) class TestBasisTranslator(QiskitTestCase): @@ -382,6 +382,77 @@ def test_diamond_path(self): self.assertEqual(actual, expected_dag) + def test_if_else(self): + """Test a simple if-else with parameters.""" + qubits = [Qubit(), Qubit()] + clbits = [Clbit(), Clbit()] + alpha = Parameter("alpha") + beta = Parameter("beta") + gate = OneQubitOneParamGate(alpha) + equiv = QuantumCircuit([qubits[0]]) + equiv.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]]) + equiv.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]]) + + eq_lib = EquivalenceLibrary() + eq_lib.add_equivalence(gate, equiv) + + circ = QuantumCircuit(qubits, clbits) + circ.append(OneQubitOneParamGate(beta), [qubits[0]]) + circ.measure(qubits[0], clbits[1]) + with circ.if_test((clbits[1], 0)) as else_: + circ.append(OneQubitOneParamGate(alpha), [qubits[0]]) + circ.append(TwoQubitZeroParamGate(), qubits) + with else_: + circ.append(TwoQubitZeroParamGate(), [qubits[1], qubits[0]]) + dag = circuit_to_dag(circ) + dag_translated = BasisTranslator(eq_lib, ["if_else", "1q0p_2", "1q1p_2", "2q0p"]).run(dag) + + expected = QuantumCircuit(qubits, clbits) + expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]]) + expected.append(OneQubitOneParamGate(beta, name="1q1p_2"), [qubits[0]]) + expected.measure(qubits[0], clbits[1]) + with expected.if_test((clbits[1], 0)) as else_: + expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]]) + expected.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]]) + expected.append(TwoQubitZeroParamGate(), qubits) + with else_: + expected.append(TwoQubitZeroParamGate(), [qubits[1], qubits[0]]) + dag_expected = circuit_to_dag(expected) + self.assertEqual(dag_translated, dag_expected) + + def test_nested_loop(self): + """Test a simple if-else with parameters.""" + qubits = [Qubit(), Qubit()] + clbits = [Clbit(), Clbit()] + cr = ClassicalRegister(bits=clbits) + index1 = Parameter("index1") + alpha = Parameter("alpha") + + gate = OneQubitOneParamGate(alpha) + equiv = QuantumCircuit([qubits[0]]) + equiv.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]]) + equiv.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]]) + + eq_lib = EquivalenceLibrary() + eq_lib.add_equivalence(gate, equiv) + + circ = QuantumCircuit(qubits, cr) + with circ.for_loop(range(3), loop_parameter=index1) as ind: + with circ.while_loop((cr, 0)): + circ.append(OneQubitOneParamGate(alpha * ind), [qubits[0]]) + dag = circuit_to_dag(circ) + dag_translated = BasisTranslator( + eq_lib, ["if_else", "for_loop", "while_loop", "1q0p_2", "1q1p_2"] + ).run(dag) + + expected = QuantumCircuit(qubits, cr) + with expected.for_loop(range(3), loop_parameter=index1) as ind: + with expected.while_loop((cr, 0)): + expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]]) + expected.append(OneQubitOneParamGate(alpha * ind, name="1q1p_2"), [qubits[0]]) + dag_expected = circuit_to_dag(expected) + self.assertEqual(dag_translated, dag_expected) + class TestUnrollerCompatability(QiskitTestCase): """Tests backward compatability with the Unroller pass.