From de11355f73aba6322c3380f6c8a98cd7c51b91f5 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Wed, 3 Jul 2024 09:16:54 +0200 Subject: [PATCH] Add common pair removal - Refactoring --- examples/nested_evaluation.rs | 18 +- src/evaluate.rs | 1246 ++++++++++++++++++++++----------- 2 files changed, 871 insertions(+), 393 deletions(-) diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs index dc593e0..f39de57 100644 --- a/examples/nested_evaluation.rs +++ b/examples/nested_evaluation.rs @@ -65,11 +65,21 @@ fn main() { let mut tree = e.as_view().to_eval_tree(|r| r.clone(), &const_map, ¶ms); - tree.horner_scheme(); // optimize the tree using an occurrence-order Horner scheme + // optimize the tree using an occurrence-order Horner scheme + println!("Op original {:?}", tree.count_operations()); + tree.horner_scheme(); + println!("Op horner {:?}", tree.count_operations()); + // the compiler seems to do this as well + tree.common_subexpression_elimination(); + println!("op CSSE {:?}", tree.count_operations()); + + let cpp = tree.export_cpp(); + println!("{}", cpp); // print C++ code - let t2 = tree.map_coeff::(&|r| r.into()); + tree.common_pair_elimination(); + println!("op CPE {:?}", tree.count_operations()); - let cpp = t2.export_cpp(); + let cpp = tree.export_cpp(); println!("{}", cpp); // print C++ code std::fs::write("nested_evaluation.cpp", cpp).unwrap(); @@ -78,6 +88,7 @@ fn main() { .arg("-shared") .arg("-fPIC") .arg("-O3") + .arg("-ffastmath") .arg("-o") .arg("libneval.so") .arg("nested_evaluation.cpp") @@ -101,6 +112,7 @@ fn main() { println!("C++ time {:#?}", t.elapsed()); }; + let t2 = tree.map_coeff::(&|r| r.into()); let mut evaluator: ExpressionEvaluator = t2.linearize(params.len()); println!("Eval: {}", evaluator.evaluate(&[5.])); diff --git a/src/evaluate.rs b/src/evaluate.rs index 9ad2be3..aab1fd8 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -1,11 +1,12 @@ -use std::rc::Rc; - -use ahash::{HashMap, HashSet}; +use ahash::HashMap; use crate::{ atom::{representation::InlineVar, Atom, AtomOrView, AtomView, Symbol}, coefficient::CoefficientView, - domains::{float::Real, rational::Rational}, + domains::{ + float::{NumericalFloatLike, Real}, + rational::Rational, + }, state::State, }; @@ -54,24 +55,28 @@ impl Atom { } } -#[derive(Debug, Clone, Hash, PartialEq, PartialOrd, Ord, Eq)] -pub enum EvalTree { +pub struct ExpressionWithSubexpressions { + pub tree: Expression, + pub subexpressions: Vec>, +} + +pub struct EvalTree { + functions: Vec<(Symbol, Vec, ExpressionWithSubexpressions)>, + expressions: ExpressionWithSubexpressions, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum Expression { Const(T), Parameter(usize), - Eval( - Vec, // a buffer for the evaluated arguments - Symbol, // function name - Vec, // function argument names - Vec>, - Box>, - ), - Add(Vec>), - Mul(Vec>), - Pow(Box<(EvalTree, i64)>), - Powf(Box<(EvalTree, EvalTree)>), - ReadArg(Symbol, usize), // read nth function argument, also store the name for codegen - BuiltinFun(Symbol, Box>), - SubExpression(usize, Rc>), // a reference to a subexpression + Eval(usize, Vec>), + Add(Vec>), + Mul(Vec>), + Pow(Box<(Expression, i64)>), + Powf(Box<(Expression, Expression)>), + ReadArg(usize), // read nth function argument + BuiltinFun(Symbol, Box>), + SubExpression(usize), } pub struct ExpressionEvaluator { @@ -218,43 +223,102 @@ enum Instr { Mul(usize, Vec), Pow(usize, usize, i64), Powf(usize, usize, usize), - BuiltinFun(usize, Symbol, usize), // support function call too? that would be a jump in the instr table here? + BuiltinFun(usize, Symbol, usize), } -impl EvalTree { - pub fn map_coeff T2>(&self, f: &F) -> EvalTree { +impl ExpressionWithSubexpressions { + pub fn map_coeff T2>(&self, f: &F) -> ExpressionWithSubexpressions { + ExpressionWithSubexpressions { + tree: self.tree.map_coeff(f), + subexpressions: self.subexpressions.iter().map(|x| x.map_coeff(f)).collect(), + } + } +} + +impl Expression { + pub fn map_coeff T2>(&self, f: &F) -> Expression { match self { - EvalTree::Const(c) => EvalTree::Const(f(c)), - EvalTree::Parameter(p) => EvalTree::Parameter(*p), - EvalTree::Eval(arg_buf, name, arg_names, e_args, ff) => { - let new_args = e_args.iter().map(|x| x.map_coeff(f)).collect(); - EvalTree::Eval( - arg_buf.iter().map(|x| f(x)).collect(), - *name, - arg_names.clone(), - new_args, - Box::new(ff.map_coeff(f)), - ) + Expression::Const(c) => Expression::Const(f(c)), + Expression::Parameter(p) => Expression::Parameter(*p), + Expression::Eval(id, e_args) => { + Expression::Eval(*id, e_args.iter().map(|x| x.map_coeff(f)).collect()) } - EvalTree::Add(a) => { + Expression::Add(a) => { let new_args = a.iter().map(|x| x.map_coeff(f)).collect(); - EvalTree::Add(new_args) + Expression::Add(new_args) } - EvalTree::Mul(m) => { + Expression::Mul(m) => { let new_args = m.iter().map(|x| x.map_coeff(f)).collect(); - EvalTree::Mul(new_args) + Expression::Mul(new_args) } - EvalTree::Pow(p) => { + Expression::Pow(p) => { let (b, e) = &**p; - EvalTree::Pow(Box::new((b.map_coeff(f), *e))) + Expression::Pow(Box::new((b.map_coeff(f), *e))) } - EvalTree::Powf(p) => { + Expression::Powf(p) => { let (b, e) = &**p; - EvalTree::Powf(Box::new((b.map_coeff(f), e.map_coeff(f)))) + Expression::Powf(Box::new((b.map_coeff(f), e.map_coeff(f)))) } - EvalTree::ReadArg(s, i) => EvalTree::ReadArg(*s, *i), - EvalTree::BuiltinFun(s, a) => EvalTree::BuiltinFun(*s, Box::new(a.map_coeff(f))), - EvalTree::SubExpression(i, e) => EvalTree::SubExpression(*i, Rc::new(e.map_coeff(f))), + Expression::ReadArg(s) => Expression::ReadArg(*s), + Expression::BuiltinFun(s, a) => Expression::BuiltinFun(*s, Box::new(a.map_coeff(f))), + Expression::SubExpression(i) => Expression::SubExpression(*i), + } + } + + fn strip_constants(&mut self, stack: &mut Vec, param_len: usize) { + match self { + Expression::Const(t) => { + if let Some(p) = stack.iter().skip(param_len).position(|x| x == t) { + *self = Expression::Parameter(param_len + p); + } else { + stack.push(t.clone()); + *self = Expression::Parameter(stack.len() - 1); + } + } + Expression::Parameter(_) => {} + Expression::Eval(_, e_args) => { + for a in e_args { + a.strip_constants(stack, param_len); + } + } + Expression::Add(a) | Expression::Mul(a) => { + for arg in a { + arg.strip_constants(stack, param_len); + } + } + Expression::Pow(p) => { + p.0.strip_constants(stack, param_len); + } + Expression::Powf(p) => { + p.0.strip_constants(stack, param_len); + p.1.strip_constants(stack, param_len); + } + Expression::ReadArg(_) => {} + Expression::BuiltinFun(_, a) => { + a.strip_constants(stack, param_len); + } + Expression::SubExpression(_) => {} + } + } +} + +impl EvalTree { + pub fn map_coeff T2>(&self, f: &F) -> EvalTree { + EvalTree { + expressions: ExpressionWithSubexpressions { + tree: self.expressions.tree.map_coeff(f), + subexpressions: self + .expressions + .subexpressions + .iter() + .map(|x| x.map_coeff(f)) + .collect(), + }, + functions: self + .functions + .iter() + .map(|(s, a, e)| (*s, a.clone(), e.map_coeff(f))) + .collect(), } } @@ -268,8 +332,14 @@ impl EvalTree { let mut sub_expr_pos = HashMap::default(); let mut instructions = vec![]; - let result_index = - self.linearize_impl(&mut stack, &mut instructions, &mut sub_expr_pos, &[]); + let result_index = self.linearize_impl( + &self.expressions.tree, + &self.expressions.subexpressions, + &mut stack, + &mut instructions, + &mut sub_expr_pos, + &[], + ); let mut e = ExpressionEvaluator { stack, @@ -283,42 +353,15 @@ impl EvalTree { } fn strip_constants(&mut self, stack: &mut Vec, param_len: usize) { - match self { - EvalTree::Const(t) => { - if let Some(p) = stack.iter().skip(param_len).position(|x| x == t) { - *self = EvalTree::Parameter(param_len + p); - } else { - stack.push(t.clone()); - *self = EvalTree::Parameter(stack.len() - 1); - } - } - EvalTree::Parameter(_) => {} - EvalTree::Eval(_, _, _, e_args, f) => { - for a in e_args { - a.strip_constants(stack, param_len); - } - f.strip_constants(stack, param_len); - } - EvalTree::Add(a) | EvalTree::Mul(a) => { - for arg in a { - arg.strip_constants(stack, param_len); - } - } - EvalTree::Pow(p) => { - p.0.strip_constants(stack, param_len); - } - EvalTree::Powf(p) => { - p.0.strip_constants(stack, param_len); - p.1.strip_constants(stack, param_len); - } - EvalTree::ReadArg(_, _) => {} - EvalTree::BuiltinFun(_, a) => { - a.strip_constants(stack, param_len); - } - EvalTree::SubExpression(_, t) => { - let mut t2 = t.as_ref().clone(); - t2.strip_constants(stack, param_len); - *t = Rc::new(t2); + self.expressions.tree.strip_constants(stack, param_len); + for e in &mut self.expressions.subexpressions { + e.strip_constants(stack, param_len); + } + + for (_, _, e) in &mut self.functions { + e.tree.strip_constants(stack, param_len); + for e in &mut e.subexpressions { + e.strip_constants(stack, param_len); } } } @@ -326,30 +369,44 @@ impl EvalTree { // Yields the stack index that contains the output. fn linearize_impl( &self, + tree: &Expression, + subexpressions: &[Expression], stack: &mut Vec, instr: &mut Vec, sub_expr_pos: &mut HashMap, args: &[usize], ) -> usize { - match self { - EvalTree::Const(t) => { + match tree { + Expression::Const(t) => { stack.push(t.clone()); // TODO: do once and recycle, this messes with the logic as there is no associated instruction stack.len() - 1 } - EvalTree::Parameter(i) => *i, - EvalTree::Eval(_, _, _, e_args, f) => { + Expression::Parameter(i) => *i, + Expression::Eval(id, e_args) => { // inline the function let new_args: Vec<_> = e_args .iter() - .map(|x| x.linearize_impl(stack, instr, sub_expr_pos, args)) + .map(|x| { + self.linearize_impl(x, subexpressions, stack, instr, sub_expr_pos, args) + }) .collect(); - f.linearize_impl(stack, instr, sub_expr_pos, &new_args) + let func = &self.functions[*id].2; + self.linearize_impl( + &func.tree, + &func.subexpressions, + stack, + instr, + sub_expr_pos, + &new_args, + ) } - EvalTree::Add(a) => { + Expression::Add(a) => { let args = a .iter() - .map(|x| x.linearize_impl(stack, instr, sub_expr_pos, args)) + .map(|x| { + self.linearize_impl(x, subexpressions, stack, instr, sub_expr_pos, args) + }) .collect(); stack.push(T::default()); @@ -360,10 +417,12 @@ impl EvalTree { res } - EvalTree::Mul(m) => { + Expression::Mul(m) => { let args = m .iter() - .map(|x| x.linearize_impl(stack, instr, sub_expr_pos, args)) + .map(|x| { + self.linearize_impl(x, subexpressions, stack, instr, sub_expr_pos, args) + }) .collect(); stack.push(T::default()); @@ -374,36 +433,43 @@ impl EvalTree { res } - EvalTree::Pow(p) => { - let b = p.0.linearize_impl(stack, instr, sub_expr_pos, args); + Expression::Pow(p) => { + let b = self.linearize_impl(&p.0, subexpressions, stack, instr, sub_expr_pos, args); stack.push(T::default()); let res = stack.len() - 1; instr.push(Instr::Pow(res, b, p.1)); res } - EvalTree::Powf(p) => { - let b = p.0.linearize_impl(stack, instr, sub_expr_pos, args); - let e = p.1.linearize_impl(stack, instr, sub_expr_pos, args); + Expression::Powf(p) => { + let b = self.linearize_impl(&p.0, subexpressions, stack, instr, sub_expr_pos, args); + let e = self.linearize_impl(&p.1, subexpressions, stack, instr, sub_expr_pos, args); stack.push(T::default()); let res = stack.len() - 1; instr.push(Instr::Powf(res, b, e)); res } - EvalTree::ReadArg(_, a) => args[*a], - EvalTree::BuiltinFun(s, v) => { - let arg = v.linearize_impl(stack, instr, sub_expr_pos, args); + Expression::ReadArg(a) => args[*a], + Expression::BuiltinFun(s, v) => { + let arg = self.linearize_impl(v, subexpressions, stack, instr, sub_expr_pos, args); stack.push(T::default()); let c = Instr::BuiltinFun(stack.len() - 1, *s, arg); instr.push(c); stack.len() - 1 } - EvalTree::SubExpression(id, s) => { + Expression::SubExpression(id) => { if sub_expr_pos.contains_key(id) { *sub_expr_pos.get(id).unwrap() } else { - let res = s.linearize_impl(stack, instr, sub_expr_pos, args); + let res = self.linearize_impl( + &subexpressions[*id], + subexpressions, + stack, + instr, + sub_expr_pos, + args, + ); sub_expr_pos.insert(*id, res); res } @@ -413,21 +479,39 @@ impl EvalTree { } impl EvalTree { - fn apply_horner_scheme(&mut self, scheme: &[EvalTree]) { + pub fn horner_scheme(&mut self) { + self.expressions.tree.horner_scheme(); + for e in &mut self.expressions.subexpressions { + e.horner_scheme(); + } + + for (_, _, e) in &mut self.functions { + e.tree.horner_scheme(); + for e in &mut e.subexpressions { + e.horner_scheme(); + } + } + } +} + +impl Expression { + fn apply_horner_scheme(&mut self, scheme: &[Expression]) { if scheme.is_empty() { return; } - let EvalTree::Add(a) = self else { + let Expression::Add(a) = self else { return; }; + a.sort(); + let mut max_pow: Option = None; for x in &*a { - if let EvalTree::Mul(m) = x { + if let Expression::Mul(m) = x { let mut pow_counter = 0; for y in m { - if let EvalTree::Pow(p) = y { + if let Expression::Pow(p) = y { if p.0 == scheme[0] { pow_counter += p.1; } @@ -456,11 +540,11 @@ impl EvalTree { for x in a { let mut found = false; - if let EvalTree::Mul(m) = x { + if let Expression::Mul(m) = x { let mut pow_counter = 0; m.retain(|y| { - if let EvalTree::Pow(p) = y { + if let Expression::Pow(p) = y { if p.0 == scheme[0] { pow_counter += p.1; false @@ -477,7 +561,7 @@ impl EvalTree { if pow_counter > max_pow { if pow_counter > max_pow + 1 { - m.push(EvalTree::Pow(Box::new(( + m.push(Expression::Pow(Box::new(( scheme[0].clone(), pow_counter - max_pow, )))); @@ -489,7 +573,7 @@ impl EvalTree { } if m.is_empty() { - *x = EvalTree::Const(Rational::one()); + *x = Expression::Const(Rational::one()); } else if m.len() == 1 { *x = m.pop().unwrap(); } @@ -497,7 +581,7 @@ impl EvalTree { found = pow_counter > 0; } else if x == &scheme[0] { found = true; - *x = EvalTree::Const(Rational::one()); + *x = Expression::Const(Rational::one()); } if found { @@ -510,20 +594,20 @@ impl EvalTree { let extracted = if max_pow == 1 { scheme[0].clone() } else { - EvalTree::Pow(Box::new((scheme[0].clone(), max_pow))) + Expression::Pow(Box::new((scheme[0].clone(), max_pow))) }; let mut contains = if contains.len() == 1 { contains.pop().unwrap() } else { - EvalTree::Add(contains) + Expression::Add(contains) }; contains.apply_horner_scheme(&scheme); // keep trying with same variable let mut v = vec![contains, extracted]; v.sort(); - let c = EvalTree::Mul(v); + let c = Expression::Mul(v); if rest.is_empty() { *self = c; @@ -531,7 +615,7 @@ impl EvalTree { let mut r = if rest.len() == 1 { rest.pop().unwrap() } else { - EvalTree::Add(rest) + Expression::Add(rest) }; r.apply_horner_scheme(&scheme[1..]); @@ -539,21 +623,20 @@ impl EvalTree { let mut v = vec![c, r]; v.sort(); - *self = EvalTree::Add(v); + *self = Expression::Add(v); } } /// Apply a simple occurrence-order Horner scheme to every addition. pub fn horner_scheme(&mut self) { match self { - EvalTree::Const(_) | EvalTree::Parameter(_) | EvalTree::ReadArg(_, _) => {} - EvalTree::Eval(_, _, _, ae, f) => { + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {} + Expression::Eval(_, ae) => { for arg in ae { arg.horner_scheme(); } - f.horner_scheme(); } - EvalTree::Add(a) => { + Expression::Add(a) => { for arg in &mut *a { arg.horner_scheme(); } @@ -562,9 +645,9 @@ impl EvalTree { for arg in &*a { match arg { - EvalTree::Mul(m) => { + Expression::Mul(m) => { for aa in m { - if let EvalTree::Pow(p) = aa { + if let Expression::Pow(p) = aa { occurrence .entry(p.0.clone()) .and_modify(|x| *x += 1) @@ -578,7 +661,7 @@ impl EvalTree { } } x => { - if let EvalTree::Pow(p) = x { + if let Expression::Pow(p) = x { occurrence .entry(p.0.clone()) .and_modify(|x| *x += 1) @@ -600,177 +683,550 @@ impl EvalTree { self.apply_horner_scheme(&scheme); } - EvalTree::Mul(a) => { + Expression::Mul(a) => { for arg in a { arg.horner_scheme(); } } - EvalTree::Pow(p) => { + Expression::Pow(p) => { p.0.horner_scheme(); } - EvalTree::Powf(p) => { + Expression::Powf(p) => { p.0.horner_scheme(); p.1.horner_scheme(); } - EvalTree::BuiltinFun(_, a) => { + Expression::BuiltinFun(_, a) => { a.horner_scheme(); } - EvalTree::SubExpression(_, r) => { - let mut rr = r.as_ref().clone(); - rr.horner_scheme(); - *r = Rc::new(rr); - } + Expression::SubExpression(_) => {} + } + } +} + +impl EvalTree { + pub fn common_subexpression_elimination(&mut self) { + assert!(self.expressions.subexpressions.is_empty()); // TODO: remove this limitation + + self.expressions.subexpressions.extend( + self.expressions + .tree + .extract_subexpressions(self.expressions.subexpressions.len()), + ); + + for (_, _, e) in &mut self.functions { + e.subexpressions + .extend(e.tree.extract_subexpressions(e.subexpressions.len())); + } + } + + pub fn common_pair_elimination(&mut self) { + while self.expressions.common_pair_elimination() {} + for (_, _, e) in &mut self.functions { + while e.common_pair_elimination() {} } } + + pub fn count_operations(&self) -> (usize, usize) { + let mut add = 0; + let mut mul = 0; + for e in &self.functions { + let (ea, em) = e.2.count_operations(); + add += ea; + mul += em; + } + + let (ea, em) = self.expressions.count_operations(); + (add + ea, mul + em) + } } -impl EvalTree { - fn extract_subexpressions(&mut self) { +impl Expression { + fn extract_subexpressions(&mut self, sub_expr_start: usize) -> Vec> { let mut h = HashMap::default(); - self.find_subexpression(&mut h, 0, &mut 0); + self.find_subexpression(&mut h); h.retain(|_, v| *v > 1); for (i, v) in h.values_mut().enumerate() { - *v = i; // make the second argument a unique index of the subexpression + *v = sub_expr_start + i; // make the second argument a unique index of the subexpression } - self.replace_subexpression(&h, 0, &mut 0, &mut HashMap::default()); + self.replace_subexpression(&h); + + let mut v: Vec<_> = h.into_iter().map(|(k, v)| (v, k)).collect(); + v.sort(); + v.into_iter().map(|(_, x)| x).collect() } - fn replace_subexpression( - &mut self, - subexp: &HashMap<(usize, EvalTree), usize>, - branch_id: usize, - new_branch_id: &mut usize, - new_sub_tree: &mut HashMap>, - ) { - let key = (branch_id, self.clone()); // key before any replacements - if let Some(i) = subexp.get(&key) { - if new_sub_tree.contains_key(i) { - *self = EvalTree::SubExpression(*i, Rc::new(new_sub_tree[i].clone())); - return; - } + fn replace_subexpression(&mut self, subexp: &HashMap, usize>) { + if let Some(i) = subexp.get(&self) { + *self = Expression::SubExpression(*i); + return; } match self { - EvalTree::Const(_) | EvalTree::Parameter(_) | EvalTree::ReadArg(_, _) => {} - EvalTree::Eval(_, _, _, ae, f) => { + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {} + Expression::Eval(_, ae) => { for arg in &mut *ae { - arg.replace_subexpression(subexp, branch_id, new_branch_id, new_sub_tree); + arg.replace_subexpression(subexp); } - - *new_branch_id += 1; - f.replace_subexpression(subexp, *new_branch_id, new_branch_id, new_sub_tree); } - EvalTree::Add(a) | EvalTree::Mul(a) => { + Expression::Add(a) | Expression::Mul(a) => { for arg in a { - arg.replace_subexpression(subexp, branch_id, new_branch_id, new_sub_tree); + arg.replace_subexpression(subexp); } } - EvalTree::Pow(p) => { - p.0.replace_subexpression(subexp, branch_id, new_branch_id, new_sub_tree); + Expression::Pow(p) => { + p.0.replace_subexpression(subexp); } - EvalTree::Powf(p) => { - p.0.replace_subexpression(subexp, branch_id, new_branch_id, new_sub_tree); - p.1.replace_subexpression(subexp, branch_id, new_branch_id, new_sub_tree); + Expression::Powf(p) => { + p.0.replace_subexpression(subexp); + p.1.replace_subexpression(subexp); } - EvalTree::BuiltinFun(_, _) => {} - EvalTree::SubExpression(_, _) => { + Expression::BuiltinFun(_, _) => {} + Expression::SubExpression(_) => { unimplemented!("The expression should not already have subexpressions") } } - - if let Some(i) = subexp.get(&key) { - new_sub_tree.insert(*i, self.clone()); - *self = EvalTree::SubExpression(*i, Rc::new(self.clone())); - } } - fn find_subexpression( - &self, - subexp: &mut HashMap<(usize, EvalTree), usize>, - branch_id: usize, - new_branch_id: &mut usize, - ) { + fn find_subexpression(&self, subexp: &mut HashMap, usize>) { if matches!( self, - EvalTree::Const(_) | EvalTree::Parameter(_) | EvalTree::ReadArg(_, _) + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) ) { return; } - let key = (branch_id, self.clone()); - if let Some(i) = subexp.get_mut(&key) { + if let Some(i) = subexp.get_mut(self) { *i += 1; return; } - subexp.insert(key, 1); + subexp.insert(self.clone(), 1); match self { - EvalTree::Const(_) | EvalTree::Parameter(_) | EvalTree::ReadArg(_, _) => {} - EvalTree::Eval(_, _, _, ae, f) => { + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {} + Expression::Eval(_, ae) => { for arg in ae { - arg.find_subexpression(subexp, branch_id, new_branch_id); + arg.find_subexpression(subexp); + } + } + Expression::Add(a) | Expression::Mul(a) => { + for arg in a { + arg.find_subexpression(subexp); + } + } + Expression::Pow(p) => { + p.0.find_subexpression(subexp); + } + Expression::Powf(p) => { + p.0.find_subexpression(subexp); + p.1.find_subexpression(subexp); + } + Expression::BuiltinFun(_, _) => {} + Expression::SubExpression(_) => {} + } + } +} + +impl + ExpressionWithSubexpressions +{ + /// Find and extract pairs of variables that appear in more than one instruction. + /// This reduces the number of operations. Returns `true` iff an extraction could be performed. + /// + /// This function can be called multiple times such that common subexpressions that + /// are larger than pairs can also be extracted. + pub fn common_pair_elimination(&mut self) -> bool { + let mut pair_count = HashMap::default(); + + for e in &self.subexpressions { + e.find_common_pairs(&mut pair_count); + } + + self.tree.find_common_pairs(&mut pair_count); + + let mut v: Vec<_> = pair_count.into_iter().collect(); + v.retain(|x| x.1 > 1); + v.sort_by_key(|k| std::cmp::Reverse(k.1)); + + let v: Vec<_> = v + .into_iter() + .map(|((a, b, c), e)| ((a, b.clone(), c.clone()), e)) + .collect(); + + for ((is_add, l, r), _) in &v { + let id = self.subexpressions.len(); + + self.tree.replace_common_pair(*is_add, l, r, id); + + let mut first_replace = None; + for (i, e) in &mut self.subexpressions.iter_mut().enumerate() { + if e.replace_common_pair(*is_add, l, r, id) { + if first_replace.is_none() { + first_replace = Some(i); + } + } + } + + let pair = if *is_add { + Expression::Add(vec![l.clone(), r.clone()]) + } else { + Expression::Mul(vec![l.clone(), r.clone()]) + }; + + if let Some(i) = first_replace { + // all subexpressions need to be shifted + for k in i..self.subexpressions.len() { + self.subexpressions[k].shift_subexpr(i, id); + } + + self.tree.shift_subexpr(i, id); + + self.subexpressions.insert(i, pair); + } else { + self.subexpressions.push(pair); + } + + // some subexpression could be Z3=Z2 now, remove that + for i in (0..self.subexpressions.len()).rev() { + if let Expression::SubExpression(n) = &self.subexpressions[i] { + let n = *n; + self.subexpressions.remove(i); + for e in &mut self.subexpressions[i..] { + e.rename_subexpr(i, n); + } + + self.tree.rename_subexpr(i, n); + } + } + + return true; // do just one for now + } + + false + } + + pub fn count_operations(&self) -> (usize, usize) { + let mut add = 0; + let mut mul = 0; + for e in &self.subexpressions { + let (ea, em) = e.count_operations(); + add += ea; + mul += em; + } + + let (ea, em) = self.tree.count_operations(); + (add + ea, mul + em) + } +} + +impl Expression { + pub fn count_operations(&self) -> (usize, usize) { + match self { + Expression::Const(_) => (0, 0), + Expression::Parameter(_) => (0, 0), + Expression::Eval(_, args) => { + let mut add = 0; + let mut mul = 0; + for arg in args { + let (a, m) = arg.count_operations(); + add += a; + mul += m; + } + (add, mul) + } + Expression::Add(a) => { + let mut add = 0; + let mut mul = 0; + for arg in a { + let (a, m) = arg.count_operations(); + add += a; + mul += m; } + (add + a.len() - 1, mul) + } + Expression::Mul(m) => { + let mut add = 0; + let mut mul = 0; + for arg in m { + let (a, m) = arg.count_operations(); + add += a; + mul += m; + } + (add, mul + m.len() - 1) + } + Expression::Pow(p) => { + let (a, m) = p.0.count_operations(); + (a, m + p.1 as usize - 1) + } + Expression::Powf(p) => { + let (a, m) = p.0.count_operations(); + let (a2, m2) = p.1.count_operations(); + (a + a2, m + m2 + 1) // not clear how to count this + } + Expression::ReadArg(_) => (0, 0), + Expression::BuiltinFun(_, _) => (0, 0), // not clear how to count this, third arg? + Expression::SubExpression(_) => (0, 0), + } + } - *new_branch_id += 1; - f.find_subexpression(subexp, *new_branch_id, new_branch_id); + fn shift_subexpr(&mut self, pos: usize, max: usize) { + match self { + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {} + Expression::Eval(_, ae) => { + for arg in &mut *ae { + arg.shift_subexpr(pos, max); + } } - EvalTree::Add(a) | EvalTree::Mul(a) => { + Expression::Add(a) | Expression::Mul(a) => { for arg in a { - arg.find_subexpression(subexp, branch_id, new_branch_id); + arg.shift_subexpr(pos, max); } } - EvalTree::Pow(p) => { - p.0.find_subexpression(subexp, branch_id, new_branch_id); + Expression::Pow(p) => { + p.0.shift_subexpr(pos, max); } - EvalTree::Powf(p) => { - p.0.find_subexpression(subexp, branch_id, new_branch_id); - p.1.find_subexpression(subexp, branch_id, new_branch_id); + Expression::Powf(p) => { + p.0.shift_subexpr(pos, max); + p.1.shift_subexpr(pos, max); } - EvalTree::BuiltinFun(_, _) => {} - EvalTree::SubExpression(_, _) => { - unimplemented!("The expression should not already have subexpressions") + Expression::BuiltinFun(_, _) => {} + Expression::SubExpression(i) => { + if *i == max { + *i = pos; + } else if *i >= pos { + *i += 1; + } + } + } + } + + fn rename_subexpr(&mut self, old: usize, new: usize) { + match self { + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {} + Expression::Eval(_, ae) => { + for arg in &mut *ae { + arg.rename_subexpr(old, new); + } + } + Expression::Add(a) | Expression::Mul(a) => { + for arg in a { + arg.rename_subexpr(old, new); + } + } + Expression::Pow(p) => { + p.0.rename_subexpr(old, new); + } + Expression::Powf(p) => { + p.0.rename_subexpr(old, new); + p.1.rename_subexpr(old, new); + } + Expression::BuiltinFun(_, _) => {} + Expression::SubExpression(i) => { + if *i == old { + *i = new; + } else if *i > old { + *i -= 1; + } } } } + + fn find_common_pairs<'a>(&'a self, subexp: &mut HashMap<(bool, &'a Self, &'a Self), usize>) { + match self { + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {} + Expression::Eval(_, ae) => { + for arg in ae { + arg.find_common_pairs(subexp); + } + } + x @ Expression::Add(m) | x @ Expression::Mul(m) => { + for a in m { + a.find_common_pairs(subexp); + } + + let mut d: Vec<_> = m.iter().collect(); + d.dedup(); + let mut rep = vec![0; d.len()]; + + for (c, v) in rep.iter_mut().zip(&d) { + for v2 in m { + if *v == v2 { + *c += 1; + } + } + } + + for i in 0..d.len() { + if rep[i] > 1 { + *subexp + .entry((matches!(x, Expression::Add(_)), &d[i], &d[i])) + .or_insert(0) += rep[i] / 2; + } + + for j in i + 1..d.len() { + *subexp + .entry((matches!(x, Expression::Add(_)), &d[i], &d[j])) + .or_insert(0) += rep[i].min(rep[j]); + } + } + } + Expression::Pow(p) => { + p.0.find_common_pairs(subexp); + } + Expression::Powf(p) => { + p.0.find_common_pairs(subexp); + p.1.find_common_pairs(subexp); + } + Expression::BuiltinFun(_, _) => {} + Expression::SubExpression(_) => {} + } + } + + fn replace_common_pair(&mut self, is_add: bool, r: &Self, l: &Self, subexpr_id: usize) -> bool { + let cur_is_add = matches!(self, Expression::Add(_)); + + match self { + Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => false, + Expression::Eval(_, ae) => { + let mut replaced = false; + for arg in &mut *ae { + replaced |= arg.replace_common_pair(is_add, r, l, subexpr_id); + } + replaced + } + Expression::Add(a) | Expression::Mul(a) => { + let mut replaced = false; + for arg in &mut *a { + replaced |= arg.replace_common_pair(is_add, r, l, subexpr_id); + } + + if is_add != cur_is_add { + return replaced; + } + + if l == r { + let count = a.iter().filter(|x| *x == l).count(); + let pairs = count / 2; + if pairs > 0 { + a.retain(|x| x != l); + + if count % 2 == 1 { + a.push(l.clone()); + } + + a.extend( + std::iter::repeat(Expression::SubExpression(subexpr_id)).take(pairs), + ); + a.sort(); + + if a.len() == 1 { + *self = a.pop().unwrap(); + } + + return true; + } + } else { + let mut idx1_count = 0; + let mut idx2_count = 0; + for v in &*a { + if v == l { + idx1_count += 1; + } + if v == r { + idx2_count += 1; + } + } + + let pair_count = idx1_count.min(idx2_count); + + if pair_count > 0 { + a.retain(|x| x != l && x != r); + + // add back removed indices in cases such as idx1*idx2*idx2 + if idx1_count > pair_count { + a.extend(std::iter::repeat(l.clone()).take(idx1_count - pair_count)); + } + if idx2_count > pair_count { + a.extend(std::iter::repeat(r.clone()).take(idx2_count - pair_count)); + } + + a.extend( + std::iter::repeat(Expression::SubExpression(subexpr_id)) + .take(pair_count), + ); + a.sort(); + + if a.len() == 1 { + *self = a.pop().unwrap(); + } + + return true; + } + } + + replaced + } + Expression::Pow(p) => p.0.replace_common_pair(is_add, r, l, subexpr_id), + Expression::Powf(p) => { + let mut replaced = p.0.replace_common_pair(is_add, r, l, subexpr_id); + replaced |= p.1.replace_common_pair(is_add, r, l, subexpr_id); + replaced + } + Expression::BuiltinFun(_, _) => false, + Expression::SubExpression(_) => false, + } + } } impl EvalTree { /// Evaluate the evaluation tree. Consider converting to a linear form for repeated evaluation. pub fn evaluate(&mut self, params: &[T]) -> T { - self.evaluate_impl(params, &[]) + self.evaluate_impl( + &self.expressions.tree, + &self.expressions.subexpressions, + params, + &[], + ) } - fn evaluate_impl(&mut self, params: &[T], args: &[T]) -> T { - match self { - EvalTree::Const(c) => c.clone(), - EvalTree::Parameter(p) => params[*p].clone(), - EvalTree::Eval(arg_buf, _, _, e_args, f) => { - for (b, a) in arg_buf.iter_mut().zip(e_args.iter_mut()) { - *b = a.evaluate_impl(params, args); + fn evaluate_impl( + &self, + expr: &Expression, + subexpressions: &[Expression], + params: &[T], + args: &[T], + ) -> T { + match expr { + Expression::Const(c) => c.clone(), + Expression::Parameter(p) => params[*p].clone(), + Expression::Eval(f, e_args) => { + let mut arg_buf = vec![T::new_zero(); e_args.len()]; + for (b, a) in arg_buf.iter_mut().zip(e_args.iter()) { + *b = self.evaluate_impl(a, subexpressions, params, args); } - f.evaluate_impl(params, &arg_buf) + let func = &self.functions[*f].2; + self.evaluate_impl(&func.tree, &func.subexpressions, params, &arg_buf) } - EvalTree::Add(a) => { - let mut r = a[0].evaluate_impl(params, args); - for arg in &mut a[1..] { - r += arg.evaluate_impl(params, args); + Expression::Add(a) => { + let mut r = self.evaluate_impl(&a[0], subexpressions, params, args); + for arg in &a[1..] { + r += self.evaluate_impl(arg, subexpressions, params, args); } r } - EvalTree::Mul(m) => { - let mut r = m[0].evaluate_impl(params, args); - for arg in &mut m[1..] { - r *= arg.evaluate_impl(params, args); + Expression::Mul(m) => { + let mut r = self.evaluate_impl(&m[0], subexpressions, params, args); + for arg in &m[1..] { + r *= self.evaluate_impl(arg, subexpressions, params, args); } r } - EvalTree::Pow(p) => { - let (b, e) = &mut **p; - let b_eval = b.evaluate_impl(params, args); + Expression::Pow(p) => { + let (b, e) = &**p; + let b_eval = self.evaluate_impl(b, subexpressions, params, args); if *e >= 0 { b_eval.pow(*e as u64) @@ -778,15 +1234,15 @@ impl EvalTree { b_eval.pow(e.unsigned_abs()).inv() } } - EvalTree::Powf(p) => { - let (b, e) = &mut **p; - let b_eval = b.evaluate_impl(params, args); - let e_eval = e.evaluate_impl(params, args); + Expression::Powf(p) => { + let (b, e) = &**p; + let b_eval = self.evaluate_impl(b, subexpressions, params, args); + let e_eval = self.evaluate_impl(e, subexpressions, params, args); b_eval.powf(&e_eval) } - EvalTree::ReadArg(_, i) => args[*i].clone(), - EvalTree::BuiltinFun(s, a) => { - let arg = a.evaluate_impl(params, args); + Expression::ReadArg(i) => args[*i].clone(), + Expression::BuiltinFun(s, a) => { + let arg = self.evaluate_impl(a, subexpressions, params, args); match *s { State::EXP => arg.exp(), State::LOG => arg.log(), @@ -796,176 +1252,146 @@ impl EvalTree { _ => unreachable!(), } } - EvalTree::SubExpression(_, _) => todo!(), + Expression::SubExpression(s) => { + // TODO: cache + self.evaluate_impl(&subexpressions[*s], subexpressions, params, args) + } } } +} +impl EvalTree { pub fn export_cpp(&self) -> String { - let mut res = String::new(); - - let mut out_preamble = Vec::new(); - let mut processed_subexpr = HashSet::default(); - - let mut funcs = HashMap::default(); - res += "\treturn "; - self.export_cpp_impl( - &mut res, - &mut out_preamble, - &mut processed_subexpr, - &mut funcs, - ); - res.push_str(";\n}\n"); + let mut res = "#include \n#include \n\n".to_string(); - let mut fs = funcs.values().cloned().collect::>(); - fs.sort(); - let mut fs = fs.into_iter().map(|(_, s)| s).collect::>(); + for (name, arg_names, body) in &self.functions { + let mut args = arg_names + .iter() + .map(|x| " T ".to_string() + x.to_string().as_str()) + .collect::>(); + args.insert(0, "T* params".to_string()); - fs.push( - format!("template\nT eval(T* params) {{\n") - + out_preamble.join("").as_str() - + res.as_str(), - ); + res += &format!( + "\ntemplate\nT {}({}) {{\n", + name, + args.join(",") + ); + // our functions are all expressions so we return the expression - fs.push( - "extern \"C\" {\n\tdouble eval_double(double* params) {\n\t\t return eval(params);\n\t}\n}\n" - .to_string() - ); + for (i, s) in body.subexpressions.iter().enumerate() { + res += &format!("\tT Z{}_ = {};\n", i, self.export_cpp_impl(s, arg_names)); + } - fs.push( - "int main() {\n\tstd::cout << eval(new double[]{5.0,6.0,7.0,8.0,9.0,10.0}) << std::endl;\n\treturn 0;\n}" - .to_string(), - ); + let ret = self.export_cpp_impl(&body.tree, arg_names); + res += &format!("\treturn {};\n}}\n", ret); + } - let header = "#include \n#include \n\n"; + res += &format!("\ntemplate\nT eval(T* params) {{\n"); - header.to_string() + fs.join("\n").as_str() - } + for (i, s) in self.expressions.subexpressions.iter().enumerate() { + res += &format!("\tT Z{}_ = {};\n", i, self.export_cpp_impl(s, &[])); + } - fn export_cpp_impl( - &self, - out: &mut String, - out_preamble: &mut Vec, - processed_subexpr: &mut HashSet, - funcs: &mut HashMap, - ) { - match self { - EvalTree::Const(c) => { - out.push_str(&format!("T({})", c)); - } - EvalTree::Parameter(p) => { - out.push_str(&format!("params[{}]", p)); - } - EvalTree::Eval(_, name, arg_names, e_args, f) => { - if funcs.get(name).is_none() { - let mut out = String::new(); - let mut out_preamble = Vec::new(); - let mut processed_subexpr = HashSet::default(); - - let mut args = arg_names - .iter() - .map(|x| "T ".to_string() + x.to_string().as_str()) - .collect::>(); - args.insert(0, "T* params".to_string()); - - // our functions are all expressions so we return the expression - out.push_str("\treturn "); - f.export_cpp_impl(&mut out, &mut out_preamble, &mut processed_subexpr, funcs); - - out.push_str(";\n}\n"); - let l = funcs.len(); - funcs.insert( - name.clone(), - ( - l, - format!("template\nT {}({}) {{\n", name, args.join(",")) - + out_preamble.join("").as_str() - + out.as_str(), - ), - ); - } + let ret = self.export_cpp_impl(&self.expressions.tree, &[]); + res += &format!("\treturn {};\n}}\n", ret); + + res += "\nextern \"C\" {\n\tdouble eval_double(double* params) {\n\t\t return eval(params);\n\t}\n}\n"; - out.push_str(&format!("{}(params", name)); + res += "\nint main() {\n\tstd::cout << eval(new double[]{5.0,6.0,7.0,8.0,9.0,10.0}) << std::endl;\n\treturn 0;\n}"; + + res + } + + fn export_cpp_impl(&self, expr: &Expression, args: &[Symbol]) -> String { + match expr { + Expression::Const(c) => { + format!("T({})", c) + } + Expression::Parameter(p) => { + format!("params[{}]", p) + } + Expression::Eval(id, e_args) => { + let mut r = format!("{}(params", self.functions[*id].0); for a in e_args { - out.push_str(", "); - a.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); + r.push_str(", "); + r += &self.export_cpp_impl(a, args); } - out.push_str(")"); + r.push_str(")"); + r } - EvalTree::Add(a) => { - out.push('('); - a[0].export_cpp_impl(out, out_preamble, processed_subexpr, funcs); + Expression::Add(a) => { + let mut r = "(".to_string(); + r += &self.export_cpp_impl(&a[0], args); for arg in &a[1..] { - out.push_str(" + "); - arg.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); + r.push_str(" + "); + r += &self.export_cpp_impl(arg, args); } - out.push_str(")"); + r.push_str(")"); + r } - EvalTree::Mul(m) => { - out.push('('); - m[0].export_cpp_impl(out, out_preamble, processed_subexpr, funcs); + Expression::Mul(m) => { + let mut r = "(".to_string(); + r += &self.export_cpp_impl(&m[0], args); for arg in &m[1..] { - out.push_str(" * "); - arg.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); + r.push_str(" * "); + r += &self.export_cpp_impl(arg, args); } - out.push(')'); - } - EvalTree::Pow(p) => { - out.push_str("pow("); - p.0.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push_str(", "); - out.push_str(&p.1.to_string()); - out.push(')'); + r.push_str(")"); + r } - EvalTree::Powf(p) => { - out.push_str("powf("); - p.0.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push_str(", "); - p.1.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push(')'); + Expression::Pow(p) => { + let mut r = "pow(".to_string(); + r += &self.export_cpp_impl(&p.0, args); + r.push_str(", "); + r.push_str(&p.1.to_string()); + r.push(')'); + r } - EvalTree::ReadArg(s, _) => { - out.push_str(&format!("{}", s)); + Expression::Powf(p) => { + let mut r = "powf(".to_string(); + r += &self.export_cpp_impl(&p.0, args); + r.push_str(", "); + r += &self.export_cpp_impl(&p.1, args); + r.push(')'); + r } - EvalTree::BuiltinFun(s, a) => match *s { + Expression::ReadArg(s) => args[*s].to_string(), + Expression::BuiltinFun(s, a) => match *s { State::EXP => { - out.push_str("exp("); - a.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push(')'); + let mut r = "exp(".to_string(); + r += &self.export_cpp_impl(a, args); + r.push(')'); + r } State::LOG => { - out.push_str("log("); - a.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push(')'); + let mut r = "log(".to_string(); + r += &self.export_cpp_impl(a, args); + r.push(')'); + r } State::SIN => { - out.push_str("sin("); - a.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push(')'); + let mut r = "sin(".to_string(); + r += &self.export_cpp_impl(a, args); + r.push(')'); + r } State::COS => { - out.push_str("cos("); - a.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push(')'); + let mut r = "cos(".to_string(); + r += &self.export_cpp_impl(a, args); + r.push(')'); + r } State::SQRT => { - out.push_str("sqrt("); - a.export_cpp_impl(out, out_preamble, processed_subexpr, funcs); - out.push(')'); + let mut r = "sqrt(".to_string(); + r += &self.export_cpp_impl(a, args); + r.push(')'); + r } _ => unreachable!(), }, - EvalTree::SubExpression(id, s) => { - if processed_subexpr.contains(id) { - out.push_str(&format!("s{}_", id)); - } else { - processed_subexpr.insert(*id); - let mut sub_out = String::new(); - s.export_cpp_impl(&mut sub_out, out_preamble, processed_subexpr, funcs); - - out_preamble.push(format!("\tT s{}_ = {};\n", id, sub_out)); - out.push_str(&format!("s{}_", id)); - } + Expression::SubExpression(id) => { + format!("Z{}_", id) } } } @@ -973,15 +1399,25 @@ impl EvalTree { impl<'a> AtomView<'a> { /// Convert nested expressions to a tree. - pub fn to_eval_tree T + Copy>( + pub fn to_eval_tree< + T: Clone + Default + std::fmt::Debug + Eq + std::hash::Hash + Ord, + F: Fn(&Rational) -> T + Copy, + >( &self, coeff_map: F, const_map: &HashMap>, params: &[Atom], ) -> EvalTree { - let mut t = self.to_eval_tree_impl(coeff_map, const_map, params, &[]); - t.extract_subexpressions(); - t + let mut funcs = vec![]; + let t = self.to_eval_tree_impl(coeff_map, const_map, params, &[], &mut funcs); + + EvalTree { + expressions: ExpressionWithSubexpressions { + tree: t, + subexpressions: vec![], + }, + functions: funcs, + } } fn to_eval_tree_impl T + Copy>( @@ -990,14 +1426,15 @@ impl<'a> AtomView<'a> { const_map: &HashMap>, params: &[Atom], args: &[Symbol], - ) -> EvalTree { + funcs: &mut Vec<(Symbol, Vec, ExpressionWithSubexpressions)>, + ) -> Expression { if let Some(p) = params.iter().position(|a| a.as_view() == *self) { - return EvalTree::Parameter(p); + return Expression::Parameter(p); } if let Some(c) = const_map.get(&self.into()) { return match c { - ConstOrExpr::Const(c) => EvalTree::Const(c.clone()), + ConstOrExpr::Const(c) => Expression::Const(c.clone()), ConstOrExpr::Expr(name, args, v) => { if !args.is_empty() { panic!( @@ -1007,8 +1444,20 @@ impl<'a> AtomView<'a> { ); } - let r = v.to_eval_tree_impl(coeff_map, const_map, params, args); - EvalTree::Eval(vec![], *name, args.clone(), vec![], Box::new(r)) + if let Some(pos) = funcs.iter().position(|f| f.0 == *name) { + Expression::Eval(pos, vec![]) + } else { + let r = v.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + funcs.push(( + *name, + args.clone(), + ExpressionWithSubexpressions { + tree: r.clone(), + subexpressions: vec![], + }, + )); + Expression::Eval(funcs.len() - 1, vec![]) + } } }; } @@ -1016,14 +1465,14 @@ impl<'a> AtomView<'a> { match self { AtomView::Num(n) => match n.get_coeff_view() { CoefficientView::Natural(n, d) => { - EvalTree::Const(coeff_map(&Rational::Natural(n, d))) + Expression::Const(coeff_map(&Rational::Natural(n, d))) } CoefficientView::Large(l) => { - EvalTree::Const(coeff_map(&Rational::Large(l.to_rat()))) + Expression::Const(coeff_map(&Rational::Large(l.to_rat()))) } CoefficientView::Float(f) => { // TODO: converting back to rational is slow - EvalTree::Const(coeff_map(&f.to_float().to_rational())) + Expression::Const(coeff_map(&f.to_float().to_rational())) } CoefficientView::FiniteField(_, _) => { unimplemented!("Finite field not yet supported for evaluation") @@ -1038,7 +1487,7 @@ impl<'a> AtomView<'a> { let name = v.get_symbol(); if let Some(p) = args.iter().position(|s| *s == name) { - return EvalTree::ReadArg(name, p); + return Expression::ReadArg(p); } panic!( @@ -1051,9 +1500,9 @@ impl<'a> AtomView<'a> { if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) { assert!(f.get_nargs() == 1); let arg = f.iter().next().unwrap(); - let arg_eval = arg.to_eval_tree_impl(coeff_map, const_map, params, args); + let arg_eval = arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); - return EvalTree::BuiltinFun(f.get_symbol(), Box::new(arg_eval)); + return Expression::BuiltinFun(f.get_symbol(), Box::new(arg_eval)); } let symb = InlineVar::new(f.get_symbol()); @@ -1062,7 +1511,7 @@ impl<'a> AtomView<'a> { }; match fun { - ConstOrExpr::Const(t) => EvalTree::Const(t.clone()), + ConstOrExpr::Const(t) => Expression::Const(t.clone()), ConstOrExpr::Expr(name, arg_spec, e) => { if f.get_nargs() != arg_spec.len() { panic!( @@ -1075,50 +1524,67 @@ impl<'a> AtomView<'a> { let eval_args = f .iter() - .map(|arg| arg.to_eval_tree_impl(coeff_map, const_map, params, args)) + .map(|arg| { + arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs) + }) .collect(); - let res = e.to_eval_tree_impl(coeff_map, const_map, params, arg_spec); - EvalTree::Eval( - vec![T::default(); arg_spec.len()], - *name, - arg_spec.clone(), - eval_args, - Box::new(res), - ) + if let Some(pos) = funcs.iter().position(|f| f.0 == *name) { + Expression::Eval(pos, eval_args) + } else { + let r = + e.to_eval_tree_impl(coeff_map, const_map, params, arg_spec, funcs); + funcs.push(( + *name, + arg_spec.clone(), + ExpressionWithSubexpressions { + tree: r.clone(), + subexpressions: vec![], + }, + )); + Expression::Eval(funcs.len() - 1, eval_args) + } } } } AtomView::Pow(p) => { let (b, e) = p.get_base_exp(); - let b_eval = b.to_eval_tree_impl(coeff_map, const_map, params, args); + let b_eval = b.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); if let AtomView::Num(n) = e { if let CoefficientView::Natural(num, den) = n.get_coeff_view() { if den == 1 { - return EvalTree::Pow(Box::new((b_eval, num))); + if num > 1 { + return Expression::Mul(vec![b_eval.clone(); num as usize]); + } + return Expression::Pow(Box::new((b_eval, num))); } } } - let e_eval = e.to_eval_tree_impl(coeff_map, const_map, params, args); - EvalTree::Powf(Box::new((b_eval, e_eval))) + let e_eval = e.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + Expression::Powf(Box::new((b_eval, e_eval))) } AtomView::Mul(m) => { let mut muls = vec![]; for arg in m.iter() { - muls.push(arg.to_eval_tree_impl(coeff_map, const_map, params, args)); + let a = arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs); + if let Expression::Mul(m) = a { + muls.extend(m); + } else { + muls.push(a); + } } - EvalTree::Mul(muls) + Expression::Mul(muls) } AtomView::Add(a) => { let mut adds = vec![]; for arg in a.iter() { - adds.push(arg.to_eval_tree_impl(coeff_map, const_map, params, args)); + adds.push(arg.to_eval_tree_impl(coeff_map, const_map, params, args, funcs)); } - EvalTree::Add(adds) + Expression::Add(adds) } } }