diff --git a/src/eclass.rs b/src/eclass.rs index 5f74b2c2..640dea63 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -17,8 +17,8 @@ pub struct EClass { /// Modifying this field will _not_ cause changes to propagate through the e-graph. /// Prefer [`EGraph::set_analysis_data`] instead. pub data: D, - /// The parent enodes and their original Ids. - pub(crate) parents: Vec<(L, Id)>, + /// The original Ids of parent enodes. + pub(crate) parents: Vec, } impl EClass { @@ -37,9 +37,9 @@ impl EClass { self.nodes.iter() } - /// Iterates over the parent enodes of this eclass. - pub fn parents(&self) -> impl ExactSizeIterator { - self.parents.iter().map(|(node, id)| (node, *id)) + /// Iterates over the non-canonical ids of parent enodes of this eclass. + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() } } diff --git a/src/egraph.rs b/src/egraph.rs index 472defe7..5c15c0a2 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -58,6 +58,8 @@ pub struct EGraph> { /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, unionfind: UnionFind, + /// Stores the original node represented by each non-canonical id + nodes: Vec, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. @@ -65,8 +67,8 @@ pub struct EGraph> { memo: HashMap, /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, /// not the canonical id of the eclass. - pending: Vec<(L, Id)>, - analysis_pending: UniqueQueue<(L, Id)>, + pending: Vec, + analysis_pending: UniqueQueue, #[cfg_attr( feature = "serde-1", serde(bound( @@ -115,6 +117,7 @@ impl> EGraph { analysis, classes: Default::default(), unionfind: Default::default(), + nodes: Default::default(), clean: false, explain: None, pending: Default::default(), @@ -215,12 +218,14 @@ impl> EGraph { /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { - if let Some(explain) = &self.explain { - let egraph = Self::new(analysis); - explain.populate_enodes(egraph) - } else { + if self.explain.is_none() { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); } + let mut egraph = Self::new(analysis); + for node in &self.nodes { + egraph.add(node.clone()); + } + egraph } /// Performs the union between two egraphs. @@ -340,20 +345,33 @@ impl> EGraph { /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical), /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical)) pub fn id_to_expr(&self, id: Id) -> RecExpr { - if let Some(explain) = &self.explain { - explain.node_to_recexpr(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); + let mut res = Default::default(); + let mut cache = Default::default(); + self.id_to_expr_internal(&mut res, id, &mut cache); + res + } + + fn id_to_expr_internal( + &self, + res: &mut RecExpr, + node_id: Id, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let new_node = self + .id_to_node(node_id) + .clone() + .map_children(|child| self.id_to_expr_internal(res, child, cache)); + let res_id = res.add(new_node); + cache.insert(node_id, res_id); + res_id } /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep pub fn id_to_node(&self, id: Id) -> &L { - if let Some(explain) = &self.explain { - explain.node(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } + &self.nodes[usize::from(id)] } /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. @@ -361,11 +379,36 @@ impl> EGraph { /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { - if let Some(explain) = &self.explain { - explain.node_to_pattern(id, substitutions) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id"); + let mut res = Default::default(); + let mut subst = Default::default(); + let mut cache = Default::default(); + self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache); + (Pattern::new(res), subst) + } + + fn id_to_pattern_internal( + &self, + res: &mut PatternAst, + node_id: Id, + var_substitutions: &HashMap, + subst: &mut Subst, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let res_id = if let Some(existing) = var_substitutions.get(&node_id) { + let var = format!("?{}", node_id).parse().unwrap(); + subst.insert(var, *existing); + res.add(ENodeOrVar::Var(var)) + } else { + let new_node = self.id_to_node(node_id).clone().map_children(|child| { + self.id_to_pattern_internal(res, child, var_substitutions, subst, cache) + }); + res.add(ENodeOrVar::ENode(new_node)) + }; + cache.insert(node_id, res_id); + res_id } /// Get all the unions ever found in the egraph in terms of enode ids. @@ -391,8 +434,10 @@ impl> EGraph { /// Get the number of congruences between nodes in the egraph. /// Only available when explanations are enabled. pub fn get_num_congr(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_congr::(&self.classes, &self.unionfind) + if let Some(explain) = &mut self.explain { + explain + .with_nodes(&self.nodes) + .get_num_congr::(&self.classes, &self.unionfind) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -400,8 +445,8 @@ impl> EGraph { /// Get the number of nodes in the egraph used for explanations. pub fn get_explanation_num_nodes(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_nodes() + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).get_num_nodes() } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -439,7 +484,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -462,7 +512,7 @@ impl> EGraph { /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -476,7 +526,7 @@ impl> EGraph { ) -> Explanation { let id = self.add_instantiation_noncanonical(pattern, subst); if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -499,7 +549,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations."); } @@ -583,11 +638,7 @@ where .map(|l| self.map_node(l)) .collect(), data: self.map_data(src_eclass.data), - parents: src_eclass - .parents - .into_iter() - .map(|(l, id)| (self.map_node(l), id)) - .collect(), + parents: src_eclass.parents, } } @@ -599,12 +650,13 @@ where explain: None, unionfind: src_egraph.unionfind, memo: src_egraph.memo.into_iter().map(kv_map).collect(), - pending: src_egraph.pending.into_iter().map(kv_map).collect(), - analysis_pending: src_egraph - .analysis_pending + pending: src_egraph.pending, + nodes: src_egraph + .nodes .into_iter() - .map(kv_map) + .map(|x| self.map_node(x)) .collect(), + analysis_pending: src_egraph.analysis_pending, classes: src_egraph .classes .into_iter() @@ -794,7 +846,7 @@ impl> EGraph { /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical /// - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { let nodes = expr.as_ref(); let mut new_ids = Vec::with_capacity(nodes.len()); @@ -832,7 +884,7 @@ impl> EGraph { /// canonical /// /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { let nodes = pat.as_ref(); @@ -952,7 +1004,7 @@ impl> EGraph { /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will /// correspond to the parameter `enode` /// - /// # Example + /// ## Example /// ``` /// # use egg::*; /// let mut egraph: EGraph = EGraph::default().with_explanations_enabled(); @@ -967,6 +1019,25 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap()); /// ``` + /// + /// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// produce an expression with equivalent but not necessarily identical children + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = EGraph::default().with_explanations_disabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.union(a, b); + /// egraph.rebuild(); + /// + /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + /// + /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); + /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap()); + /// ``` pub fn add_uncanonical(&mut self, mut enode: L) -> Id { let original = enode.clone(); if let Some(existing_id) = self.lookup_internal(&mut enode) { @@ -977,7 +1048,9 @@ impl> EGraph { *existing_explain } else { let new_id = self.unionfind.make_set(); - explain.add(original, new_id, new_id); + explain.add(original.clone(), new_id, new_id); + debug_assert_eq!(Id::from(self.nodes.len()), new_id); + self.nodes.push(original); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -986,7 +1059,7 @@ impl> EGraph { existing_id } } else { - let id = self.make_new_eclass(enode); + let id = self.make_new_eclass(enode, original.clone()); if let Some(explain) = self.explain.as_mut() { explain.add(original, id, id); } @@ -999,24 +1072,26 @@ impl> EGraph { } /// This function makes a new eclass in the egraph (but doesn't touch explanations) - fn make_new_eclass(&mut self, enode: L) -> Id { + fn make_new_eclass(&mut self, enode: L, original: L) -> Id { let id = self.unionfind.make_set(); log::trace!(" ...adding to {}", id); let class = EClass { id, nodes: vec![enode.clone()], - data: N::make(self, &enode), + data: N::make(self, &original), parents: Default::default(), }; + debug_assert_eq!(Id::from(self.nodes.len()), id); + self.nodes.push(original); + // add this enode to the parent lists of its children enode.for_each(|child| { - let tup = (enode.clone(), id); - self[child].parents.push(tup); + self[child].parents.push(id); }); // TODO is this needed? - self.pending.push((enode.clone(), id)); + self.pending.push(id); self.classes.insert(id, class); assert!(self.memo.insert(enode, id).is_none()); @@ -1151,13 +1226,13 @@ impl> EGraph { let class1 = self.classes.get_mut(&id1).unwrap(); assert_eq!(id1, class1.id); - self.pending.extend(class2.parents.iter().cloned()); + self.pending.extend(class2.parents.iter().copied()); let did_merge = self.analysis.merge(&mut class1.data, class2.data); if did_merge.0 { - self.analysis_pending.extend(class1.parents.iter().cloned()); + self.analysis_pending.extend(class1.parents.iter().copied()); } if did_merge.1 { - self.analysis_pending.extend(class2.parents.iter().cloned()); + self.analysis_pending.extend(class2.parents.iter().copied()); } concat_vecs(&mut class1.nodes, class2.nodes); @@ -1176,7 +1251,7 @@ impl> EGraph { let id = self.find_mut(id); let class = self.classes.get_mut(&id).unwrap(); class.data = new_data; - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents.iter().copied()); N::modify(self, id) } @@ -1311,7 +1386,8 @@ impl> EGraph { let mut n_unions = 0; while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some((mut node, class)) = self.pending.pop() { + while let Some(class) = self.pending.pop() { + let mut node = self.nodes[usize::from(class)].clone(); node.update_children(|id| self.find_mut(id)); if let Some(memo_class) = self.memo.insert(node, class) { let did_something = self.perform_union( @@ -1324,14 +1400,15 @@ impl> EGraph { } } - while let Some((node, class_id)) = self.analysis_pending.pop() { + while let Some(class_id) = self.analysis_pending.pop() { + let node = self.nodes[usize::from(class_id)].clone(); let class_id = self.find_mut(class_id); let node_data = N::make(self, &node); let class = self.classes.get_mut(&class_id).unwrap(); let did_merge = self.analysis.merge(&mut class.data, node_data); if did_merge.0 { - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents.iter().copied()); N::modify(self, class_id) } } @@ -1412,9 +1489,9 @@ impl> EGraph { n_unions } - pub(crate) fn check_each_explain(&self, rules: &[&Rewrite]) -> bool { - if let Some(explain) = &self.explain { - explain.check_each_explain(rules) + pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).check_each_explain(rules) } else { panic!("Can't check explain when explanations are off"); } diff --git a/src/explain.rs b/src/explain.rs index 187aecfc..9de2a17e 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,12 +1,13 @@ use crate::Symbol; use crate::{ - util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id, - Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var, + util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, + PatternAst, RecExpr, Rewrite, UnionFind, Var, }; use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; use symbolic_expressions::Sexp; @@ -38,8 +39,7 @@ struct Connection { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct ExplainNode { - node: L, +struct ExplainNode { // neighbors includes parent connections neighbors: Vec, parent_connection: Connection, @@ -54,8 +54,15 @@ struct ExplainNode { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct Explain { - explainfind: Vec>, + explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + #[cfg_attr( + feature = "serde-1", + serde(bound( + serialize = "L: serde::Serialize", + deserialize = "L: serde::Deserialize<'de>", + )) + )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. pub optimize_explanation_lengths: bool, @@ -69,6 +76,11 @@ pub struct Explain { shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } +pub(crate) struct ExplainNodes<'a, L: Language> { + explain: &'a mut Explain, + nodes: &'a [L], +} + #[derive(Default)] struct DistanceMemo { parent_distance: Vec<(Id, ProofCost)>, @@ -883,97 +895,6 @@ impl PartialOrd for HeapState { } impl Explain { - pub(crate) fn node(&self, node_id: Id) -> &L { - &self.explainfind[usize::from(node_id)].node - } - fn node_to_explanation( - &self, - node_id: Id, - cache: &mut NodeExplanationCache, - ) -> Rc> { - if let Some(existing) = cache.get(&node_id) { - existing.clone() - } else { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(vec![self.node_to_explanation(child, cache)]); - sofar - }); - let res = Rc::new(TreeTerm::new(node, children)); - cache.insert(node_id, res.clone()); - res - } - } - - pub(crate) fn node_to_recexpr(&self, node_id: Id) -> RecExpr { - let mut res = Default::default(); - let mut cache = Default::default(); - self.node_to_recexpr_internal(&mut res, node_id, &mut cache); - res - } - fn node_to_recexpr_internal( - &self, - res: &mut RecExpr, - node_id: Id, - cache: &mut HashMap, - ) { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_recexpr_internal(res, child, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(new_node); - } - - pub(crate) fn node_to_pattern( - &self, - node_id: Id, - substitutions: &HashMap, - ) -> (Pattern, Subst) { - let mut res = Default::default(); - let mut subst = Default::default(); - let mut cache = Default::default(); - self.node_to_pattern_internal(&mut res, node_id, substitutions, &mut subst, &mut cache); - (Pattern::new(res), subst) - } - - fn node_to_pattern_internal( - &self, - res: &mut PatternAst, - node_id: Id, - var_substitutions: &HashMap, - subst: &mut Subst, - cache: &mut HashMap, - ) { - if let Some(existing) = var_substitutions.get(&node_id) { - let var = format!("?{}", node_id).parse().unwrap(); - res.add(ENodeOrVar::Var(var)); - subst.insert(var, *existing); - } else { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_pattern_internal(res, child, var_substitutions, subst, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(ENodeOrVar::ENode(new_node)); - } - } - - fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(self.node_to_flat_explanation(child)); - sofar - }); - FlatTerm::new(node, children) - } - fn make_rule_table<'a, N: Analysis>( rules: &[&'a Rewrite], ) -> HashMap> { @@ -983,52 +904,6 @@ impl Explain { } table } - - pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { - let rule_table = Explain::make_rule_table(rules); - for i in 0..self.explainfind.len() { - let explain_node = &self.explainfind[i]; - - // check that explanation reasons never form a cycle - let mut existance = i; - let mut seen_existance: HashSet = Default::default(); - loop { - seen_existance.insert(existance); - let next = usize::from(self.explainfind[existance].existance_node); - if existance == next { - break; - } - existance = next; - if seen_existance.contains(&existance) { - panic!("Cycle in existance!"); - } - } - - if explain_node.parent_connection.next != Id::from(i) { - let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); - let mut next_explanation = - self.node_to_flat_explanation(explain_node.parent_connection.next); - if let Justification::Rule(rule_name) = - &explain_node.parent_connection.justification - { - if let Some(rule) = rule_table.get(rule_name) { - if !explain_node.parent_connection.is_rewrite_forward { - std::mem::swap(&mut current_explanation, &mut next_explanation); - } - if !Explanation::check_rewrite( - ¤t_explanation, - &next_explanation, - rule, - ) { - return false; - } - } - } - } - } - true - } - pub fn new() -> Self { Explain { explainfind: vec![], @@ -1044,9 +919,8 @@ impl Explain { pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id { assert_eq!(self.explainfind.len(), usize::from(set)); - self.uncanon_memo.insert(node.clone(), set); + self.uncanon_memo.insert(node, set); self.explainfind.push(ExplainNode { - node, neighbors: vec![], parent_connection: Connection { justification: Justification::Congruence, @@ -1119,7 +993,7 @@ impl Explain { new_rhs: bool, ) { if let Justification::Congruence = justification { - assert!(self.node(node1).matches(self.node(node2))); + // assert!(self.node(node1).matches(self.node(node2))); } if new_rhs { self.set_existance_reason(node2, node1) @@ -1155,7 +1029,6 @@ impl Explain { .push(other_pconnection); self.explainfind[usize::from(node1)].parent_connection = pconnection; } - pub(crate) fn get_union_equalities(&self) -> UnionEqualities { let mut equalities = vec![]; for node in &self.explainfind { @@ -1170,13 +1043,103 @@ impl Explain { equalities } - pub(crate) fn populate_enodes>(&self, mut egraph: EGraph) -> EGraph { - for i in 0..self.explainfind.len() { - let node = &self.explainfind[i]; - egraph.add(node.node.clone()); + pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> { + ExplainNodes { + explain: self, + nodes, } + } +} + +impl<'a, L: Language> Deref for ExplainNodes<'a, L> { + type Target = Explain; + + fn deref(&self) -> &Self::Target { + self.explain + } +} + +impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.explain + } +} + +impl<'x, L: Language> ExplainNodes<'x, L> { + pub(crate) fn node(&self, node_id: Id) -> &L { + &self.nodes[usize::from(node_id)] + } + fn node_to_explanation( + &self, + node_id: Id, + cache: &mut NodeExplanationCache, + ) -> Rc> { + if let Some(existing) = cache.get(&node_id) { + existing.clone() + } else { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(vec![self.node_to_explanation(child, cache)]); + sofar + }); + let res = Rc::new(TreeTerm::new(node, children)); + cache.insert(node_id, res.clone()); + res + } + } + + fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(self.node_to_flat_explanation(child)); + sofar + }); + FlatTerm::new(node, children) + } + + pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { + let rule_table = Explain::make_rule_table(rules); + for i in 0..self.explainfind.len() { + let explain_node = &self.explainfind[i]; + + // check that explanation reasons never form a cycle + let mut existance = i; + let mut seen_existance: HashSet = Default::default(); + loop { + seen_existance.insert(existance); + let next = usize::from(self.explainfind[existance].existance_node); + if existance == next { + break; + } + existance = next; + if seen_existance.contains(&existance) { + panic!("Cycle in existance!"); + } + } - egraph + if explain_node.parent_connection.next != Id::from(i) { + let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); + let mut next_explanation = + self.node_to_flat_explanation(explain_node.parent_connection.next); + if let Justification::Rule(rule_name) = + &explain_node.parent_connection.justification + { + if let Some(rule) = rule_table.get(rule_name) { + if !explain_node.parent_connection.is_rewrite_forward { + std::mem::swap(&mut current_explanation, &mut next_explanation); + } + if !Explanation::check_rewrite( + ¤t_explanation, + &next_explanation, + rule, + ) { + return false; + } + } + } + } + } + true } pub(crate) fn explain_equivalence>( @@ -1328,7 +1291,7 @@ impl Explain { let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone(); let mut index_of_child = 0; let mut found = false; - existance_node.node.for_each(|child| { + self.node(existance).for_each(|child| { if found { return; } @@ -2092,7 +2055,7 @@ mod tests { #[test] fn simple_explain_union_trusted() { - use crate::SymbolLang; + use crate::{EGraph, SymbolLang}; crate::init_logger(); let mut egraph = EGraph::new(()).with_explanations_enabled(); diff --git a/src/language.rs b/src/language.rs index 6414c63a..79f32c4d 100644 --- a/src/language.rs +++ b/src/language.rs @@ -708,6 +708,8 @@ pub trait Analysis: Sized { /// It is **not** `make`'s responsiblity to insert the e-node; /// the e-node is "being inserted" when this function is called. /// Doing so will create an infinite loop. + /// + /// Note that `enode`'s children may not be canonical fn make(egraph: &mut EGraph, enode: &L) -> Self::Data; /// An optional hook that allows inspection before a [`union`] occurs.