Skip to content

Commit

Permalink
Use retworkx for substitute_node_with_dag (#6302)
Browse files Browse the repository at this point in the history
* Use retworkx for substitute_node_with_dag

This commit leverage the substitute_node_with_subgraph method being
added Qiskit/rustworkx#312 for the dagcircuit method
substitute_node_with_dag.

* DNM: Add retworkx from PR branch to CI

* Fix handling of input dag with direct edge from input to output

* Update requirements for testing

* Run black

* Avoid node copy

* Avoid op func call overhead

* Fix lint

* Expand substitute tests

* Make failing test deterministic

* Fix qasm example

* Fix lint

* Update retworkx source path

* Fix rebase issues

* Bump minimum retworkx version to latest release

* Reduce iterations building wire maps

* Use a plain list comprehension instead of a lambda map

* Apply suggestions from code review

Co-authored-by: Kevin Krsulich <[email protected]>

* Improve code comments

* Add reno touting performance benefits

Co-authored-by: Kevin Krsulich <[email protected]>
Co-authored-by: Jake Lishman <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 7, 2021
1 parent 4fd83bb commit 1eb6681
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 96 deletions.
175 changes: 81 additions & 94 deletions qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,60 +866,6 @@ def _check_wires_list(self, wires, node):
if len(wires) != wire_tot:
raise DAGCircuitError("expected %d wires, got %d" % (wire_tot, len(wires)))

def _make_pred_succ_maps(self, node):
"""Return predecessor and successor dictionaries.
Args:
node (DAGOpNode): reference to multi_graph node
Returns:
tuple(dict): tuple(predecessor_map, successor_map)
These map from wire (Register, int) to the node ids for the
predecessor (successor) nodes of the input node.
"""

pred_map = {e[2]: e[0] for e in self._multi_graph.in_edges(node._node_id)}
succ_map = {e[2]: e[1] for e in self._multi_graph.out_edges(node._node_id)}
return pred_map, succ_map

def _full_pred_succ_maps(self, pred_map, succ_map, input_circuit, wire_map):
"""Map all wires of the input circuit.
Map all wires of the input circuit to predecessor and
successor nodes in self, keyed on wires in self.
Args:
pred_map (dict): comes from _make_pred_succ_maps
succ_map (dict): comes from _make_pred_succ_maps
input_circuit (DAGCircuit): the input circuit
wire_map (dict): the map from wires of input_circuit to wires of self
Returns:
tuple: full_pred_map, full_succ_map (dict, dict)
Raises:
DAGCircuitError: if more than one predecessor for output nodes
"""
full_pred_map = {}
full_succ_map = {}
for w in input_circuit.input_map:
# If w is wire mapped, find the corresponding predecessor
# of the node
if w in wire_map:
full_pred_map[wire_map[w]] = pred_map[wire_map[w]]
full_succ_map[wire_map[w]] = succ_map[wire_map[w]]
else:
# Otherwise, use the corresponding output nodes of self
# and compute the predecessor.
full_succ_map[w] = self.output_map[w]
full_pred_map[w] = self._multi_graph.predecessors(self.output_map[w])[0]
if len(self._multi_graph.predecessors(self.output_map[w])) != 1:
raise DAGCircuitError(
"too many predecessors for %s[%d] output node" % (w.register, w.index)
)

return full_pred_map, full_succ_map

def __eq__(self, other):
# Try to convert to float, but in case of unbound ParameterExpressions
# a TypeError will be raise, fallback to normal equality in those
Expand Down Expand Up @@ -1022,7 +968,7 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):

if wires is None:
wires = in_dag.wires

wire_set = set(wires)
self._check_wires_list(wires, node)

# Create a proxy wire_map to identify fragments and duplicates
Expand All @@ -1044,12 +990,14 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):

condition_bit_list = self._bits_in_condition(node.op.condition)

wire_map = dict(zip(wires, list(node.qargs) + list(node.cargs) + list(condition_bit_list)))
new_wires = list(node.qargs) + list(node.cargs) + list(condition_bit_list)

