Skip to content

Commit

Permalink
Fully port FilterOpNodes to Rust
Browse files Browse the repository at this point in the history
This commit ports the FilterOpNodes pass to rust. This pass is
exceedingly simple and just runs a filter function over all the op
nodes and removes nodes that match the filter. However, the API for
the class exposes that filter function interface as a user provided
Python callable. So for the current pass we need to retain that python
callback. This limits the absolute performance of this pass because
we're bottlenecked by calling python.

Looking to the future, this commit adds a rust native method to
DAGCircuit to perform this filtering with a rust predicate FnMut. It
isn't leveraged by the python implementation because of layer mismatch
for the efficient rust interface and Python working with `DAGOpNode`
objects. A function using that interface is added to filter labeled
nodes. In the preset pass manager we only use FilterOpNodes to remove
nodes with a specific label (which is used to identify temporary
barriers created by qiskit). In a follow up we should consider
leveraging this new function to build a new pass specifically for
this use case.

Fixes Qiskit#12263
Part of Qiskit#12208
  • Loading branch information
mtreinish committed Aug 28, 2024
1 parent cc2edc9 commit 081c4bc
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 10 deletions.
65 changes: 65 additions & 0 deletions crates/accelerate/src/filter_op_nodes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use std::convert::Infallible;

use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

use qiskit_circuit::dag_circuit::DAGCircuit;
use qiskit_circuit::packed_instruction::PackedInstruction;
use rustworkx_core::petgraph::stable_graph::NodeIndex;

#[pyfunction]
#[pyo3(name = "filter_op_nodes")]
pub fn py_filter_op_nodes(
py: Python,
dag: &mut DAGCircuit,
predicate: &Bound<PyAny>,
) -> PyResult<()> {
let callable = |node: NodeIndex| -> PyResult<bool> {
let dag_op_node = dag.get_node(py, node)?;
predicate.call1((dag_op_node,))?.extract()
};
let mut remove_nodes: Vec<NodeIndex> = Vec::new();
for node in dag.op_nodes(true) {
if !callable(node)? {
remove_nodes.push(node);
}
}
for node in remove_nodes {
dag.remove_op_node(node);
}
Ok(())
}

/// Remove any nodes that have the provided label set
///
/// Args:
/// dag (DAGCircuit): The dag circuit to filter the ops from
/// label (str): The label to filter nodes on
#[pyfunction]
pub fn filter_labelled_op(dag: &mut DAGCircuit, label: String) {
let predicate = |node: &PackedInstruction| -> Result<bool, Infallible> {
match node.label() {
Some(inst_label) => Ok(inst_label != label),
None => Ok(false),
}
};
let _ = dag.filter_op_nodes(predicate);
}

pub fn filter_op_nodes_mod(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(py_filter_op_nodes))?;
m.add_wrapped(wrap_pyfunction!(filter_labelled_op))?;
Ok(())
}
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod dense_layout;
pub mod edge_collections;
pub mod error_map;
pub mod euler_one_qubit_decomposer;
pub mod filter_op_nodes;
pub mod isometry;
pub mod nlayout;
pub mod optimize_1q_gates;
Expand Down
22 changes: 21 additions & 1 deletion crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5604,7 +5604,7 @@ impl DAGCircuit {
/// Remove an operation node n.
///
/// Add edges from predecessors to successors.
fn remove_op_node(&mut self, index: NodeIndex) {
pub fn remove_op_node(&mut self, index: NodeIndex) {
let mut edge_list: Vec<(NodeIndex, NodeIndex, Wire)> = Vec::new();
for (source, in_weight) in self
.dag
Expand Down Expand Up @@ -5785,6 +5785,26 @@ impl DAGCircuit {
}
}

// Filter any nodes that don't match a given predicate function
pub fn filter_op_nodes<F, E>(&mut self, mut predicate: F) -> Result<(), E>
where
F: FnMut(&PackedInstruction) -> Result<bool, E>,
{
let mut remove_nodes: Vec<NodeIndex> = Vec::new();
for node in self.op_nodes(true) {
let NodeType::Operation(op) = &self.dag[node] else {
unreachable!()
};
if !predicate(op)? {
remove_nodes.push(node);
}
}
for node in remove_nodes {
self.remove_op_node(node);
}
Ok(())
}

pub fn op_nodes_by_py_type<'a>(
&'a self,
op: &'a Bound<PyType>,
Expand Down
7 changes: 7 additions & 0 deletions crates/circuit/src/packed_instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,13 @@ impl PackedInstruction {
.and_then(|extra| extra.condition.as_ref())
}

#[inline]
pub fn label(&self) -> Option<&str> {
self.extra_attrs
.as_ref()
.and_then(|extra| extra.label.as_deref())
}

/// Build a reference to the Python-space operation object (the `Gate`, etc) packed into this
/// instruction. This may construct the reference if the `PackedInstruction` is a standard
/// gate with no already stored operation.
Expand Down
14 changes: 8 additions & 6 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ use pyo3::prelude::*;
use qiskit_accelerate::{
circuit_library::circuit_library, convert_2q_block_matrix::convert_2q_block_matrix,
dense_layout::dense_layout, error_map::error_map,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
target_transpiler::target, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
utils::utils, vf2_layout::vf2_layout,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, filter_op_nodes::filter_op_nodes_mod,
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
pauli_exp_val::pauli_expval, results::results, sabre::sabre, sampled_exp_val::sampled_exp_val,
sparse_pauli_op::sparse_pauli_op, star_prerouting::star_prerouting,
stochastic_swap::stochastic_swap, synthesis::synthesis, target_transpiler::target,
two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, utils::utils,
vf2_layout::vf2_layout,
};

#[inline(always)]
Expand All @@ -44,6 +45,7 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, dense_layout, "dense_layout")?;
add_submodule(m, error_map, "error_map")?;
add_submodule(m, euler_one_qubit_decomposer, "euler_one_qubit_decomposer")?;
add_submodule(m, filter_op_nodes_mod, "filter_op_nodes")?;
add_submodule(m, isometry, "isometry")?;
add_submodule(m, nlayout, "nlayout")?;
add_submodule(m, optimize_1q_gates, "optimize_1q_gates")?;
Expand Down
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
sys.modules["qiskit._accelerate.synthesis.linear"] = _accelerate.synthesis.linear
sys.modules["qiskit._accelerate.synthesis.clifford"] = _accelerate.synthesis.clifford
sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase
sys.modules["qiskit._accelerate.filter_op_nodes"] = _accelerate.filter_op_nodes

from qiskit.exceptions import QiskitError, MissingOptionalLibraryError

Expand Down
6 changes: 3 additions & 3 deletions qiskit/transpiler/passes/utils/filter_op_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.passes.utils import control_flow

from qiskit._accelerate.filter_op_nodes import filter_op_nodes


class FilterOpNodes(TransformationPass):
"""Remove all operations that match a filter function
Expand Down Expand Up @@ -59,7 +61,5 @@ def __init__(self, predicate: Callable[[DAGOpNode], bool]):
@control_flow.trivial_recurse
def run(self, dag: DAGCircuit) -> DAGCircuit:
"""Run the RemoveBarriers pass on `dag`."""
for node in dag.op_nodes():
if not self.predicate(node):
dag.remove_op_node(node)
filter_op_nodes(dag, self.predicate)
return dag

0 comments on commit 081c4bc

Please sign in to comment.