From fe8dfa7abda9a80b3d1129117638118ce71c1f51 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com> Date: Mon, 22 Jul 2024 23:34:57 -0400 Subject: [PATCH] Fix: Wrong return value for `basis_search` --- .../basis/basis_translator/basis_search.rs | 79 +++++++++---------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/crates/accelerate/src/basis/basis_translator/basis_search.rs b/crates/accelerate/src/basis/basis_translator/basis_search.rs index 3cbe2f3e81c0..5675d354d3e1 100644 --- a/crates/accelerate/src/basis/basis_translator/basis_search.rs +++ b/crates/accelerate/src/basis/basis_translator/basis_search.rs @@ -31,14 +31,15 @@ pub(crate) fn py_basis_search( equiv_lib: &mut EquivalenceLibrary, source_basis: &Bound, target_basis: &Bound, -) -> PyResult, CircuitRep)>>> { - let source_basis: PyResult> = +) -> PyResult { + let new_source_basis: PyResult> = source_basis.iter().map(|item| item.extract()).collect(); - let target_basis: PyResult> = + let new_target_basis: PyResult> = target_basis.iter().map(|item| item.extract()).collect(); - Ok(basis_search(equiv_lib, source_basis?, target_basis?)) + Ok(basis_search(equiv_lib, new_source_basis?, new_target_basis?).into_py(source_basis.py())) } +type BasisTransforms = Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>; /// Search for a set of transformations from source_basis to target_basis. /// Args: /// equiv_lib (EquivalenceLibrary): Source of valid translations @@ -54,7 +55,7 @@ pub(crate) fn basis_search( equiv_lib: &mut EquivalenceLibrary, source_basis: IndexSet<(String, u32), RandomState>, target_basis: IndexSet, -) -> Option, CircuitRep)>> { +) -> Option { // Build the visitor attributes: let mut num_gates_remaining_for_rule: IndexMap = IndexMap::default(); let predecessors: IndexMap<(&str, u32), Equivalence, RandomState> = IndexMap::default(); @@ -67,7 +68,6 @@ pub(crate) fn basis_search( // Initialize visitor attributes: initialize_num_gates_remain_for_rule(&equiv_lib.graph, &mut num_gates_remaining_for_rule); - // println!("{:#?}", num_gates_remaining_for_rule); // TODO: Logs let mut source_basis_remain: IndexSet = source_basis @@ -96,7 +96,8 @@ pub(crate) fn basis_search( .cloned() .filter(|key| target_basis.contains(&key.name)) .collect(); - // println!("Target basis keys {:#?}", target_basis_keys); + + // Dummy node is inserted in the graph. Which is where the search will start let dummy: NodeIndex = equiv_lib.graph.add_node(NodeData { equivs: vec![], key: Key { @@ -105,17 +106,18 @@ pub(crate) fn basis_search( }, }); + // Extract indices for the target_basis gates, to avoid borrowing from graph. let target_basis_indices: Vec = target_basis_keys .iter() .map(|key| equiv_lib.node_index(key)) .collect(); + // Connect each edge in the target_basis to the dummy node. target_basis_indices.iter().for_each(|node| { equiv_lib.graph.add_edge(dummy, *node, None); }); - // Build visitor methods - + // Edge cost function for Visitor let edge_weight = |edge: EdgeReference>| -> Result { if edge.weight().is_none() { @@ -131,7 +133,7 @@ pub(crate) fn basis_search( - borrowed_cost[&(edge_data.source.name.as_str(), edge_data.source.num_qubits)]) }; - dijkstra_search( + let basis_transforms = match dijkstra_search( &equiv_lib.graph, [dummy], edge_weight, @@ -161,44 +163,41 @@ pub(crate) fn basis_search( return Control::Break(()); } } - DijkstraEvent::EdgeRelaxed(_, target, edata) => { - if let Some(edata) = edata { - let gate = &equiv_lib.graph[target].key; - predecessors_cell - .borrow_mut() - .entry((gate.name.as_str(), gate.num_qubits)) - .and_modify(|value| *value = edata.rule.clone()) - .or_insert(edata.rule.clone()); - } + DijkstraEvent::EdgeRelaxed(_, target, Some(edata)) => { + let gate = &equiv_lib.graph[target].key; + predecessors_cell + .borrow_mut() + .entry((gate.name.as_str(), gate.num_qubits)) + .and_modify(|value| *value = edata.rule.clone()) + .or_insert(edata.rule.clone()); } - DijkstraEvent::ExamineEdge(_, target, edata) => { - if edata.is_some() { - let edata = edata.as_ref().unwrap(); - num_gates_remaining_for_rule - .entry(edata.index) - .and_modify(|val| *val -= 1) - .or_insert(0); - let target = &equiv_lib.graph[target].key; - - if num_gates_remaining_for_rule[&edata.index] > 0 - || target_basis_keys.contains(target) - { - return Control::Prune; - } + DijkstraEvent::ExamineEdge(_, target, Some(edata)) => { + num_gates_remaining_for_rule + .entry(edata.index) + .and_modify(|val| *val -= 1) + .or_insert(0); + let target = &equiv_lib.graph[target].key; + + // If there are gates in this `rule` that we have not yet generated, we can't apply + // this `rule`. if `target` is already in basis, it's not beneficial to use this rule. + if num_gates_remaining_for_rule[&edata.index] > 0 + || target_basis_keys.contains(target) + { + return Control::Prune; } } _ => {} }; Control::Continue }, - ) - .unwrap(); - // Values will have to be cloned in order for the dummy node to be removed. - // Will be revised - drop(opt_cost_map_cell); - drop(predecessors_cell); + ) { + Ok(Control::Break(_)) => Some(basis_transforms), + _ => None, + }; + + // TODO: Values will have to be cloned in order for the dummy node to be removed. equiv_lib.graph.remove_node(dummy); - Some(basis_transforms) + basis_transforms } fn initialize_num_gates_remain_for_rule(