wire_map = {}
reverse_wire_map = {}
for wire, new_wire in zip(wires, new_wires):
wire_map[wire] = new_wire
reverse_wire_map[new_wire] = wire
self._check_wiremap_validity(wire_map, wires, self.input_map)
pred_map, succ_map = self._make_pred_succ_maps(node)
full_pred_map, full_succ_map = self._full_pred_succ_maps(
pred_map, succ_map, in_dag, wire_map
)

if condition_bit_list:
# If we are replacing a conditional node, map input dag through
Expand All @@ -1065,40 +1013,79 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):
"Mapped DAG would alter clbits on which it would be conditioned."
)

# Now that we know the connections, delete node
self._multi_graph.remove_node(node._node_id)

# Iterate over nodes of input_circuit
for sorted_node in in_dag.topological_op_nodes():
# Insert a new node
condition = self._map_condition(wire_map, sorted_node.op.condition, self.cregs.values())
m_qargs = list(map(lambda x: wire_map.get(x, x), sorted_node.qargs))
m_cargs = list(map(lambda x: wire_map.get(x, x), sorted_node.cargs))
node_index = self._add_op_node(sorted_node.op, m_qargs, m_cargs)

# Add edges from predecessor nodes to new node
# and update predecessor nodes that change
all_cbits = self._bits_in_condition(condition)
all_cbits.extend(m_cargs)
al = [m_qargs, all_cbits]
for q in itertools.chain(*al):
self._multi_graph.add_edge(full_pred_map[q], node_index, q)
full_pred_map[q] = node_index

# Connect all predecessors and successors, and remove
# residual edges between input and output nodes
for w in full_pred_map:
self._multi_graph.add_edge(full_pred_map[w], full_succ_map[w], w)
o_pred = self._multi_graph.predecessors(self.output_map[w]._node_id)
if len(o_pred) > 1:
if len(o_pred) != 2:
raise DAGCircuitError("expected 2 predecessors here")

p = [x for x in o_pred if x != full_pred_map[w]]
if len(p) != 1:
raise DAGCircuitError("expected 1 predecessor to pass filter")

self._multi_graph.remove_edge(p[0], self.output_map[w])
# Add wire from pred to succ if no ops on mapped wire on ``in_dag``
# retworkx's substitute_node_with_subgraph lacks the DAGCircuit
# context to know what to do in this case (the method won't even see
# these nodes because they're filtered) so we manually retain the
# edges prior to calling substitute_node_with_subgraph and set the
# edge_map_fn callback kwarg to skip these edges when they're
# encountered.
for wire in wires:
input_node = in_dag.input_map[wire]
output_node = in_dag.output_map[wire]
if in_dag._multi_graph.has_edge(input_node._node_id, output_node._node_id):
self_wire = wire_map[wire]
pred = self._multi_graph.find_predecessors_by_edge(
node._node_id, lambda edge, wire=self_wire: edge == wire
)[0]
succ = self._multi_graph.find_successors_by_edge(
node._node_id, lambda edge, wire=self_wire: edge == wire
)[0]
self._multi_graph.add_edge(pred._node_id, succ._node_id, self_wire)

# Exlude any nodes from in_dag that are not a DAGOpNode or are on
# bits outside the set specified by the wires kwarg
def filter_fn(node):
if not isinstance(node, DAGOpNode):
return False
for qarg in node.qargs:
if qarg not in wire_set:
return False
return True

# Map edges into and out of node to the appropriate node from in_dag
def edge_map_fn(source, _target, self_wire):
wire = reverse_wire_map[self_wire]
# successor edge
if source == node._node_id:
wire_output_id = in_dag.output_map[wire]._node_id
out_index = in_dag._multi_graph.predecessor_indices(wire_output_id)[0]
# Edge directly from from input nodes to output nodes in in_dag are
# already handled prior to calling retworkx. Don't map these edges
# in retworkx.
if not isinstance(in_dag._multi_graph[out_index], DAGOpNode):
return None
# predecessor edge
else:
wire_input_id = in_dag.input_map[wire]._node_id
out_index = in_dag._multi_graph.successor_indices(wire_input_id)[0]
# Edge directly from from input nodes to output nodes in in_dag are
# already handled prior to calling retworkx. Don't map these edges
# in retworkx.
if not isinstance(in_dag._multi_graph[out_index], DAGOpNode):
return None
return out_index

