From cff062866b2b028a83343782b8f867bcc169d131 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 10:21:33 -0800 Subject: [PATCH 1/5] Added `nodes` field to `EGraph` to avoid storing nodes in `analysis` and `analysis_pending` --- src/eclass.rs | 10 +++++----- src/egraph.rs | 41 +++++++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/eclass.rs b/src/eclass.rs index 5f74b2c2..640dea63 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -17,8 +17,8 @@ pub struct EClass { /// Modifying this field will _not_ cause changes to propagate through the e-graph. /// Prefer [`EGraph::set_analysis_data`] instead. pub data: D, - /// The parent enodes and their original Ids. - pub(crate) parents: Vec<(L, Id)>, + /// The original Ids of parent enodes. + pub(crate) parents: Vec, } impl EClass { @@ -37,9 +37,9 @@ impl EClass { self.nodes.iter() } - /// Iterates over the parent enodes of this eclass. - pub fn parents(&self) -> impl ExactSizeIterator { - self.parents.iter().map(|(node, id)| (node, *id)) + /// Iterates over the non-canonical ids of parent enodes of this eclass. + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() } } diff --git a/src/egraph.rs b/src/egraph.rs index 6af452b2..f05456de 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -57,6 +57,8 @@ pub struct EGraph> { /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, unionfind: UnionFind, + /// Stores the original node represented by each non-canonical id + nodes: Vec, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. @@ -64,8 +66,8 @@ pub struct EGraph> { memo: HashMap, /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, /// not the canonical id of the eclass. - pending: Vec<(L, Id)>, - analysis_pending: UniqueQueue<(L, Id)>, + pending: Vec, + analysis_pending: UniqueQueue, #[cfg_attr( feature = "serde-1", serde(bound( @@ -114,6 +116,7 @@ impl> EGraph { analysis, classes: Default::default(), unionfind: Default::default(), + nodes: Default::default(), clean: false, explain: None, pending: Default::default(), @@ -769,7 +772,9 @@ impl> EGraph { *existing_explain } else { let new_id = self.unionfind.make_set(); - explain.add(original, new_id, new_id); + explain.add(original.clone(), new_id, new_id); + self.nodes.push(original); + debug_assert_eq!(Id::from(self.nodes.len()), new_id); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -778,7 +783,7 @@ impl> EGraph { existing_id } } else { - let id = self.make_new_eclass(enode); + let id = self.make_new_eclass(enode, original.clone()); if let Some(explain) = self.explain.as_mut() { explain.add(original, id, id); } @@ -791,24 +796,26 @@ impl> EGraph { } /// This function makes a new eclass in the egraph (but doesn't touch explanations) - fn make_new_eclass(&mut self, enode: L) -> Id { + fn make_new_eclass(&mut self, enode: L, original: L) -> Id { let id = self.unionfind.make_set(); log::trace!(" ...adding to {}", id); let class = EClass { id, nodes: vec![enode.clone()], - data: N::make(self, &enode), + data: N::make(self, &original), parents: Default::default(), }; + self.nodes.push(original); + debug_assert_eq!(Id::from(self.nodes.len()), id); + // add this enode to the parent lists of its children enode.for_each(|child| { - let tup = (enode.clone(), id); - self[child].parents.push(tup); + self[child].parents.push(id); }); // TODO is this needed? - self.pending.push((enode.clone(), id)); + self.pending.push(id); self.classes.insert(id, class); assert!(self.memo.insert(enode, id).is_none()); @@ -943,13 +950,13 @@ impl> EGraph { let class1 = self.classes.get_mut(&id1).unwrap(); assert_eq!(id1, class1.id); - self.pending.extend(class2.parents.iter().cloned()); + self.pending.extend(class2.parents.iter().copied()); let did_merge = self.analysis.merge(&mut class1.data, class2.data); if did_merge.0 { - self.analysis_pending.extend(class1.parents.iter().cloned()); + self.analysis_pending.extend(class1.parents.iter().copied()); } if did_merge.1 { - self.analysis_pending.extend(class2.parents.iter().cloned()); + self.analysis_pending.extend(class2.parents.iter().copied()); } concat_vecs(&mut class1.nodes, class2.nodes); @@ -968,7 +975,7 @@ impl> EGraph { let id = self.find_mut(id); let class = self.classes.get_mut(&id).unwrap(); class.data = new_data; - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents.iter().copied()); N::modify(self, id) } @@ -1103,7 +1110,8 @@ impl> EGraph { let mut n_unions = 0; while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some((mut node, class)) = self.pending.pop() { + while let Some(class) = self.pending.pop() { + let mut node = self.nodes[usize::from(class)].clone(); node.update_children(|id| self.find_mut(id)); if let Some(memo_class) = self.memo.insert(node, class) { let did_something = self.perform_union( @@ -1116,14 +1124,15 @@ impl> EGraph { } } - while let Some((node, class_id)) = self.analysis_pending.pop() { + while let Some(class_id) = self.analysis_pending.pop() { + let node = self.nodes[usize::from(class_id)].clone(); let class_id = self.find_mut(class_id); let node_data = N::make(self, &node); let class = self.classes.get_mut(&class_id).unwrap(); let did_merge = self.analysis.merge(&mut class.data, node_data); if did_merge.0 { - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents.iter().copied()); N::modify(self, class_id) } } From 608d584aa4939e522cd2070bc94c1d9332158d74 Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 11:43:49 -0800 Subject: [PATCH 2/5] eliminated `node` field of `ExplainNode` (used `EGraph.nodes` instead) --- src/egraph.rs | 107 ++++++++++++++------ src/explain.rs | 260 ++++++++++++++++++++----------------------------- 2 files changed, 186 insertions(+), 181 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index f05456de..3e0d8225 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -217,12 +217,11 @@ impl> EGraph { /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { - if let Some(explain) = &self.explain { - let egraph = Self::new(analysis); - explain.populate_enodes(egraph) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); + let mut egraph = Self::new(analysis); + for node in &self.nodes { + egraph.add(node.clone()); } + egraph } /// Performs the union between two egraphs. @@ -342,20 +341,33 @@ impl> EGraph { /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical), /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical)) pub fn id_to_expr(&self, id: Id) -> RecExpr { - if let Some(explain) = &self.explain { - explain.node_to_recexpr(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); + let mut res = Default::default(); + let mut cache = Default::default(); + self.id_to_expr_internal(&mut res, id, &mut cache); + res + } + + fn id_to_expr_internal( + &self, + res: &mut RecExpr, + node_id: Id, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let new_node = self + .id_to_node(node_id) + .clone() + .map_children(|child| self.id_to_expr_internal(res, child, cache)); + let res_id = res.add(new_node); + cache.insert(node_id, res_id); + res_id } /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep pub fn id_to_node(&self, id: Id) -> &L { - if let Some(explain) = &self.explain { - explain.node(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } + &self.nodes[usize::from(id)] } /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. @@ -363,11 +375,36 @@ impl> EGraph { /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { - if let Some(explain) = &self.explain { - explain.node_to_pattern(id, substitutions) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id"); + let mut res = Default::default(); + let mut subst = Default::default(); + let mut cache = Default::default(); + self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache); + (Pattern::new(res), subst) + } + + fn id_to_pattern_internal( + &self, + res: &mut PatternAst, + node_id: Id, + var_substitutions: &HashMap, + subst: &mut Subst, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let res_id = if let Some(existing) = var_substitutions.get(&node_id) { + let var = format!("?{}", node_id).parse().unwrap(); + subst.insert(var, *existing); + res.add(ENodeOrVar::Var(var)) + } else { + let new_node = self.id_to_node(node_id).clone().map_children(|child| { + self.id_to_pattern_internal(res, child, var_substitutions, subst, cache) + }); + res.add(ENodeOrVar::ENode(new_node)) + }; + cache.insert(node_id, res_id); + res_id } /// Get all the unions ever found in the egraph in terms of enode ids. @@ -393,8 +430,10 @@ impl> EGraph { /// Get the number of congruences between nodes in the egraph. /// Only available when explanations are enabled. pub fn get_num_congr(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_congr::(&self.classes, &self.unionfind) + if let Some(explain) = &mut self.explain { + explain + .with_nodes(&self.nodes) + .get_num_congr::(&self.classes, &self.unionfind) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -402,8 +441,8 @@ impl> EGraph { /// Get the number of nodes in the egraph used for explanations. pub fn get_explanation_num_nodes(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_nodes() + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).get_num_nodes() } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -441,7 +480,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -464,7 +508,7 @@ impl> EGraph { /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -478,7 +522,7 @@ impl> EGraph { ) -> Explanation { let id = self.add_instantiation_noncanonical(pattern, subst); if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -501,7 +545,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations."); } @@ -1213,9 +1262,9 @@ impl> EGraph { n_unions } - pub(crate) fn check_each_explain(&self, rules: &[&Rewrite]) -> bool { - if let Some(explain) = &self.explain { - explain.check_each_explain(rules) + pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).check_each_explain(rules) } else { panic!("Can't check explain when explanations are off"); } diff --git a/src/explain.rs b/src/explain.rs index 187aecfc..59315615 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,12 +1,13 @@ use crate::Symbol; use crate::{ - util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id, - Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var, + util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, + PatternAst, RecExpr, Rewrite, UnionFind, Var, }; use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; use symbolic_expressions::Sexp; @@ -38,8 +39,7 @@ struct Connection { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct ExplainNode { - node: L, +struct ExplainNode { // neighbors includes parent connections neighbors: Vec, parent_connection: Connection, @@ -54,7 +54,7 @@ struct ExplainNode { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct Explain { - explainfind: Vec>, + explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. @@ -69,6 +69,11 @@ pub struct Explain { shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } +pub(crate) struct ExplainNodes<'a, L: Language> { + explain: &'a mut Explain, + nodes: &'a [L], +} + #[derive(Default)] struct DistanceMemo { parent_distance: Vec<(Id, ProofCost)>, @@ -883,97 +888,6 @@ impl PartialOrd for HeapState { } impl Explain { - pub(crate) fn node(&self, node_id: Id) -> &L { - &self.explainfind[usize::from(node_id)].node - } - fn node_to_explanation( - &self, - node_id: Id, - cache: &mut NodeExplanationCache, - ) -> Rc> { - if let Some(existing) = cache.get(&node_id) { - existing.clone() - } else { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(vec![self.node_to_explanation(child, cache)]); - sofar - }); - let res = Rc::new(TreeTerm::new(node, children)); - cache.insert(node_id, res.clone()); - res - } - } - - pub(crate) fn node_to_recexpr(&self, node_id: Id) -> RecExpr { - let mut res = Default::default(); - let mut cache = Default::default(); - self.node_to_recexpr_internal(&mut res, node_id, &mut cache); - res - } - fn node_to_recexpr_internal( - &self, - res: &mut RecExpr, - node_id: Id, - cache: &mut HashMap, - ) { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_recexpr_internal(res, child, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(new_node); - } - - pub(crate) fn node_to_pattern( - &self, - node_id: Id, - substitutions: &HashMap, - ) -> (Pattern, Subst) { - let mut res = Default::default(); - let mut subst = Default::default(); - let mut cache = Default::default(); - self.node_to_pattern_internal(&mut res, node_id, substitutions, &mut subst, &mut cache); - (Pattern::new(res), subst) - } - - fn node_to_pattern_internal( - &self, - res: &mut PatternAst, - node_id: Id, - var_substitutions: &HashMap, - subst: &mut Subst, - cache: &mut HashMap, - ) { - if let Some(existing) = var_substitutions.get(&node_id) { - let var = format!("?{}", node_id).parse().unwrap(); - res.add(ENodeOrVar::Var(var)); - subst.insert(var, *existing); - } else { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_pattern_internal(res, child, var_substitutions, subst, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(ENodeOrVar::ENode(new_node)); - } - } - - fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(self.node_to_flat_explanation(child)); - sofar - }); - FlatTerm::new(node, children) - } - fn make_rule_table<'a, N: Analysis>( rules: &[&'a Rewrite], ) -> HashMap> { @@ -983,52 +897,6 @@ impl Explain { } table } - - pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { - let rule_table = Explain::make_rule_table(rules); - for i in 0..self.explainfind.len() { - let explain_node = &self.explainfind[i]; - - // check that explanation reasons never form a cycle - let mut existance = i; - let mut seen_existance: HashSet = Default::default(); - loop { - seen_existance.insert(existance); - let next = usize::from(self.explainfind[existance].existance_node); - if existance == next { - break; - } - existance = next; - if seen_existance.contains(&existance) { - panic!("Cycle in existance!"); - } - } - - if explain_node.parent_connection.next != Id::from(i) { - let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); - let mut next_explanation = - self.node_to_flat_explanation(explain_node.parent_connection.next); - if let Justification::Rule(rule_name) = - &explain_node.parent_connection.justification - { - if let Some(rule) = rule_table.get(rule_name) { - if !explain_node.parent_connection.is_rewrite_forward { - std::mem::swap(&mut current_explanation, &mut next_explanation); - } - if !Explanation::check_rewrite( - ¤t_explanation, - &next_explanation, - rule, - ) { - return false; - } - } - } - } - } - true - } - pub fn new() -> Self { Explain { explainfind: vec![], @@ -1046,7 +914,6 @@ impl Explain { assert_eq!(self.explainfind.len(), usize::from(set)); self.uncanon_memo.insert(node.clone(), set); self.explainfind.push(ExplainNode { - node, neighbors: vec![], parent_connection: Connection { justification: Justification::Congruence, @@ -1119,7 +986,7 @@ impl Explain { new_rhs: bool, ) { if let Justification::Congruence = justification { - assert!(self.node(node1).matches(self.node(node2))); + // assert!(self.node(node1).matches(self.node(node2))); } if new_rhs { self.set_existance_reason(node2, node1) @@ -1155,7 +1022,6 @@ impl Explain { .push(other_pconnection); self.explainfind[usize::from(node1)].parent_connection = pconnection; } - pub(crate) fn get_union_equalities(&self) -> UnionEqualities { let mut equalities = vec![]; for node in &self.explainfind { @@ -1170,13 +1036,103 @@ impl Explain { equalities } - pub(crate) fn populate_enodes>(&self, mut egraph: EGraph) -> EGraph { - for i in 0..self.explainfind.len() { - let node = &self.explainfind[i]; - egraph.add(node.node.clone()); + pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> { + ExplainNodes { + explain: self, + nodes, } + } +} + +impl<'a, L: Language> Deref for ExplainNodes<'a, L> { + type Target = Explain; + + fn deref(&self) -> &Self::Target { + self.explain + } +} + +impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.explain + } +} + +impl<'x, L: Language> ExplainNodes<'x, L> { + pub(crate) fn node(&self, node_id: Id) -> &L { + &self.nodes[usize::from(node_id)] + } + fn node_to_explanation( + &self, + node_id: Id, + cache: &mut NodeExplanationCache, + ) -> Rc> { + if let Some(existing) = cache.get(&node_id) { + existing.clone() + } else { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(vec![self.node_to_explanation(child, cache)]); + sofar + }); + let res = Rc::new(TreeTerm::new(node, children)); + cache.insert(node_id, res.clone()); + res + } + } + + fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(self.node_to_flat_explanation(child)); + sofar + }); + FlatTerm::new(node, children) + } + + pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { + let rule_table = Explain::make_rule_table(rules); + for i in 0..self.explainfind.len() { + let explain_node = &self.explainfind[i]; + + // check that explanation reasons never form a cycle + let mut existance = i; + let mut seen_existance: HashSet = Default::default(); + loop { + seen_existance.insert(existance); + let next = usize::from(self.explainfind[existance].existance_node); + if existance == next { + break; + } + existance = next; + if seen_existance.contains(&existance) { + panic!("Cycle in existance!"); + } + } - egraph + if explain_node.parent_connection.next != Id::from(i) { + let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); + let mut next_explanation = + self.node_to_flat_explanation(explain_node.parent_connection.next); + if let Justification::Rule(rule_name) = + &explain_node.parent_connection.justification + { + if let Some(rule) = rule_table.get(rule_name) { + if !explain_node.parent_connection.is_rewrite_forward { + std::mem::swap(&mut current_explanation, &mut next_explanation); + } + if !Explanation::check_rewrite( + ¤t_explanation, + &next_explanation, + rule, + ) { + return false; + } + } + } + } + } + true } pub(crate) fn explain_equivalence>( @@ -1328,7 +1284,7 @@ impl Explain { let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone(); let mut index_of_child = 0; let mut found = false; - existance_node.node.for_each(|child| { + self.node(existance).for_each(|child| { if found { return; } @@ -2092,7 +2048,7 @@ mod tests { #[test] fn simple_explain_union_trusted() { - use crate::SymbolLang; + use crate::{EGraph, SymbolLang}; crate::init_logger(); let mut egraph = EGraph::new(()).with_explanations_enabled(); From 3b34138857b683b925cd8c65068322425ec5760b Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 12:34:37 -0800 Subject: [PATCH 3/5] serde --- src/explain.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/explain.rs b/src/explain.rs index 59315615..a2d0a2b2 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -56,6 +56,10 @@ struct ExplainNode { pub struct Explain { explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + #[cfg_attr( + feature = "serde-1", + serde(bound(serialize = "L: Serialize", deserialize = "L: for<'a> Deserialize<'a>",)) + )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. pub optimize_explanation_lengths: bool, @@ -912,7 +916,7 @@ impl Explain { pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id { assert_eq!(self.explainfind.len(), usize::from(set)); - self.uncanon_memo.insert(node.clone(), set); + self.uncanon_memo.insert(node, set); self.explainfind.push(ExplainNode { neighbors: vec![], parent_connection: Connection { From a7423dab7dc76fdacf7bd81c75d6e0b14b76edfc Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 13:08:41 -0800 Subject: [PATCH 4/5] serde --- src/explain.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/explain.rs b/src/explain.rs index a2d0a2b2..9de2a17e 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -58,7 +58,10 @@ pub struct Explain { #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] #[cfg_attr( feature = "serde-1", - serde(bound(serialize = "L: Serialize", deserialize = "L: for<'a> Deserialize<'a>",)) + serde(bound( + serialize = "L: serde::Serialize", + deserialize = "L: serde::Deserialize<'de>", + )) )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. From eb5b54846e6d9dc40d4da53f38116274e0e27f0f Mon Sep 17 00:00:00 2001 From: dewert99 Date: Wed, 3 Jan 2024 19:22:11 -0800 Subject: [PATCH 5/5] Clarify `id_to_expr` and prevent `copy_with_unions` when explanations are disabled --- src/egraph.rs | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 3e0d8225..3f292460 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -217,6 +217,9 @@ impl> EGraph { /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { + if self.explain.is_none() { + panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); + } let mut egraph = Self::new(analysis); for node in &self.nodes { egraph.add(node.clone()); @@ -638,7 +641,7 @@ impl> EGraph { /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical /// - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { let nodes = expr.as_ref(); let mut new_ids = Vec::with_capacity(nodes.len()); @@ -676,7 +679,7 @@ impl> EGraph { /// canonical /// /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { let nodes = pat.as_ref(); @@ -796,7 +799,7 @@ impl> EGraph { /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will /// correspond to the parameter `enode` /// - /// # Example + /// ## Example /// ``` /// # use egg::*; /// let mut egraph: EGraph = EGraph::default().with_explanations_enabled(); @@ -811,6 +814,25 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap()); /// ``` + /// + /// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// produce an expression with equivalent but not necessarily identical children + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = EGraph::default().with_explanations_disabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.union(a, b); + /// egraph.rebuild(); + /// + /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + /// + /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); + /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap()); + /// ``` pub fn add_uncanonical(&mut self, mut enode: L) -> Id { let original = enode.clone(); if let Some(existing_id) = self.lookup_internal(&mut enode) { @@ -822,8 +844,8 @@ impl> EGraph { } else { let new_id = self.unionfind.make_set(); explain.add(original.clone(), new_id, new_id); - self.nodes.push(original); debug_assert_eq!(Id::from(self.nodes.len()), new_id); + self.nodes.push(original); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -855,8 +877,8 @@ impl> EGraph { parents: Default::default(), }; - self.nodes.push(original); debug_assert_eq!(Id::from(self.nodes.len()), id); + self.nodes.push(original); // add this enode to the parent lists of its children enode.for_each(|child| {