Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use retworkx for substitute_node_with_dag #6302

Merged
merged 35 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a32fef8
Use retworkx for substitute_node_with_dag
mtreinish Apr 25, 2021
acee0ab
DNM: Add retworkx from PR branch to CI
mtreinish Apr 26, 2021
ed9f2ec
Merge branch 'main' into perf-twaeks
mtreinish Apr 26, 2021
773b7b1
Merge branch 'main' into perf-twaeks
mtreinish Jun 2, 2021
2ee3216
Fix handling of input dag with direct edge from input to output
mtreinish Jun 3, 2021
baa8f91
Update requirements for testing
mtreinish Jun 3, 2021
d26cf7d
Run black
mtreinish Jun 3, 2021
ad51c04
Avoid node copy
mtreinish Jun 3, 2021
0372583
Avoid op func call overhead
mtreinish Jun 3, 2021
ca42178
Fix lint
mtreinish Jun 3, 2021
b4e2efa
Expand substitute tests
mtreinish Jun 3, 2021
193e224
Make failing test deterministic
mtreinish Jun 3, 2021
df153d5
Fix qasm example
mtreinish Jun 4, 2021
c0fc655
Fix lint
mtreinish Jun 4, 2021
d9f887a
Merge branch 'main' into perf-twaeks
mtreinish Jun 4, 2021
1763563
Merge branch 'main' into perf-twaeks
mtreinish Jun 4, 2021
4eb83c6
Merge branch 'main' into perf-twaeks
mtreinish Jun 10, 2021
4014afe
Merge branch 'main' into perf-twaeks
mtreinish Jun 14, 2021
85cd3b1
Merge branch 'main' into perf-twaeks
mtreinish Jun 22, 2021
bcd65ed
Merge branch 'main' into perf-twaeks
mtreinish Jun 23, 2021
46dcdb7
Merge branch 'main' into perf-twaeks
mtreinish Jun 25, 2021
0b5d4ac
Update retworkx source path
mtreinish Jun 25, 2021
7b39b7e
Merge remote-tracking branch 'origin/main' into perf-twaeks
mtreinish Aug 5, 2021
84d03a0
Fix rebase issues
mtreinish Aug 5, 2021
8a64b39
Merge branch 'main' into perf-twaeks
mtreinish Aug 26, 2021
4e926cc
Bump minimum retworkx version to latest release
mtreinish Aug 26, 2021
f571561
Reduce iterations building wire maps
mtreinish Aug 26, 2021
3f8f1ca
Merge branch 'main' into perf-twaeks
mtreinish Aug 27, 2021
9e0118b
Use a plain list comprehension instead of a lambda map
mtreinish Aug 27, 2021
a9a0633
Merge branch 'perf-twaeks' of github.com:mtreinish/qiskit-core into p…
mtreinish Aug 27, 2021
9f50236
Apply suggestions from code review
mtreinish Aug 30, 2021
7a12eac
Improve code comments
mtreinish Aug 30, 2021
22d9393
Merge remote-tracking branch 'origin/main' into perf-twaeks
mtreinish Aug 30, 2021
c8d43f5
Add reno touting performance benefits
jakelishman Aug 31, 2021
2088453
Merge branch 'main' into perf-twaeks
mergify[bot] Sep 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 81 additions & 94 deletions qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,60 +866,6 @@ def _check_wires_list(self, wires, node):
if len(wires) != wire_tot:
raise DAGCircuitError("expected %d wires, got %d" % (wire_tot, len(wires)))

def _make_pred_succ_maps(self, node):
"""Return predecessor and successor dictionaries.

Args:
node (DAGOpNode): reference to multi_graph node

Returns:
tuple(dict): tuple(predecessor_map, successor_map)
These map from wire (Register, int) to the node ids for the
predecessor (successor) nodes of the input node.
"""

pred_map = {e[2]: e[0] for e in self._multi_graph.in_edges(node._node_id)}
succ_map = {e[2]: e[1] for e in self._multi_graph.out_edges(node._node_id)}
return pred_map, succ_map

def _full_pred_succ_maps(self, pred_map, succ_map, input_circuit, wire_map):
"""Map all wires of the input circuit.

Map all wires of the input circuit to predecessor and
successor nodes in self, keyed on wires in self.

Args:
pred_map (dict): comes from _make_pred_succ_maps
succ_map (dict): comes from _make_pred_succ_maps
input_circuit (DAGCircuit): the input circuit
wire_map (dict): the map from wires of input_circuit to wires of self

Returns:
tuple: full_pred_map, full_succ_map (dict, dict)

Raises:
DAGCircuitError: if more than one predecessor for output nodes
"""
full_pred_map = {}
full_succ_map = {}
for w in input_circuit.input_map:
# If w is wire mapped, find the corresponding predecessor
# of the node
if w in wire_map:
full_pred_map[wire_map[w]] = pred_map[wire_map[w]]
full_succ_map[wire_map[w]] = succ_map[wire_map[w]]
else:
# Otherwise, use the corresponding output nodes of self
# and compute the predecessor.
full_succ_map[w] = self.output_map[w]
full_pred_map[w] = self._multi_graph.predecessors(self.output_map[w])[0]
if len(self._multi_graph.predecessors(self.output_map[w])) != 1:
raise DAGCircuitError(
"too many predecessors for %s[%d] output node" % (w.register, w.index)
)

