Skip to content

Commit

Permalink
Merge pull request #1 from dewert99/egraph_nodes
Browse files Browse the repository at this point in the history
Egraph nodes
  • Loading branch information
dewert99 authored Mar 20, 2024
2 parents 9b8c9f3 + eb5b548 commit 8d1ec3c
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 205 deletions.
10 changes: 5 additions & 5 deletions src/eclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pub struct EClass<L, D> {
/// Modifying this field will _not_ cause changes to propagate through the e-graph.
/// Prefer [`EGraph::set_analysis_data`] instead.
pub data: D,
/// The parent enodes and their original Ids.
pub(crate) parents: Vec<(L, Id)>,
/// The original Ids of parent enodes.
pub(crate) parents: Vec<Id>,
}

impl<L, D> EClass<L, D> {
Expand All @@ -37,9 +37,9 @@ impl<L, D> EClass<L, D> {
self.nodes.iter()
}

/// Iterates over the parent enodes of this eclass.
pub fn parents(&self) -> impl ExactSizeIterator<Item = (&L, Id)> {
self.parents.iter().map(|(node, id)| (node, *id))
/// Iterates over the non-canonical ids of parent enodes of this eclass.
pub fn parents(&self) -> impl ExactSizeIterator<Item = Id> + '_ {
self.parents.iter().copied()
}
}