# Adjust edge weights from in_dag
def edge_weight_map(wire):
return wire_map[wire]

node_map = self._multi_graph.substitute_node_with_subgraph(
node._node_id, in_dag._multi_graph, edge_map_fn, filter_fn, edge_weight_map
)

# Iterate over nodes of input_circuit and update wiires in node objects migrated
# from in_dag
for old_node_index, new_node_index in node_map.items():
# update node attributes
old_node = in_dag._multi_graph[old_node_index]
condition = self._map_condition(wire_map, old_node.op.condition, self.cregs.values())
m_qargs = [wire_map.get(x, x) for x in old_node.qargs]
m_cargs = [wire_map.get(x, x) for x in old_node.cargs]
new_node = DAGOpNode(old_node.op, qargs=m_qargs, cargs=m_cargs)
new_node._node_id = new_node_index
new_node.op.condition = condition
self._multi_graph[new_node_index] = new_node

def substitute_node(self, node, op, inplace=False):
"""Replace an DAGOpNode with a single instruction. qargs, cargs and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Various transpilation internals now use new features in `retworkx
<https://github.com/Qiskit/retworkx>`__ 0.10 when operating on the internal
circuit representation. This can often result in speedups in calls to
:obj:`~qiskit.transpile` of around 10-40%, with greater effects at higher
optimisation levels. See `#6302
<https://github.com/Qiskit/qiskit-terra/pull/6302>`__ for more details.
50 changes: 48 additions & 2 deletions test/python/dagcircuit/test_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,14 +1186,60 @@ def test_substitute_circuit_one_middle(self):
self.dag.substitute_node_with_dag(cx_node, flipped_cx_circuit, wires=[v[0], v[1]])

self.assertEqual(self.dag.count_ops()["h"], 5)
expected = DAGCircuit()
qreg = QuantumRegister(3, "qr")
creg = ClassicalRegister(2, "cr")
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(CXGate(), [qreg[1], qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])
self.assertEqual(self.dag, expected)

def test_substitute_circuit_one_front(self):
"""The method substitute_node_with_dag() replaces a leaf-in-the-front node with a DAG."""
pass
circuit = DAGCircuit()
v = QuantumRegister(1, "v")
circuit.add_qreg(v)
circuit.apply_operation_back(HGate(), [v[0]], [])
circuit.apply_operation_back(XGate(), [v[0]], [])

self.dag.substitute_node_with_dag(next(self.dag.topological_op_nodes()), circuit)
expected = DAGCircuit()
qreg = QuantumRegister(3, "qr")
creg = ClassicalRegister(2, "cr")
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(XGate(), [qreg[0]], [])
expected.apply_operation_back(CXGate(), [qreg[0], qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])
self.assertEqual(self.dag, expected)

def test_substitute_circuit_one_back(self):
"""The method substitute_node_with_dag() replaces a leaf-in-the-back node with a DAG."""
pass
circuit = DAGCircuit()
v = QuantumRegister(1, "v")
circuit.add_qreg(v)
circuit.apply_operation_back(HGate(), [v[0]], [])
circuit.apply_operation_back(XGate(), [v[0]], [])

self.dag.substitute_node_with_dag(list(self.dag.topological_op_nodes())[2], circuit)
expected = DAGCircuit()
qreg = QuantumRegister(3, "qr")
creg = ClassicalRegister(2, "cr")
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(CXGate(), [qreg[0], qreg[1]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])

self.assertEqual(self.dag, expected)

def test_raise_if_substituting_dag_modifies_its_conditional(self):
"""Verify that we raise if the input dag modifies any of the bits in node.op.condition."""
Expand Down
1 change: 1 addition & 0 deletions test/python/transpiler/test_basis_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


"""Test the BasisTranslator pass"""

import os

from numpy import pi
Expand Down

0 comments on commit 1eb6681

Please sign in to comment.