return full_pred_map, full_succ_map

def __eq__(self, other):
# Try to convert to float, but in case of unbound ParameterExpressions
# a TypeError will be raise, fallback to normal equality in those
Expand Down Expand Up @@ -1022,7 +968,7 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):

if wires is None:
wires = in_dag.wires

wire_set = set(wires)
self._check_wires_list(wires, node)

# Create a proxy wire_map to identify fragments and duplicates
Expand All @@ -1044,12 +990,14 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):

condition_bit_list = self._bits_in_condition(node.op.condition)

wire_map = dict(zip(wires, list(node.qargs) + list(node.cargs) + list(condition_bit_list)))
new_wires = list(node.qargs) + list(node.cargs) + list(condition_bit_list)

wire_map = {}
reverse_wire_map = {}
for wire, new_wire in zip(wires, new_wires):
wire_map[wire] = new_wire
reverse_wire_map[new_wire] = wire
self._check_wiremap_validity(wire_map, wires, self.input_map)
pred_map, succ_map = self._make_pred_succ_maps(node)
full_pred_map, full_succ_map = self._full_pred_succ_maps(
pred_map, succ_map, in_dag, wire_map
)

if condition_bit_list:
# If we are replacing a conditional node, map input dag through
Expand All @@ -1065,40 +1013,79 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):
"Mapped DAG would alter clbits on which it would be conditioned."
)

# Now that we know the connections, delete node
self._multi_graph.remove_node(node._node_id)

# Iterate over nodes of input_circuit
for sorted_node in in_dag.topological_op_nodes():
# Insert a new node
condition = self._map_condition(wire_map, sorted_node.op.condition, self.cregs.values())
m_qargs = list(map(lambda x: wire_map.get(x, x), sorted_node.qargs))
m_cargs = list(map(lambda x: wire_map.get(x, x), sorted_node.cargs))
node_index = self._add_op_node(sorted_node.op, m_qargs, m_cargs)

# Add edges from predecessor nodes to new node
# and update predecessor nodes that change
all_cbits = self._bits_in_condition(condition)
all_cbits.extend(m_cargs)
al = [m_qargs, all_cbits]
for q in itertools.chain(*al):
self._multi_graph.add_edge(full_pred_map[q], node_index, q)
full_pred_map[q] = node_index

# Connect all predecessors and successors, and remove
# residual edges between input and output nodes
for w in full_pred_map:
self._multi_graph.add_edge(full_pred_map[w], full_succ_map[w], w)
o_pred = self._multi_graph.predecessors(self.output_map[w]._node_id)
if len(o_pred) > 1:
if len(o_pred) != 2:
raise DAGCircuitError("expected 2 predecessors here")

p = [x for x in o_pred if x != full_pred_map[w]]
if len(p) != 1:
raise DAGCircuitError("expected 1 predecessor to pass filter")

self._multi_graph.remove_edge(p[0], self.output_map[w])
# Add wire from pred to succ if no ops on mapped wire on ``in_dag``
# retworkx's substitute_node_with_subgraph lacks the DAGCircuit
# context to know what to do in this case (the method won't even see
# these nodes because they're filtered) so we manually retain the
# edges prior to calling substitute_node_with_subgraph and set the
# edge_map_fn callback kwarg to skip these edges when they're
# encountered.
for wire in wires:
input_node = in_dag.input_map[wire]
output_node = in_dag.output_map[wire]
if in_dag._multi_graph.has_edge(input_node._node_id, output_node._node_id):
self_wire = wire_map[wire]
pred = self._multi_graph.find_predecessors_by_edge(
node._node_id, lambda edge, wire=self_wire: edge == wire
)[0]
Comment on lines +1028 to +1030
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check, here you're just using the second argument of lambda edge, wire=self_wire: edge == wire purely for the normal safe scoping within the lambda, right? e.g. it's only ever intended to be called with one argument, like self_wire.__eq__ but with conversion of NotImplemented to errors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did this solely for scoping. IIRC when I wrote this originally I didn't do this and lint or something complained. The filter_fn argument only gets passed a single arg which is the weight/data payload object for each edge: https://qiskit.org/documentation/retworkx/dev/stubs/retworkx.PyDiGraph.find_predecessors_by_edge.html#retworkx.PyDiGraph.find_predecessors_by_edge

succ = self._multi_graph.find_successors_by_edge(
node._node_id, lambda edge, wire=self_wire: edge == wire
)[0]
self._multi_graph.add_edge(pred._node_id, succ._node_id, self_wire)

