Skip to content

Commit

Permalink
Fix: Wrong return value for basis_search
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Jul 25, 2024
1 parent 7006e12 commit fe8dfa7
Showing 1 changed file with 39 additions and 40 deletions.
79 changes: 39 additions & 40 deletions crates/accelerate/src/basis/basis_translator/basis_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ pub(crate) fn py_basis_search(
equiv_lib: &mut EquivalenceLibrary,
source_basis: &Bound<PySet>,
target_basis: &Bound<PySet>,
) -> PyResult<Option<Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>>> {
let source_basis: PyResult<IndexSet<(String, u32), RandomState>> =
) -> PyResult<PyObject> {
let new_source_basis: PyResult<IndexSet<(String, u32), RandomState>> =
source_basis.iter().map(|item| item.extract()).collect();
let target_basis: PyResult<IndexSet<String, RandomState>> =
let new_target_basis: PyResult<IndexSet<String, RandomState>> =
target_basis.iter().map(|item| item.extract()).collect();
Ok(basis_search(equiv_lib, source_basis?, target_basis?))
Ok(basis_search(equiv_lib, new_source_basis?, new_target_basis?).into_py(source_basis.py()))
}

type BasisTransforms = Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>;
/// Search for a set of transformations from source_basis to target_basis.
/// Args:
/// equiv_lib (EquivalenceLibrary): Source of valid translations
Expand All @@ -54,7 +55,7 @@ pub(crate) fn basis_search(
equiv_lib: &mut EquivalenceLibrary,
source_basis: IndexSet<(String, u32), RandomState>,
target_basis: IndexSet<String, RandomState>,
) -> Option<Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)>> {
) -> Option<BasisTransforms> {
// Build the visitor attributes:
let mut num_gates_remaining_for_rule: IndexMap<usize, usize, RandomState> = IndexMap::default();
let predecessors: IndexMap<(&str, u32), Equivalence, RandomState> = IndexMap::default();
Expand All @@ -67,7 +68,6 @@ pub(crate) fn basis_search(

// Initialize visitor attributes:
initialize_num_gates_remain_for_rule(&equiv_lib.graph, &mut num_gates_remaining_for_rule);
// println!("{:#?}", num_gates_remaining_for_rule);

// TODO: Logs
let mut source_basis_remain: IndexSet<Key> = source_basis
Expand Down Expand Up @@ -96,7 +96,8 @@ pub(crate) fn basis_search(
.cloned()
.filter(|key| target_basis.contains(&key.name))
.collect();
// println!("Target basis keys {:#?}", target_basis_keys);

// Dummy node is inserted in the graph. Which is where the search will start
let dummy: NodeIndex = equiv_lib.graph.add_node(NodeData {
equivs: vec![],
key: Key {
Expand All @@ -105,17 +106,18 @@ pub(crate) fn basis_search(
},
});

// Extract indices for the target_basis gates, to avoid borrowing from graph.
let target_basis_indices: Vec<NodeIndex> = target_basis_keys
.iter()
.map(|key| equiv_lib.node_index(key))
.collect();

// Connect each edge in the target_basis to the dummy node.
target_basis_indices.iter().for_each(|node| {
equiv_lib.graph.add_edge(dummy, *node, None);
});

// Build visitor methods

// Edge cost function for Visitor
let edge_weight =
|edge: EdgeReference<Option<EdgeData>>| -> Result<u32, ()> {
if edge.weight().is_none() {
Expand All @@ -131,7 +133,7 @@ pub(crate) fn basis_search(
- borrowed_cost[&(edge_data.source.name.as_str(), edge_data.source.num_qubits)])
};

dijkstra_search(
let basis_transforms = match dijkstra_search(
&equiv_lib.graph,
[dummy],
edge_weight,
Expand Down Expand Up @@ -161,44 +163,41 @@ pub(crate) fn basis_search(
return Control::Break(());
}
}
DijkstraEvent::EdgeRelaxed(_, target, edata) => {
if let Some(edata) = edata {
let gate = &equiv_lib.graph[target].key;
predecessors_cell
.borrow_mut()
.entry((gate.name.as_str(), gate.num_qubits))
.and_modify(|value| *value = edata.rule.clone())
.or_insert(edata.rule.clone());
}
DijkstraEvent::EdgeRelaxed(_, target, Some(edata)) => {
let gate = &equiv_lib.graph[target].key;
predecessors_cell
.borrow_mut()
.entry((gate.name.as_str(), gate.num_qubits))
.and_modify(|value| *value = edata.rule.clone())
.or_insert(edata.rule.clone());
}
DijkstraEvent::ExamineEdge(_, target, edata) => {
if edata.is_some() {
let edata = edata.as_ref().unwrap();
num_gates_remaining_for_rule
.entry(edata.index)
.and_modify(|val| *val -= 1)
.or_insert(0);
let target = &equiv_lib.graph[target].key;

if num_gates_remaining_for_rule[&edata.index] > 0
|| target_basis_keys.contains(target)
{
return Control::Prune;
}
DijkstraEvent::ExamineEdge(_, target, Some(edata)) => {
num_gates_remaining_for_rule
.entry(edata.index)
.and_modify(|val| *val -= 1)
.or_insert(0);
let target = &equiv_lib.graph[target].key;

// If there are gates in this `rule` that we have not yet generated, we can't apply
// this `rule`. if `target` is already in basis, it's not beneficial to use this rule.
if num_gates_remaining_for_rule[&edata.index] > 0
|| target_basis_keys.contains(target)
{
return Control::Prune;
}
}
_ => {}
};
Control::Continue
},
)
.unwrap();
// Values will have to be cloned in order for the dummy node to be removed.
// Will be revised
drop(opt_cost_map_cell);
drop(predecessors_cell);
) {
Ok(Control::Break(_)) => Some(basis_transforms),
_ => None,
};

// TODO: Values will have to be cloned in order for the dummy node to be removed.
equiv_lib.graph.remove_node(dummy);
Some(basis_transforms)
basis_transforms
}

fn initialize_num_gates_remain_for_rule(
Expand Down

0 comments on commit fe8dfa7

Please sign in to comment.