diff --git a/crates/accelerate/src/commutation_analysis.rs b/crates/accelerate/src/commutation_analysis.rs new file mode 100644 index 000000000000..08fa1dda5ec9 --- /dev/null +++ b/crates/accelerate/src/commutation_analysis.rs @@ -0,0 +1,192 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2024 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::PyModule; +use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python}; +use qiskit_circuit::Qubit; + +use crate::commutation_checker::CommutationChecker; +use hashbrown::HashMap; +use pyo3::prelude::*; + +use pyo3::types::{PyDict, PyList}; +use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire}; +use rustworkx_core::petgraph::stable_graph::NodeIndex; + +// Custom types to store the commutation sets and node indices, +// see the docstring below for more information. +type CommutationSet = HashMap>>; +type NodeIndices = HashMap<(NodeIndex, Wire), usize>; + +// the maximum number of qubits we check commutativity for +const MAX_NUM_QUBITS: u32 = 3; + +/// Compute the commutation sets for a given DAG. +/// +/// We return two HashMaps: +/// * {wire: commutation_sets}: For each wire, we keep a vector of index sets, where each index +/// set contains mutually commuting nodes. Note that these include the input and output nodes +/// which do not commute with anything. +/// * {(node, wire): index}: For each (node, wire) pair we store the index indicating in which +/// commutation set the node appears on a given wire. +/// +/// For example, if we have a circuit +/// +/// |0> -- X -- SX -- Z (out) +/// 0 2 3 4 1 <-- node indices including input (0) and output (1) nodes +/// +/// Then we would have +/// +/// commutation_set = {0: [[0], [2, 3], [4], [1]]} +/// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2} +/// +fn analyze_commutations_inner( + py: Python, + dag: &mut DAGCircuit, + commutation_checker: &mut CommutationChecker, +) -> PyResult<(CommutationSet, NodeIndices)> { + let mut commutation_set: CommutationSet = HashMap::new(); + let mut node_indices: NodeIndices = HashMap::new(); + + for qubit in 0..dag.num_qubits() { + let wire = Wire::Qubit(Qubit(qubit as u32)); + + for current_gate_idx in dag.nodes_on_wire(py, &wire, false) { + // get the commutation set associated with the current wire, or create a new + // index set containing the current gate + let commutation_entry = commutation_set + .entry(wire.clone()) + .or_insert_with(|| vec![vec![current_gate_idx]]); + + // we can unwrap as we know the commutation entry has at least one element + let last = commutation_entry.last_mut().unwrap(); + + // if the current gate index is not in the set, check whether it commutes with + // the previous nodes -- if yes, add it to the commutation set + if !last.contains(¤t_gate_idx) { + let mut all_commute = true; + + for prev_gate_idx in last.iter() { + // if the node is an input/output node, they do not commute, so we only + // continue if the nodes are operation nodes + if let (NodeType::Operation(packed_inst0), NodeType::Operation(packed_inst1)) = + (&dag.dag[current_gate_idx], &dag.dag[*prev_gate_idx]) + { + let op1 = packed_inst0.op.view(); + let op2 = packed_inst1.op.view(); + let params1 = packed_inst0.params_view(); + let params2 = packed_inst1.params_view(); + let qargs1 = dag.get_qargs(packed_inst0.qubits); + let qargs2 = dag.get_qargs(packed_inst1.qubits); + let cargs1 = dag.get_cargs(packed_inst0.clbits); + let cargs2 = dag.get_cargs(packed_inst1.clbits); + + all_commute = commutation_checker.commute_inner( + py, + &op1, + params1, + packed_inst0.extra_attrs.as_deref(), + qargs1, + cargs1, + &op2, + params2, + packed_inst1.extra_attrs.as_deref(), + qargs2, + cargs2, + MAX_NUM_QUBITS, + )?; + if !all_commute { + break; + } + } else { + all_commute = false; + break; + } + } + + if all_commute { + // all commute, add to current list + last.push(current_gate_idx); + } else { + // does not commute, create new list + commutation_entry.push(vec![current_gate_idx]); + } + } + + node_indices.insert( + (current_gate_idx, wire.clone()), + commutation_entry.len() - 1, + ); + } + } + + Ok((commutation_set, node_indices)) +} + +#[pyfunction] +#[pyo3(signature = (dag, commutation_checker))] +pub(crate) fn analyze_commutations( + py: Python, + dag: &mut DAGCircuit, + commutation_checker: &mut CommutationChecker, +) -> PyResult> { + // This returns two HashMaps: + // * The commuting nodes per wire: {wire: [commuting_nodes_1, commuting_nodes_2, ...]} + // * The index in which commutation set a given node is located on a wire: {(node, wire): index} + // The Python dict will store both of these dictionaries in one. + let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?; + + let out_dict = PyDict::new_bound(py); + + // First set the {wire: [commuting_nodes_1, ...]} bit + for (wire, commutations) in commutation_set { + // we know all wires are of type Wire::Qubit, since in analyze_commutations_inner + // we only iterater over the qubits + let py_wire = match wire { + Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py), + _ => return Err(PyValueError::new_err("Unexpected wire type.")), + }; + + out_dict.set_item( + py_wire, + PyList::new_bound( + py, + commutations.iter().map(|inner| { + PyList::new_bound( + py, + inner + .iter() + .map(|node_index| dag.get_node(py, *node_index).unwrap()), + ) + }), + ), + )?; + } + + // Then we add the {(node, wire): index} dictionary + for ((node_index, wire), index) in node_indices { + let py_wire = match wire { + Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py), + _ => return Err(PyValueError::new_err("Unexpected wire type.")), + }; + out_dict.set_item((dag.get_node(py, node_index)?, py_wire), index)?; + } + + Ok(out_dict.unbind()) +} + +#[pymodule] +pub fn commutation_analysis(m: &Bound) -> PyResult<()> { + m.add_wrapped(wrap_pyfunction!(analyze_commutations))?; + Ok(()) +} diff --git a/crates/accelerate/src/commutation_checker.rs b/crates/accelerate/src/commutation_checker.rs index 6fd9d58a7c23..b00d7c624c92 100644 --- a/crates/accelerate/src/commutation_checker.rs +++ b/crates/accelerate/src/commutation_checker.rs @@ -69,7 +69,7 @@ where /// lookups. It's not meant to be a public facing Python object though and only used /// internally by the Python class. #[pyclass(module = "qiskit._accelerate.commutation_checker")] -struct CommutationChecker { +pub struct CommutationChecker { library: CommutationLibrary, cache_max_entries: usize, cache: HashMap<(String, String), CommutationCacheEntry>, @@ -227,7 +227,7 @@ impl CommutationChecker { impl CommutationChecker { #[allow(clippy::too_many_arguments)] - fn commute_inner( + pub fn commute_inner( &mut self, py: Python, op1: &OperationRef, diff --git a/crates/accelerate/src/lib.rs b/crates/accelerate/src/lib.rs index 3a8c4bf51071..6561dd258614 100644 --- a/crates/accelerate/src/lib.rs +++ b/crates/accelerate/src/lib.rs @@ -15,6 +15,7 @@ use std::env; use pyo3::import_exception; pub mod circuit_library; +pub mod commutation_analysis; pub mod commutation_checker; pub mod convert_2q_block_matrix; pub mod dense_layout; diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index afe6797596d9..4d18f1967cf6 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -5523,7 +5523,7 @@ impl DAGCircuit { /// Get the nodes on the given wire. /// /// Note: result is empty if the wire is not in the DAG. - fn nodes_on_wire(&self, py: Python, wire: &Wire, only_ops: bool) -> Vec { + pub fn nodes_on_wire(&self, py: Python, wire: &Wire, only_ops: bool) -> Vec { let mut nodes = Vec::new(); let mut current_node = match wire { Wire::Qubit(qubit) => self.qubit_io_map.get(qubit.0 as usize).map(|x| x[0]), diff --git a/crates/pyext/src/lib.rs b/crates/pyext/src/lib.rs index 24a4badf6539..fdb2bff9a21d 100644 --- a/crates/pyext/src/lib.rs +++ b/crates/pyext/src/lib.rs @@ -13,15 +13,15 @@ use pyo3::prelude::*; use qiskit_accelerate::{ - circuit_library::circuit_library, commutation_checker::commutation_checker, - convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout, - error_map::error_map, euler_one_qubit_decomposer::euler_one_qubit_decomposer, - isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates, - pauli_exp_val::pauli_expval, results::results, sabre::sabre, sampled_exp_val::sampled_exp_val, - sparse_pauli_op::sparse_pauli_op, star_prerouting::star_prerouting, - stochastic_swap::stochastic_swap, synthesis::synthesis, target_transpiler::target, - two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, utils::utils, - vf2_layout::vf2_layout, + circuit_library::circuit_library, commutation_analysis::commutation_analysis, + commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix, + dense_layout::dense_layout, error_map::error_map, + euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout, + optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval, results::results, + sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op, + star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis, + target_transpiler::target, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, + utils::utils, vf2_layout::vf2_layout, }; #[inline(always)] @@ -62,5 +62,6 @@ fn _accelerate(m: &Bound) -> PyResult<()> { add_submodule(m, utils, "utils")?; add_submodule(m, vf2_layout, "vf2_layout")?; add_submodule(m, commutation_checker, "commutation_checker")?; + add_submodule(m, commutation_analysis, "commutation_analysis")?; Ok(()) } diff --git a/qiskit/__init__.py b/qiskit/__init__.py index 67384b38971a..38a9f5952425 100644 --- a/qiskit/__init__.py +++ b/qiskit/__init__.py @@ -87,6 +87,7 @@ sys.modules["qiskit._accelerate.synthesis.linear"] = _accelerate.synthesis.linear sys.modules["qiskit._accelerate.synthesis.clifford"] = _accelerate.synthesis.clifford sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker +sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase from qiskit.exceptions import QiskitError, MissingOptionalLibraryError diff --git a/qiskit/transpiler/passes/optimization/commutation_analysis.py b/qiskit/transpiler/passes/optimization/commutation_analysis.py index 12ed7145eec7..d801e4775937 100644 --- a/qiskit/transpiler/passes/optimization/commutation_analysis.py +++ b/qiskit/transpiler/passes/optimization/commutation_analysis.py @@ -12,11 +12,9 @@ """Analysis pass to find commutation relations between DAG nodes.""" -from collections import defaultdict - from qiskit.circuit.commutation_library import SessionCommutationChecker as scc -from qiskit.dagcircuit import DAGOpNode from qiskit.transpiler.basepasses import AnalysisPass +from qiskit._accelerate.commutation_analysis import analyze_commutations class CommutationAnalysis(AnalysisPass): @@ -33,6 +31,7 @@ def __init__(self, *, _commutation_checker=None): # do not care about commutations of all gates, but just a subset if _commutation_checker is None: _commutation_checker = scc + self.comm_checker = _commutation_checker def run(self, dag): @@ -42,49 +41,4 @@ def run(self, dag): into the ``property_set``. """ # Initiate the commutation set - self.property_set["commutation_set"] = defaultdict(list) - - # Build a dictionary to keep track of the gates on each qubit - # The key with format (wire) will store the lists of commutation sets - # The key with format (node, wire) will store the index of the commutation set - # on the specified wire, thus, for example: - # self.property_set['commutation_set'][wire][(node, wire)] will give the - # commutation set that contains node. - - for wire in dag.qubits: - self.property_set["commutation_set"][wire] = [] - - # Add edges to the dictionary for each qubit - for node in dag.topological_op_nodes(): - for _, _, edge_wire in dag.edges(node): - self.property_set["commutation_set"][(node, edge_wire)] = -1 - - # Construct the commutation set - for wire in dag.qubits: - - for current_gate in dag.nodes_on_wire(wire): - - current_comm_set = self.property_set["commutation_set"][wire] - if not current_comm_set: - current_comm_set.append([current_gate]) - - if current_gate not in current_comm_set[-1]: - does_commute = True - - # Check if the current gate commutes with all the gates in the current block - for prev_gate in current_comm_set[-1]: - does_commute = ( - isinstance(current_gate, DAGOpNode) - and isinstance(prev_gate, DAGOpNode) - and self.comm_checker.commute_nodes(current_gate, prev_gate) - ) - if not does_commute: - break - - if does_commute: - current_comm_set[-1].append(current_gate) - else: - current_comm_set.append([current_gate]) - - temp_len = len(current_comm_set) - self.property_set["commutation_set"][(current_gate, wire)] = temp_len - 1 + self.property_set["commutation_set"] = analyze_commutations(dag, self.comm_checker.cc) diff --git a/releasenotes/notes/oxidize-commutation-analysis-d2fc81feb6ca80aa.yaml b/releasenotes/notes/oxidize-commutation-analysis-d2fc81feb6ca80aa.yaml new file mode 100644 index 000000000000..967a65bcebab --- /dev/null +++ b/releasenotes/notes/oxidize-commutation-analysis-d2fc81feb6ca80aa.yaml @@ -0,0 +1,4 @@ +--- +features_transpiler: + - | + Added a Rust implementation of :class:`.CommutationAnalysis` in :func:`.analyze_commutations`.