Skip to content

Commit

Permalink
support control flow in BasisTranslator pass (Qiskit#7808)
Browse files Browse the repository at this point in the history
* create two tests and 1st modification of unroller

* if_else test, parameter test

* black

* linting

* change shallow copy of control flow ops to not copy body

* add special copy

* debug

* add `replace_blocks` method

* minor update

* clean debug code

* linting fix bugs

* minor commit

* linting

* don't recurse on run

* linting

* don't mutate basis in "_update_basis".

* Update qiskit/transpiler/passes/basis/basis_translator.py

Co-authored-by: Jake Lishman <[email protected]>

* apply_translation returns bool

* factor out "replace_node" function

* linting

* singledispatchmethod -> singledispatch for python 3.7

* black

* fix indentation bug

* linting

* black

* changed _get_example_gates following @jakelishman suggestion.

Co-authored-by: Jake Lishman <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 21, 2022
1 parent 4414c4e commit f28f383
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 105 deletions.
2 changes: 0 additions & 2 deletions qiskit/circuit/controlflow/while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def blocks(self):

def replace_blocks(self, blocks):
(body,) = blocks
if not isinstance(body, QuantumCircuit):
raise CircuitError("WhileLoopOp expects a single QuantumCircuit when setting blocks")
return WhileLoopOp(self.condition, body, label=self.label)

def c_if(self, classical, val):
Expand Down
17 changes: 17 additions & 0 deletions qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,23 @@ def calibrations(self, calibrations: dict):
"""
self._calibrations = defaultdict(dict, calibrations)

def has_calibration_for(self, instr_context: Tuple):
"""Return True if the circuit has a calibration defined for the instruction context. In this
case, the operation does not need to be translated to the device basis.
"""
instr, qargs, _ = instr_context
if not self.calibrations or instr.name not in self.calibrations:
return False
qubits = tuple(self.qubits.index(qubit) for qubit in qargs)
params = []
for p in instr.params:
if isinstance(p, ParameterExpression) and not p.parameters:
params.append(float(p))
else:
params.append(p)
params = tuple(params)
return (qubits, params) in self.calibrations[instr.name]

@property
def metadata(self) -> dict:
"""The user provided metadata associated with the circuit
Expand Down
1 change: 0 additions & 1 deletion qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,6 @@ def __eq__(self, other):
# Try to convert to float, but in case of unbound ParameterExpressions
# a TypeError will be raise, fallback to normal equality in those
# cases

try:
self_phase = float(self.global_phase)
other_phase = float(other.global_phase)
Expand Down
248 changes: 157 additions & 91 deletions qiskit/transpiler/passes/basis/basis_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

from itertools import zip_longest
from collections import defaultdict
from functools import singledispatch

import retworkx

from qiskit.circuit import Gate, ParameterVector, QuantumRegister
from qiskit.circuit.equivalence import Key
from qiskit.circuit import Gate, ParameterVector, QuantumRegister, ControlFlowOp, QuantumCircuit
from qiskit.dagcircuit import DAGCircuit
from qiskit.converters import circuit_to_dag, dag_to_circuit
from qiskit.circuit.equivalence import Key
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.exceptions import TranspilerError

Expand Down Expand Up @@ -119,36 +121,12 @@ def run(self, dag):
if self._target is None:
basic_instrs = ["measure", "reset", "barrier", "snapshot", "delay"]
target_basis = set(self._target_basis)
source_basis = set()
for node in dag.op_nodes():
if not dag.has_calibration_for(node):
source_basis.add((node.name, node.op.num_qubits))
source_basis = set(_extract_basis(dag))
qargs_local_source_basis = {}
else:
basic_instrs = ["barrier", "snapshot"]
source_basis = set()
target_basis = self._target.keys() - set(self._non_global_operations)
qargs_local_source_basis = defaultdict(set)
for node in dag.op_nodes():
qargs = tuple(qarg_indices[bit] for bit in node.qargs)
if dag.has_calibration_for(node):
continue
# Treat the instruction as on an incomplete basis if the qargs are in the
# qargs_with_non_global_operation dictionary or if any of the qubits in qargs
# are a superset for a non-local operation. For example, if the qargs
# are (0, 1) and that's a global (ie no non-local operations on (0, 1)
# operation but there is a non-local operation on (1,) we need to
# do an extra non-local search for this op to ensure we include any
# single qubit operation for (1,) as valid. This pattern also holds
# true for > 2q ops too (so for 4q operations we need to check for 3q, 2q,
# and 1q operations in the same manner)
if qargs in self._qargs_with_non_global_operation or any(
frozenset(qargs).issuperset(incomplete_qargs)
for incomplete_qargs in self._qargs_with_non_global_operation
):
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
else:
source_basis.add((node.name, node.op.num_qubits))
source_basis, qargs_local_source_basis = self._extract_basis_target(dag, qarg_indices)

target_basis = set(target_basis).union(basic_instrs)

Expand Down Expand Up @@ -225,68 +203,43 @@ def run(self, dag):
# Replace source instructions with target translations.

replace_start_time = time.time()
for node in dag.op_nodes():
node_qargs = tuple(qarg_indices[bit] for bit in node.qargs)
qubit_set = frozenset(node_qargs)

if node.name in target_basis:
continue
if (
node_qargs in self._qargs_with_non_global_operation
and node.name in self._qargs_with_non_global_operation[node_qargs]
):
continue

if dag.has_calibration_for(node):
continue

def replace_node(node, instr_map):
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
if len(node.op.params) != len(target_params):
raise TranspilerError(
"Translation num_params not equal to op num_params."
"Op: {} {} Translation: {}\n{}".format(
node.op.params, node.op.name, target_params, target_dag
)
)

if node.op.params:
# Convert target to circ and back to assign_parameters, since
# DAGCircuits won't have a ParameterTable.
from qiskit.converters import dag_to_circuit, circuit_to_dag

target_circuit = dag_to_circuit(target_dag)

target_circuit.assign_parameters(
dict(zip_longest(target_params, node.op.params)), inplace=True
)
def apply_translation(dag):
dag_updated = False
for node in dag.op_nodes():
node_qargs = tuple(qarg_indices[bit] for bit in node.qargs)
qubit_set = frozenset(node_qargs)
if node.name in target_basis:
if isinstance(node.op, ControlFlowOp):
flow_blocks = []
for block in node.op.blocks:
dag_block = circuit_to_dag(block)
dag_updated = apply_translation(dag_block)
if dag_updated:
flow_circ_block = dag_to_circuit(dag_block)
else:
flow_circ_block = block
flow_blocks.append(flow_circ_block)
node.op = node.op.replace_blocks(flow_blocks)
continue
if (
node_qargs in self._qargs_with_non_global_operation
and node.name in self._qargs_with_non_global_operation[node_qargs]
):
continue

bound_target_dag = circuit_to_dag(target_circuit)
else:
bound_target_dag = target_dag

if len(bound_target_dag.op_nodes()) == 1 and len(
bound_target_dag.op_nodes()[0].qargs
) == len(node.qargs):
dag_op = bound_target_dag.op_nodes()[0].op
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if node.op.condition:
dag_op = dag_op.copy()
dag.substitute_node(node, dag_op, inplace=True)

if bound_target_dag.global_phase:
dag.global_phase += bound_target_dag.global_phase
if dag.has_calibration_for(node):
continue
if qubit_set in extra_instr_map:
self._replace_node(dag, node, extra_instr_map[qubit_set])
elif (node.op.name, node.op.num_qubits) in instr_map:
self._replace_node(dag, node, instr_map)
else:
dag.substitute_node_with_dag(node, bound_target_dag)

if qubit_set in extra_instr_map:
replace_node(node, extra_instr_map[qubit_set])
elif (node.op.name, node.op.num_qubits) in instr_map:
replace_node(node, instr_map)
else:
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
dag_updated = True
return dag_updated

apply_translation(dag)
replace_end_time = time.time()
logger.info(
"Basis translation instructions replaced in %.3fs.",
Expand All @@ -295,6 +248,110 @@ def replace_node(node, instr_map):

return dag

def _replace_node(self, dag, node, instr_map):
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
if len(node.op.params) != len(target_params):
raise TranspilerError(
"Translation num_params not equal to op num_params."
"Op: {} {} Translation: {}\n{}".format(
node.op.params, node.op.name, target_params, target_dag
)
)

if node.op.params:
# Convert target to circ and back to assign_parameters, since
# DAGCircuits won't have a ParameterTable.
target_circuit = dag_to_circuit(target_dag)

target_circuit.assign_parameters(
dict(zip_longest(target_params, node.op.params)), inplace=True
)

bound_target_dag = circuit_to_dag(target_circuit)
else:
bound_target_dag = target_dag

if len(bound_target_dag.op_nodes()) == 1 and len(
bound_target_dag.op_nodes()[0].qargs
) == len(node.qargs):
dag_op = bound_target_dag.op_nodes()[0].op
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if node.op.condition:
dag_op = dag_op.copy()
dag.substitute_node(node, dag_op, inplace=True)

if bound_target_dag.global_phase:
dag.global_phase += bound_target_dag.global_phase
else:
dag.substitute_node_with_dag(node, bound_target_dag)

def _extract_basis_target(
self, dag, qarg_indices, source_basis=None, qargs_local_source_basis=None
):
if source_basis is None:
source_basis = set()
if qargs_local_source_basis is None:
qargs_local_source_basis = defaultdict(set)
for node in dag.op_nodes():
qargs = tuple(qarg_indices[bit] for bit in node.qargs)
if dag.has_calibration_for(node):
continue
# Treat the instruction as on an incomplete basis if the qargs are in the
# qargs_with_non_global_operation dictionary or if any of the qubits in qargs
# are a superset for a non-local operation. For example, if the qargs
# are (0, 1) and that's a global (ie no non-local operations on (0, 1)
# operation but there is a non-local operation on (1,) we need to
# do an extra non-local search for this op to ensure we include any
# single qubit operation for (1,) as valid. This pattern also holds
# true for > 2q ops too (so for 4q operations we need to check for 3q, 2q,
# and 1q operations in the same manner)
if qargs in self._qargs_with_non_global_operation or any(
frozenset(qargs).issuperset(incomplete_qargs)
for incomplete_qargs in self._qargs_with_non_global_operation
):
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
else:
source_basis.add((node.name, node.op.num_qubits))
if isinstance(node.op, ControlFlowOp):
for block in node.op.blocks:
block_dag = circuit_to_dag(block)
source_basis, qargs_local_source_basis = self._extract_basis_target(
block_dag,
qarg_indices,
source_basis=source_basis,
qargs_local_source_basis=qargs_local_source_basis,
)
return source_basis, qargs_local_source_basis


# this could be singledispatchmethod and included in above class when minimum
# supported python version=3.8.
@singledispatch
def _extract_basis(circuit):
return circuit


@_extract_basis.register
def _(dag: DAGCircuit):
for node in dag.op_nodes():
if not dag.has_calibration_for(node):
yield (node.name, node.op.num_qubits)
if isinstance(node.op, ControlFlowOp):
for block in node.op.blocks:
yield from _extract_basis(block)


@_extract_basis.register
def _(circ: QuantumCircuit):
for instr_context in circ.data:
instr, _, _ = instr_context
if not circ.has_calibration_for(instr_context):
yield (instr.name, instr.num_qubits)
if isinstance(instr, ControlFlowOp):
for block in instr.blocks:
yield from _extract_basis(block)


class StopIfBasisRewritable(Exception):
"""Custom exception that signals `retworkx.dijkstra_search` to stop."""
Expand Down Expand Up @@ -486,8 +543,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
source_basis but not affected by basis_transforms will be included
as a key mapping to itself.
"""

example_gates = {(node.op.name, node.op.num_qubits): node.op for node in source_dag.op_nodes()}
example_gates = _get_example_gates(source_dag)
mapped_instrs = {}

for gate_name, gate_num_qubits in source_basis:
Expand Down Expand Up @@ -523,7 +579,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
]

if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
from qiskit.converters import dag_to_circuit

logger.debug(
"Updating transform for mapped instr %s %s from \n%s",
Expand All @@ -533,7 +588,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
)

for node in doomed_nodes:
from qiskit.converters import circuit_to_dag

replacement = equiv.assign_parameters(
dict(zip_longest(equiv_params, node.op.params))
Expand All @@ -544,7 +598,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
dag.substitute_node_with_dag(node, replacement_dag)

if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
from qiskit.converters import dag_to_circuit

logger.debug(
"Updated transform for mapped instr %s %s to\n%s",
Expand All @@ -554,3 +607,16 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
)

return mapped_instrs


def _get_example_gates(source_dag):
def recurse(dag, example_gates=None):
example_gates = example_gates or {}
for node in dag.op_nodes():
example_gates[(node.op.name, node.op.num_qubits)] = node.op
if isinstance(node.op, ControlFlowOp):
for block in node.op.blocks:
example_gates = recurse(circuit_to_dag(block), example_gates)
return example_gates

return recurse(source_dag)
Loading

0 comments on commit f28f383

Please sign in to comment.