diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs index bc8aeeb..f0135eb 100644 --- a/examples/nested_evaluation.rs +++ b/examples/nested_evaluation.rs @@ -2,13 +2,14 @@ use std::{process::Command, time::Instant}; use ahash::HashMap; use symbolica::{ - atom::Atom, + atom::{Atom, AtomView}, evaluate::{ConstOrExpr, ExpressionEvaluator}, state::State, }; fn main() { - let e = Atom::parse("x + cos(x) + f(g(x+1),h(x*2)) + p(1)").unwrap(); + let e1 = Atom::parse("x + cos(x) + f(g(x+1),h(x*2)) + p(1)").unwrap(); + let e2 = Atom::parse("x + h(x*2) + cos(x)").unwrap(); let f = Atom::parse("y^2 + z^2*y^2").unwrap(); let g = Atom::parse("i(y+7)+x*i(y+7)*(y-1)").unwrap(); let h = Atom::parse("y*(1+x*(1+x^2)) + y^2*(1+x*(1+x^2))^2 + 3*(1+x^2)").unwrap(); @@ -63,7 +64,12 @@ fn main() { let params = vec![Atom::parse("x").unwrap()]; - let mut tree = e.as_view().to_eval_tree(|r| r.clone(), &const_map, ¶ms); + let mut tree = AtomView::to_eval_tree_multiple( + &[e1.as_view(), e2.as_view()], + |r| r.clone(), + &const_map, + ¶ms, + ); // optimize the tree using an occurrence-order Horner scheme println!("Op original {:?}", tree.count_operations()); @@ -73,9 +79,6 @@ fn main() { tree.common_subexpression_elimination(); println!("op CSSE {:?}", tree.count_operations()); - let cpp = tree.export_cpp(); - println!("{}", cpp); // print C++ code - tree.common_pair_elimination(); println!("op CPE {:?}", tree.count_operations()); @@ -97,17 +100,20 @@ fn main() { unsafe { let lib = libloading::Library::new("./libneval.so").unwrap(); - let func: libloading::Symbol f64> = - lib.get(b"eval_double").unwrap(); + let func: libloading::Symbol< + unsafe extern "C" fn(params: *const f64, out: *mut f64) -> f64, + > = lib.get(b"eval_double").unwrap(); let params = vec![5.]; - println!("Eval from C++: {}", func(params.as_ptr())); + let mut out = vec![0., 0.]; + func(params.as_ptr(), out.as_mut_ptr()); + println!("Eval from C++: {}, {}", out[0], out[1]); // benchmark let t = Instant::now(); for _ in 0..1000000 { - let _ = func(params.as_ptr()); + let _ = func(params.as_ptr(), out.as_mut_ptr()); } println!("C++ time {:#?}", t.elapsed()); }; @@ -115,13 +121,15 @@ fn main() { let t2 = tree.map_coeff::(&|r| r.into()); let mut evaluator: ExpressionEvaluator = t2.linearize(params.len()); - println!("Eval: {}", evaluator.evaluate(&[5.])); + let mut out = vec![0., 0.]; + evaluator.evaluate_multiple(&[5.], &mut out); + println!("Eval: {}, {}", out[0], out[1]); // benchmark let params = vec![5.]; let t = Instant::now(); for _ in 0..1000000 { - let _ = evaluator.evaluate(¶ms); + evaluator.evaluate_multiple(¶ms, &mut out); } println!("Eager time {:#?}", t.elapsed()); } diff --git a/src/evaluate.rs b/src/evaluate.rs index aab1fd8..2ee9916 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -56,7 +56,7 @@ impl Atom { } pub struct ExpressionWithSubexpressions { - pub tree: Expression, + pub tree: Vec>, pub subexpressions: Vec>, } @@ -83,11 +83,17 @@ pub struct ExpressionEvaluator { stack: Vec, reserved_indices: usize, instructions: Vec, - result_index: usize, + result_indices: Vec, } impl ExpressionEvaluator { pub fn evaluate(&mut self, params: &[T]) -> T { + let mut res = T::new_zero(); + self.evaluate_multiple(params, std::slice::from_mut(&mut res)); + res + } + + pub fn evaluate_multiple(&mut self, params: &[T], out: &mut [T]) { for (t, p) in self.stack.iter_mut().zip(params) { *t = p.clone(); } @@ -132,7 +138,9 @@ impl ExpressionEvaluator { } } - self.stack[self.result_index].clone() + for (o, i) in out.iter_mut().zip(&self.result_indices) { + *o = self.stack[*i].clone(); + } } } @@ -162,6 +170,11 @@ impl ExpressionEvaluator { last_use[i] = self.instructions.len(); } + // prevent the output slots from being overwritten + for i in &self.result_indices { + last_use[*i] = self.instructions.len(); + } + let mut rename_map: Vec<_> = (0..self.stack.len()).collect(); // identity map let mut max_reg = self.reserved_indices; @@ -213,7 +226,9 @@ impl ExpressionEvaluator { self.stack.truncate(max_reg + 1); - self.result_index = rename_map[self.result_index]; + for i in &mut self.result_indices { + *i = rename_map[*i]; + } } } @@ -229,7 +244,7 @@ enum Instr { impl ExpressionWithSubexpressions { pub fn map_coeff T2>(&self, f: &F) -> ExpressionWithSubexpressions { ExpressionWithSubexpressions { - tree: self.tree.map_coeff(f), + tree: self.tree.iter().map(|x| x.map_coeff(f)).collect(), subexpressions: self.subexpressions.iter().map(|x| x.map_coeff(f)).collect(), } } @@ -306,7 +321,12 @@ impl EvalTree { pub fn map_coeff T2>(&self, f: &F) -> EvalTree { EvalTree { expressions: ExpressionWithSubexpressions { - tree: self.expressions.tree.map_coeff(f), + tree: self + .expressions + .tree + .iter() + .map(|x| x.map_coeff(f)) + .collect(), subexpressions: self .expressions .subexpressions @@ -332,20 +352,26 @@ impl EvalTree { let mut sub_expr_pos = HashMap::default(); let mut instructions = vec![]; - let result_index = self.linearize_impl( - &self.expressions.tree, - &self.expressions.subexpressions, - &mut stack, - &mut instructions, - &mut sub_expr_pos, - &[], - ); + + let mut result_indices = vec![]; + + for t in &self.expressions.tree { + let result_index = self.linearize_impl( + &t, + &self.expressions.subexpressions, + &mut stack, + &mut instructions, + &mut sub_expr_pos, + &[], + ); + result_indices.push(result_index); + } let mut e = ExpressionEvaluator { stack, reserved_indices, instructions, - result_index, + result_indices, }; e.optimize_stack(); @@ -353,13 +379,19 @@ impl EvalTree { } fn strip_constants(&mut self, stack: &mut Vec, param_len: usize) { - self.expressions.tree.strip_constants(stack, param_len); + for t in &mut self.expressions.tree { + t.strip_constants(stack, param_len); + } + for e in &mut self.expressions.subexpressions { e.strip_constants(stack, param_len); } for (_, _, e) in &mut self.functions { - e.tree.strip_constants(stack, param_len); + for t in &mut e.tree { + t.strip_constants(stack, param_len); + } + for e in &mut e.subexpressions { e.strip_constants(stack, param_len); } @@ -393,7 +425,7 @@ impl EvalTree { let func = &self.functions[*id].2; self.linearize_impl( - &func.tree, + &func.tree[0], &func.subexpressions, stack, instr, @@ -480,13 +512,19 @@ impl EvalTree { impl EvalTree { pub fn horner_scheme(&mut self) { - self.expressions.tree.horner_scheme(); + for t in &mut self.expressions.tree { + t.horner_scheme(); + } + for e in &mut self.expressions.subexpressions { e.horner_scheme(); } for (_, _, e) in &mut self.functions { - e.tree.horner_scheme(); + for t in &mut e.tree { + t.horner_scheme(); + } + for e in &mut e.subexpressions { e.horner_scheme(); } @@ -705,17 +743,10 @@ impl Expression { impl EvalTree { pub fn common_subexpression_elimination(&mut self) { - assert!(self.expressions.subexpressions.is_empty()); // TODO: remove this limitation - - self.expressions.subexpressions.extend( - self.expressions - .tree - .extract_subexpressions(self.expressions.subexpressions.len()), - ); + self.expressions.common_subexpression_elimination(); for (_, _, e) in &mut self.functions { - e.subexpressions - .extend(e.tree.extract_subexpressions(e.subexpressions.len())); + e.common_subexpression_elimination(); } } @@ -740,23 +771,37 @@ impl EvalTree } } -impl Expression { - fn extract_subexpressions(&mut self, sub_expr_start: usize) -> Vec> { +impl + ExpressionWithSubexpressions +{ + pub fn common_subexpression_elimination(&mut self) { let mut h = HashMap::default(); - self.find_subexpression(&mut h); + + for t in &mut self.tree { + t.find_subexpression(&mut h); + } h.retain(|_, v| *v > 1); + + // make the second argument a unique index of the subexpression for (i, v) in h.values_mut().enumerate() { - *v = sub_expr_start + i; // make the second argument a unique index of the subexpression + *v = self.subexpressions.len() + i; } - self.replace_subexpression(&h); + for t in &mut self.tree { + t.replace_subexpression(&h); + } let mut v: Vec<_> = h.into_iter().map(|(k, v)| (v, k)).collect(); v.sort(); - v.into_iter().map(|(_, x)| x).collect() + + for (_, x) in v { + self.subexpressions.push(x); + } } +} +impl Expression { fn replace_subexpression(&mut self, subexp: &HashMap, usize>) { if let Some(i) = subexp.get(&self) { *self = Expression::SubExpression(*i); @@ -844,7 +889,9 @@ impl e.find_common_pairs(&mut pair_count); } - self.tree.find_common_pairs(&mut pair_count); + for t in &self.tree { + t.find_common_pairs(&mut pair_count); + } let mut v: Vec<_> = pair_count.into_iter().collect(); v.retain(|x| x.1 > 1); @@ -858,7 +905,9 @@ impl for ((is_add, l, r), _) in &v { let id = self.subexpressions.len(); - self.tree.replace_common_pair(*is_add, l, r, id); + for t in &mut self.tree { + t.replace_common_pair(*is_add, l, r, id); + } let mut first_replace = None; for (i, e) in &mut self.subexpressions.iter_mut().enumerate() { @@ -881,7 +930,9 @@ impl self.subexpressions[k].shift_subexpr(i, id); } - self.tree.shift_subexpr(i, id); + for t in &mut self.tree { + t.shift_subexpr(i, id); + } self.subexpressions.insert(i, pair); } else { @@ -897,7 +948,9 @@ impl e.rename_subexpr(i, n); } - self.tree.rename_subexpr(i, n); + for t in &mut self.tree { + t.rename_subexpr(i, n); + } } } @@ -916,8 +969,13 @@ impl mul += em; } - let (ea, em) = self.tree.count_operations(); - (add + ea, mul + em) + for e in &self.tree { + let (ea, em) = e.count_operations(); + add += ea; + mul += em; + } + + (add, mul) } } @@ -1182,13 +1240,10 @@ impl Expressi impl EvalTree { /// Evaluate the evaluation tree. Consider converting to a linear form for repeated evaluation. - pub fn evaluate(&mut self, params: &[T]) -> T { - self.evaluate_impl( - &self.expressions.tree, - &self.expressions.subexpressions, - params, - &[], - ) + pub fn evaluate(&mut self, params: &[T], out: &mut [T]) { + for (o, e) in out.iter_mut().zip(&self.expressions.tree) { + *o = self.evaluate_impl(&e, &self.expressions.subexpressions, params, &[]) + } } fn evaluate_impl( @@ -1208,7 +1263,7 @@ impl EvalTree { } let func = &self.functions[*f].2; - self.evaluate_impl(&func.tree, &func.subexpressions, params, &arg_buf) + self.evaluate_impl(&func.tree[0], &func.subexpressions, params, &arg_buf) } Expression::Add(a) => { let mut r = self.evaluate_impl(&a[0], subexpressions, params, args); @@ -1282,22 +1337,27 @@ impl EvalTree { res += &format!("\tT Z{}_ = {};\n", i, self.export_cpp_impl(s, arg_names)); } - let ret = self.export_cpp_impl(&body.tree, arg_names); + if body.tree.len() > 1 { + panic!("Tensor functions not supported yet"); + } + + let ret = self.export_cpp_impl(&body.tree[0], arg_names); res += &format!("\treturn {};\n}}\n", ret); } - res += &format!("\ntemplate\nT eval(T* params) {{\n"); + res += &format!("\ntemplate\nvoid eval(T* params, T* out) {{\n"); for (i, s) in self.expressions.subexpressions.iter().enumerate() { res += &format!("\tT Z{}_ = {};\n", i, self.export_cpp_impl(s, &[])); } - let ret = self.export_cpp_impl(&self.expressions.tree, &[]); - res += &format!("\treturn {};\n}}\n", ret); + for (i, e) in self.expressions.tree.iter().enumerate() { + res += &format!("\tout[{}] = {};\n", i, self.export_cpp_impl(&e, &[])); + } - res += "\nextern \"C\" {\n\tdouble eval_double(double* params) {\n\t\t return eval(params);\n\t}\n}\n"; + res += "\treturn;\n}\n"; - res += "\nint main() {\n\tstd::cout << eval(new double[]{5.0,6.0,7.0,8.0,9.0,10.0}) << std::endl;\n\treturn 0;\n}"; + res += "\nextern \"C\" {\n\tvoid eval_double(double* params, double* out) {\n\t\teval(params, out);\n\t\treturn;\n\t}\n}\n"; res } @@ -1407,13 +1467,29 @@ impl<'a> AtomView<'a> { coeff_map: F, const_map: &HashMap>, params: &[Atom], + ) -> EvalTree { + Self::to_eval_tree_multiple(std::slice::from_ref(self), coeff_map, const_map, params) + } + + /// Convert nested expressions to a tree. + pub fn to_eval_tree_multiple< + T: Clone + Default + std::fmt::Debug + Eq + std::hash::Hash + Ord, + F: Fn(&Rational) -> T + Copy, + >( + exprs: &[Self], + coeff_map: F, + const_map: &HashMap>, + params: &[Atom], ) -> EvalTree { let mut funcs = vec![]; - let t = self.to_eval_tree_impl(coeff_map, const_map, params, &[], &mut funcs); + let tree = exprs + .iter() + .map(|t| t.to_eval_tree_impl(coeff_map, const_map, params, &[], &mut funcs)) + .collect(); EvalTree { expressions: ExpressionWithSubexpressions { - tree: t, + tree, subexpressions: vec![], }, functions: funcs, @@ -1452,7 +1528,7 @@ impl<'a> AtomView<'a> { *name, args.clone(), ExpressionWithSubexpressions { - tree: r.clone(), + tree: vec![r.clone()], subexpressions: vec![], }, )); @@ -1538,7 +1614,7 @@ impl<'a> AtomView<'a> { *name, arg_spec.clone(), ExpressionWithSubexpressions { - tree: r.clone(), + tree: vec![r.clone()], subexpressions: vec![], }, ));