From 3a12ea68dc208d70a77b7f253b402f1a136e68af Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Sat, 29 Jun 2024 16:28:31 +0200 Subject: [PATCH] Apply a Horner scheme to the evaluation tree --- examples/nested_evaluation.rs | 7 +- src/evaluate.rs | 149 ++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 2 deletions(-) diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs index 65ad9a7..d49ec12 100644 --- a/examples/nested_evaluation.rs +++ b/examples/nested_evaluation.rs @@ -13,7 +13,7 @@ fn main() { let g = Atom::parse("i(y+7)+x*i(y+7)*(y-1)").unwrap(); let h = Atom::parse("y*(1+x*(1+x^2)) + y^2*(1+x*(1+x^2))^2 + 3*(1+x^2)").unwrap(); let i = Atom::parse("y - 1").unwrap(); - let k = Atom::parse("x+8").unwrap(); + let k = Atom::parse("3*x^3 + 4*x^2 + 6*x +8").unwrap(); let mut const_map = HashMap::default(); @@ -63,7 +63,10 @@ fn main() { let params = vec![Atom::parse("x").unwrap()]; - let tree = e.as_view().to_eval_tree(|r| r.clone(), &const_map, ¶ms); + let mut tree = e.as_view().to_eval_tree(|r| r.clone(), &const_map, ¶ms); + + tree.horner_scheme(); // optimize the tree using an occurrence-order Horner scheme + let t2 = tree.map_coeff::(&|r| r.into()); println!("{}", t2.export_cpp()); // print C++ code diff --git a/src/evaluate.rs b/src/evaluate.rs index 4211102..c2514d1 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -412,6 +412,155 @@ impl EvalTree { } } +impl EvalTree { + fn apply_horner_scheme(&mut self, scheme: &[EvalTree]) { + if scheme.is_empty() { + return; + } + + let EvalTree::Add(a) = self else { + return; + }; + + // TODO: find power to extract, now we do just one + + let mut contains = vec![]; + let mut rest = vec![]; + + for x in a { + let mut found = false; + if let EvalTree::Mul(m) = x { + for (p, y) in m.iter_mut().enumerate() { + if let EvalTree::Pow(p) = y { + if p.0 == scheme[0] { + found = true; + if p.1 == 2 { + *y = p.0.clone(); // TODO: prevent clone + } else { + p.1 -= 1; + } + } + } else if y == &scheme[0] { + found = true; + // remove from prod + m.remove(p); + if m.len() == 1 { + *x = m[0].clone(); + } + break; + } + } + } else if x == &scheme[0] { + found = true; + *x = EvalTree::Const(Rational::one()); + } + + if found { + contains.push(x.clone()); + } else { + rest.push(x.clone()); + } + } + + if contains.is_empty() { + *self = EvalTree::Add(rest); + self.apply_horner_scheme(&scheme[1..]); + } else { + let mut c = EvalTree::Mul(vec![EvalTree::Add(contains), scheme[0].clone()]); + c.apply_horner_scheme(&scheme[1..]); + + if rest.is_empty() { + *self = c; + } else { + let mut r = EvalTree::Add(rest); + r.apply_horner_scheme(&scheme[1..]); + + *self = EvalTree::Add(vec![c, r]); + } + } + } + + /// Apply a simple occurrence-order Horner scheme to every addition. + pub fn horner_scheme(&mut self) { + match self { + EvalTree::Const(_) | EvalTree::Parameter(_) | EvalTree::ReadArg(_, _) => {} + EvalTree::Eval(_, _, _, ae, f) => { + for arg in ae { + arg.horner_scheme(); + } + f.horner_scheme(); + } + EvalTree::Add(a) => { + for arg in &mut *a { + arg.horner_scheme(); + } + + let mut occurrence = HashMap::default(); + + for arg in &*a { + match arg { + EvalTree::Mul(m) => { + for aa in m { + if let EvalTree::Pow(p) = aa { + occurrence + .entry(p.0.clone()) + .and_modify(|x| *x += 1) + .or_insert(1); + } else { + occurrence + .entry(aa.clone()) + .and_modify(|x| *x += 1) + .or_insert(1); + } + } + } + x => { + if let EvalTree::Pow(p) = x { + occurrence + .entry(p.0.clone()) + .and_modify(|x| *x += 1) + .or_insert(1); + } else { + occurrence + .entry(x.clone()) + .and_modify(|x| *x += 1) + .or_insert(1); + } + } + } + } + + occurrence.retain(|_, v| *v > 1); + let mut order: Vec<_> = occurrence.into_iter().collect(); + order.sort_by_key(|k| k.1); + let scheme = order.into_iter().map(|(k, _)| k).collect::>(); + + self.apply_horner_scheme(&scheme); + } + EvalTree::Mul(a) => { + for arg in a { + arg.horner_scheme(); + } + } + EvalTree::Pow(p) => { + p.0.horner_scheme(); + } + EvalTree::Powf(p) => { + p.0.horner_scheme(); + p.1.horner_scheme(); + } + EvalTree::BuiltinFun(_, a) => { + a.horner_scheme(); + } + EvalTree::SubExpression(_, r) => { + let mut rr = r.as_ref().clone(); + rr.horner_scheme(); + *r = Rc::new(rr); + } + } + } +} + impl EvalTree { fn extract_subexpressions(&mut self) { let mut h = HashMap::default();