# Exlude any nodes from in_dag that are not a DAGOpNode or are on
# bits outside the set specified by the wires kwarg
def filter_fn(node):
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(node, DAGOpNode):
return False
for qarg in node.qargs:
if qarg not in wire_set:
return False
return True

# Map edges into and out of node to the appropriate node from in_dag
def edge_map_fn(source, _target, self_wire):
wire = reverse_wire_map[self_wire]
# successor edge
if source == node._node_id:
wire_output_id = in_dag.output_map[wire]._node_id
out_index = in_dag._multi_graph.predecessor_indices(wire_output_id)[0]
# Edge directly from from input nodes to output nodes in in_dag are
# already handled prior to calling retworkx. Don't map these edges
# in retworkx.
if not isinstance(in_dag._multi_graph[out_index], DAGOpNode):
return None
# predecessor edge
else:
wire_input_id = in_dag.input_map[wire]._node_id
out_index = in_dag._multi_graph.successor_indices(wire_input_id)[0]
# Edge directly from from input nodes to output nodes in in_dag are
# already handled prior to calling retworkx. Don't map these edges
# in retworkx.
if not isinstance(in_dag._multi_graph[out_index], DAGOpNode):
return None
return out_index

# Adjust edge weights from in_dag
def edge_weight_map(wire):
return wire_map[wire]

node_map = self._multi_graph.substitute_node_with_subgraph(
node._node_id, in_dag._multi_graph, edge_map_fn, filter_fn, edge_weight_map
)

# Iterate over nodes of input_circuit and update wiires in node objects migrated
# from in_dag
for old_node_index, new_node_index in node_map.items():
# update node attributes
old_node = in_dag._multi_graph[old_node_index]
condition = self._map_condition(wire_map, old_node.op.condition, self.cregs.values())
m_qargs = [wire_map.get(x, x) for x in old_node.qargs]
m_cargs = [wire_map.get(x, x) for x in old_node.cargs]
new_node = DAGOpNode(old_node.op, qargs=m_qargs, cargs=m_cargs)
new_node._node_id = new_node_index
new_node.op.condition = condition
self._multi_graph[new_node_index] = new_node

def substitute_node(self, node, op, inplace=False):
"""Replace an DAGOpNode with a single instruction. qargs, cargs and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Various transpilation internals now use new features in `retworkx
<https://github.com/Qiskit/retworkx>`__ 0.10 when operating on the internal
circuit representation. This can often result in speedups in calls to
:obj:`~qiskit.transpile` of around 10-40%, with greater effects at higher
optimisation levels. See `#6302
<https://github.com/Qiskit/qiskit-terra/pull/6302>`__ for more details.
50 changes: 48 additions & 2 deletions test/python/dagcircuit/test_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,14 +1186,60 @@ def test_substitute_circuit_one_middle(self):
self.dag.substitute_node_with_dag(cx_node, flipped_cx_circuit, wires=[v[0], v[1]])

self.assertEqual(self.dag.count_ops()["h"], 5)
expected = DAGCircuit()
qreg = QuantumRegister(3, "qr")
creg = ClassicalRegister(2, "cr")
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(CXGate(), [qreg[1], qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])
self.assertEqual(self.dag, expected)

def test_substitute_circuit_one_front(self):
"""The method substitute_node_with_dag() replaces a leaf-in-the-front node with a DAG."""
pass
circuit = DAGCircuit()
v = QuantumRegister(1, "v")
circuit.add_qreg(v)
circuit.apply_operation_back(HGate(), [v[0]], [])
circuit.apply_operation_back(XGate(), [v[0]], [])

self.dag.substitute_node_with_dag(next(self.dag.topological_op_nodes()), circuit)
expected = DAGCircuit()
qreg = QuantumRegister(3, "qr")
creg = ClassicalRegister(2, "cr")
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(XGate(), [qreg[0]], [])
expected.apply_operation_back(CXGate(), [qreg[0], qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])
self.assertEqual(self.dag, expected)

def test_substitute_circuit_one_back(self):
"""The method substitute_node_with_dag() replaces a leaf-in-the-back node with a DAG."""
pass
circuit = DAGCircuit()
v = QuantumRegister(1, "v")
circuit.add_qreg(v)
circuit.apply_operation_back(HGate(), [v[0]], [])
circuit.apply_operation_back(XGate(), [v[0]], [])

self.dag.substitute_node_with_dag(list(self.dag.topological_op_nodes())[2], circuit)
expected = DAGCircuit()
qreg = QuantumRegister(3, "qr")
creg = ClassicalRegister(2, "cr")
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(CXGate(), [qreg[0], qreg[1]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])

self.assertEqual(self.dag, expected)

def test_raise_if_substituting_dag_modifies_its_conditional(self):
"""Verify that we raise if the input dag modifies any of the bits in node.op.condition."""
Expand Down
1 change: 1 addition & 0 deletions test/python/transpiler/test_basis_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


"""Test the BasisTranslator pass"""

import os

from numpy import pi
Expand Down