Expand Down
174 changes: 127 additions & 47 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,17 @@ pub struct EGraph<L: Language, N: Analysis<L>> {
/// The `Explain` used to explain equivalences in this `EGraph`.
pub(crate) explain: Option<Explain<L>>,
unionfind: UnionFind,
/// Stores the original node represented by each non-canonical id
nodes: Vec<L>,
/// Stores each enode's `Id`, not the `Id` of the eclass.
/// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new
/// unions can cause them to become out of date.
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
memo: HashMap<L, Id>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pending: Vec<(L, Id)>,
analysis_pending: UniqueQueue<(L, Id)>,
pending: Vec<Id>,
analysis_pending: UniqueQueue<Id>,
#[cfg_attr(
feature = "serde-1",
serde(bound(
Expand Down Expand Up @@ -114,6 +116,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
analysis,
classes: Default::default(),
unionfind: Default::default(),
nodes: Default::default(),
clean: false,
explain: None,
pending: Default::default(),
Expand Down Expand Up @@ -214,12 +217,14 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

/// Make a copy of the egraph with the same nodes, but no unions between them.
pub fn copy_without_unions(&self, analysis: N) -> Self {
if let Some(explain) = &self.explain {
let egraph = Self::new(analysis);
explain.populate_enodes(egraph)
} else {
if self.explain.is_none() {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions");
}
let mut egraph = Self::new(analysis);
for node in &self.nodes {
egraph.add(node.clone());
}
egraph
}

/// Performs the union between two egraphs.
Expand Down Expand Up @@ -339,32 +344,70 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical),
/// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical))
pub fn id_to_expr(&self, id: Id) -> RecExpr<L> {
if let Some(explain) = &self.explain {
explain.node_to_recexpr(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id");
let mut res = Default::default();
let mut cache = Default::default();
self.id_to_expr_internal(&mut res, id, &mut cache);
res
}

fn id_to_expr_internal(
&self,
res: &mut RecExpr<L>,
node_id: Id,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let new_node = self
.id_to_node(node_id)
.clone()
.map_children(|child| self.id_to_expr_internal(res, child, cache));
let res_id = res.add(new_node);
cache.insert(node_id, res_id);
res_id
}

/// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep
pub fn id_to_node(&self, id: Id) -> &L {
if let Some(explain) = &self.explain {
explain.node(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id");
}
&self.nodes[usize::from(id)]
}

/// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term.
/// When an eclass listed in the given substitutions is found, it creates a variable.
/// It also adds this variable and the corresponding Id value to the resulting [`Subst`]
/// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr).
pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap<Id, Id>) -> (Pattern<L>, Subst) {
if let Some(explain) = &self.explain {
explain.node_to_pattern(id, substitutions)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id");
let mut res = Default::default();
let mut subst = Default::default();
let mut cache = Default::default();
self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache);
(Pattern::new(res), subst)
}

fn id_to_pattern_internal(
&self,
res: &mut PatternAst<L>,
node_id: Id,
var_substitutions: &HashMap<Id, Id>,
subst: &mut Subst,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let res_id = if let Some(existing) = var_substitutions.get(&node_id) {
let var = format!("?{}", node_id).parse().unwrap();
subst.insert(var, *existing);
res.add(ENodeOrVar::Var(var))
} else {
let new_node = self.id_to_node(node_id).clone().map_children(|child| {
self.id_to_pattern_internal(res, child, var_substitutions, subst, cache)
});
res.add(ENodeOrVar::ENode(new_node))
};
cache.insert(node_id, res_id);
res_id
}

/// Get all the unions ever found in the egraph in terms of enode ids.
Expand All @@ -390,17 +433,19 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Get the number of congruences between nodes in the egraph.
/// Only available when explanations are enabled.
pub fn get_num_congr(&mut self) -> usize {
if let Some(explain) = &self.explain {
explain.get_num_congr::<N>(&self.classes, &self.unionfind)
if let Some(explain) = &mut self.explain {
explain
.with_nodes(&self.nodes)
.get_num_congr::<N>(&self.classes, &self.unionfind)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}

/// Get the number of nodes in the egraph used for explanations.
pub fn get_explanation_num_nodes(&mut self) -> usize {
if let Some(explain) = &self.explain {
explain.get_num_nodes()
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).get_num_nodes()
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand Down Expand Up @@ -438,7 +483,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
);
}
if let Some(explain) = &mut self.explain {
explain.explain_equivalence::<N>(left, right, &mut self.unionfind, &self.classes)
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -461,7 +511,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// but more efficient
fn explain_existance_id(&mut self, id: Id) -> Explanation<L> {
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -475,7 +525,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
) -> Explanation<L> {
let id = self.add_instantiation_noncanonical(pattern, subst);
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -498,7 +548,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
);
}
if let Some(explain) = &mut self.explain {
explain.explain_equivalence::<N>(left, right, &mut self.unionfind, &self.classes)
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.");
}
Expand Down Expand Up @@ -586,7 +641,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

/// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical
///
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr`
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled
pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
Expand Down Expand Up @@ -624,7 +679,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// canonical
///
/// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the
/// instantiation of the pattern
fn add_instantiation_noncanonical(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let nodes = pat.as_ref();
Expand Down Expand Up @@ -744,7 +799,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
/// correspond to the parameter `enode`
///
/// # Example
/// ## Example
/// ```
/// # use egg::*;
/// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_enabled();
Expand All @@ -759,6 +814,25 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
/// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap());
/// ```
///
/// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
/// produce an expression with equivalent but not necessarily identical children
///
/// # Example
/// ```
/// # use egg::*;
/// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_disabled();
/// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
/// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
/// egraph.union(a, b);
/// egraph.rebuild();
///
/// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
/// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
///
/// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
/// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap());
/// ```
pub fn add_uncanonical(&mut self, mut enode: L) -> Id {
let original = enode.clone();
if let Some(existing_id) = self.lookup_internal(&mut enode) {
Expand All @@ -769,7 +843,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
*existing_explain
} else {
let new_id = self.unionfind.make_set();
explain.add(original, new_id, new_id);
explain.add(original.clone(), new_id, new_id);
debug_assert_eq!(Id::from(self.nodes.len()), new_id);
self.nodes.push(original);
self.unionfind.union(id, new_id);
explain.union(existing_id, new_id, Justification::Congruence, true);
new_id
Expand All @@ -778,7 +854,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
existing_id
}
} else {
let id = self.make_new_eclass(enode);
let id = self.make_new_eclass(enode, original.clone());
if let Some(explain) = self.explain.as_mut() {
explain.add(original, id, id);
}
Expand All @@ -791,24 +867,26 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}

/// This function makes a new eclass in the egraph (but doesn't touch explanations)
fn make_new_eclass(&mut self, enode: L) -> Id {
fn make_new_eclass(&mut self, enode: L, original: L) -> Id {
let id = self.unionfind.make_set();
log::trace!(" ...adding to {}", id);
let class = EClass {
id,
nodes: vec![enode.clone()],
data: N::make(self, &enode),
data: N::make(self, &original),
parents: Default::default(),
};

debug_assert_eq!(Id::from(self.nodes.len()), id);
self.nodes.push(original);

// add this enode to the parent lists of its children
enode.for_each(|child| {
let tup = (enode.clone(), id);
self[child].parents.push(tup);
self[child].parents.push(id);
});

// TODO is this needed?
self.pending.push((enode.clone(), id));
self.pending.push(id);

self.classes.insert(id, class);
assert!(self.memo.insert(enode, id).is_none());
Expand Down Expand Up @@ -943,13 +1021,13 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let class1 = self.classes.get_mut(&id1).unwrap();
assert_eq!(id1, class1.id);

self.pending.extend(class2.parents.iter().cloned());
self.pending.extend(class2.parents.iter().copied());
let did_merge = self.analysis.merge(&mut class1.data, class2.data);
if did_merge.0 {
self.analysis_pending.extend(class1.parents.iter().cloned());
self.analysis_pending.extend(class1.parents.iter().copied());
}
if did_merge.1 {
self.analysis_pending.extend(class2.parents.iter().cloned());
self.analysis_pending.extend(class2.parents.iter().copied());
}

concat_vecs(&mut class1.nodes, class2.nodes);
Expand All @@ -968,7 +1046,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let id = self.find_mut(id);
let class = self.classes.get_mut(&id).unwrap();
class.data = new_data;
self.analysis_pending.extend(class.parents.iter().cloned());
self.analysis_pending.extend(class.parents.iter().copied());
N::modify(self, id)
}

Expand Down Expand Up @@ -1103,7 +1181,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let mut n_unions = 0;

while !self.pending.is_empty() || !self.analysis_pending.is_empty() {
while let Some((mut node, class)) = self.pending.pop() {
while let Some(class) = self.pending.pop() {
let mut node = self.nodes[usize::from(class)].clone();
node.update_children(|id| self.find_mut(id));
if let Some(memo_class) = self.memo.insert(node, class) {
let did_something = self.perform_union(
Expand All @@ -1116,14 +1195,15 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

while let Some((node, class_id)) = self.analysis_pending.pop() {
while let Some(class_id) = self.analysis_pending.pop() {
let node = self.nodes[usize::from(class_id)].clone();
let class_id = self.find_mut(class_id);
let node_data = N::make(self, &node);
let class = self.classes.get_mut(&class_id).unwrap();

let did_merge = self.analysis.merge(&mut class.data, node_data);
if did_merge.0 {
self.analysis_pending.extend(class.parents.iter().cloned());
self.analysis_pending.extend(class.parents.iter().copied());
N::modify(self, class_id)
}
}
Expand Down Expand Up @@ -1204,9 +1284,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
n_unions
}

pub(crate) fn check_each_explain(&self, rules: &[&Rewrite<L, N>]) -> bool {
if let Some(explain) = &self.explain {
explain.check_each_explain(rules)
pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite<L, N>]) -> bool {
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).check_each_explain(rules)
} else {
panic!("Can't check explain when explanations are off");
}
Expand Down
Loading

0 comments on commit 8d1ec3c

Please sign in to comment.