From 248650dcd71dae5343501e0e93f20a2bb04c5d95 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Wed, 26 Jun 2024 15:30:30 +0200 Subject: [PATCH] Add nested expression evaluator --- examples/nested_evaluation.rs | 47 ++++ src/atom.rs | 91 +++++++- src/evaluate.rs | 401 +++++++++++++++++++++++++++++++++- 3 files changed, 528 insertions(+), 11 deletions(-) create mode 100644 examples/nested_evaluation.rs diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs new file mode 100644 index 0000000..970c783 --- /dev/null +++ b/examples/nested_evaluation.rs @@ -0,0 +1,47 @@ +use ahash::HashMap; +use symbolica::{atom::Atom, evaluate::ConstOrExpr, state::State}; + +fn main() { + let e = Atom::parse("x + cos(x) + f(g(x+1),h(x*2)) + p(1)").unwrap(); + let f = Atom::parse("y^2 + z^2").unwrap(); // f(y,z) = y^2+z^2 + let g = Atom::parse("i(y+7)").unwrap(); // g(y) = i(y+7) + let h = Atom::parse("y + 3").unwrap(); // h(y) = y+3 + let i = Atom::parse("y * 2").unwrap(); // i(y) = y*2 + let k = Atom::parse("x+8").unwrap(); // p(1) = x + 8 + + let mut const_map = HashMap::default(); + + let p1 = Atom::parse("p(1)").unwrap(); + let f_s = Atom::new_var(State::get_symbol("f")); + let g_s = Atom::new_var(State::get_symbol("g")); + let h_s = Atom::new_var(State::get_symbol("h")); + let i_s = Atom::new_var(State::get_symbol("i")); + + const_map.insert(p1.into(), ConstOrExpr::Expr(vec![], k.as_view())); + + const_map.insert( + f_s.into(), + ConstOrExpr::Expr( + vec![State::get_symbol("y"), State::get_symbol("z")], + f.as_view(), + ), + ); + const_map.insert( + g_s.into(), + ConstOrExpr::Expr(vec![State::get_symbol("y")], g.as_view()), + ); + const_map.insert( + h_s.into(), + ConstOrExpr::Expr(vec![State::get_symbol("y")], h.as_view()), + ); + const_map.insert( + i_s.into(), + ConstOrExpr::Expr(vec![State::get_symbol("y")], i.as_view()), + ); + + let params = vec![Atom::parse("x").unwrap()]; + + let mut evaluator = e.as_view().evaluator(|r| r.into(), &const_map, ¶ms); + + println!("{}", evaluator.evaluate(&[5.])); +} diff --git a/src/atom.rs b/src/atom.rs index 703935d..279c239 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -29,6 +29,16 @@ pub struct Symbol { is_linear: bool, } +impl std::fmt::Debug for Symbol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.id))?; + for _ in 0..self.wildcard_level { + f.write_str("_")?; + } + Ok(()) + } +} + impl Symbol { /// Create a new variable symbol. This constructor should be used with care as there are no checks /// about the validity of the identifier. @@ -81,16 +91,6 @@ impl Symbol { } } -impl std::fmt::Debug for Symbol { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("{}", self.id))?; - for _ in 0..self.wildcard_level { - f.write_str("_")?; - } - Ok(()) - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AtomType { Num, @@ -192,6 +192,77 @@ impl<'a> From> for AtomView<'a> { } } +/// A copy-on-write structure for `Atom` and `AtomView`. +pub enum AtomOrView<'a> { + Atom(Atom), + View(AtomView<'a>), +} + +impl<'a> PartialEq for AtomOrView<'a> { + #[inline] + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (AtomOrView::Atom(a), AtomOrView::Atom(b)) => a == b, + (AtomOrView::View(a), AtomOrView::View(b)) => a == b, + _ => self.as_view() == other.as_view(), + } + } +} + +impl Eq for AtomOrView<'_> {} + +impl Hash for AtomOrView<'_> { + #[inline] + fn hash(&self, state: &mut H) { + match self { + AtomOrView::Atom(a) => a.as_view().hash(state), + AtomOrView::View(a) => a.hash(state), + } + } +} + +impl<'a> From for AtomOrView<'a> { + fn from(a: Atom) -> AtomOrView<'a> { + AtomOrView::Atom(a) + } +} + +impl<'a> From> for AtomOrView<'a> { + fn from(a: AtomView<'a>) -> AtomOrView<'a> { + AtomOrView::View(a) + } +} + +impl<'a> From<&AtomView<'a>> for AtomOrView<'a> { + fn from(a: &AtomView<'a>) -> AtomOrView<'a> { + AtomOrView::View(*a) + } +} + +impl<'a> AtomOrView<'a> { + pub fn as_view(&'a self) -> AtomView<'a> { + match self { + AtomOrView::Atom(a) => a.as_view(), + AtomOrView::View(a) => *a, + } + } + + pub fn as_mut(&mut self) -> &mut Atom { + match self { + AtomOrView::Atom(a) => a, + AtomOrView::View(a) => { + let mut oa = Atom::default(); + oa.set_from_view(a); + *self = AtomOrView::Atom(oa); + match self { + AtomOrView::Atom(a) => a, + _ => unreachable!(), + } + } + } + } +} + /// A trait for any type that can be converted into an `AtomView`. /// To be used for functions that accept any argument that can be /// converted to an `AtomView`. diff --git a/src/evaluate.rs b/src/evaluate.rs index 65947cb..4f90abc 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -1,7 +1,7 @@ use ahash::HashMap; use crate::{ - atom::{Atom, AtomView, Symbol}, + atom::{representation::InlineVar, Atom, AtomOrView, AtomView, Symbol}, coefficient::CoefficientView, domains::{float::Real, rational::Rational}, state::State, @@ -29,6 +29,11 @@ impl EvaluationFn { } } +pub enum ConstOrExpr<'a, T> { + Const(T), + Expr(Vec, AtomView<'a>), +} + impl Atom { /// Evaluate an expression using a constant map and a function map. /// The constant map can map any literal expression to a value, for example @@ -47,7 +52,401 @@ impl Atom { } } +#[derive(Debug)] +pub enum EvalTree { + Const(T), + Parameter(usize), + Eval(Vec, Vec>, Box>), // first argument is a buffer for the evaluated arguments + Add(Vec>), + Mul(Vec>), + Pow(Box<(EvalTree, i64)>), + Powf(Box<(EvalTree, EvalTree)>), + ReadArg(usize), + BuiltinFun(Symbol, Box>), +} + +pub struct ExpressionEvaluator { + stack: Vec, + instructions: Vec, + result_index: usize, +} + +impl ExpressionEvaluator { + pub fn evaluate(&mut self, params: &[T]) -> T { + for (t, p) in self.stack.iter_mut().zip(params) { + *t = p.clone(); + } + + for i in &self.instructions { + match i { + Instr::Add(r, v) => { + self.stack[*r] = self.stack[v[0]].clone(); + for x in &v[1..] { + let e = self.stack[*x].clone(); + self.stack[*r] += e; + } + } + Instr::Mul(r, v) => { + self.stack[*r] = self.stack[v[0]].clone(); + for x in &v[1..] { + let e = self.stack[*x].clone(); + self.stack[*r] *= e; + } + } + Instr::Pow(r, b, e) => { + if *e >= 0 { + self.stack[*r] = self.stack[*b].pow(*e as u64); + } else { + self.stack[*r] = self.stack[*b].pow(e.unsigned_abs()).inv(); + } + } + Instr::Powf(r, b, e) => { + self.stack[*r] = self.stack[*b].powf(&self.stack[*e]); + } + Instr::BuiltinFun(r, s, arg) => match *s { + State::EXP => self.stack[*r] = self.stack[*arg].exp(), + State::LOG => self.stack[*r] = self.stack[*arg].log(), + State::SIN => self.stack[*r] = self.stack[*arg].sin(), + State::COS => self.stack[*r] = self.stack[*arg].cos(), + State::SQRT => self.stack[*r] = self.stack[*arg].sqrt(), + _ => unreachable!(), + }, + Instr::Copy(d, s) => { + for (o, i) in s.iter().enumerate() { + self.stack[*d + o] = self.stack[*i].clone(); + } + } + } + } + + self.stack[self.result_index].clone() + } +} + +enum Instr { + Add(usize, Vec), + Mul(usize, Vec), + Pow(usize, usize, i64), + Powf(usize, usize, usize), + BuiltinFun(usize, Symbol, usize), + Copy(usize, Vec), // copy arguments into an adjacent array +} + +impl EvalTree { + /// Create a linear version of the tree that can be evaluated more efficiently. + pub fn linearize(&self, param_len: usize) -> ExpressionEvaluator { + let mut stack = vec![T::default(); param_len]; + let mut instructions = vec![]; + let result_index = self.linearize_impl(&mut stack, &mut instructions, 0); + ExpressionEvaluator { + stack, + instructions, + result_index, + } + } + + // Yields the stack index that contains the output. + fn linearize_impl( + &self, + stack: &mut Vec, + instr: &mut Vec, + arg_start: usize, + ) -> usize { + match self { + EvalTree::Const(t) => { + stack.push(t.clone()); // TODO: do once and recycle + stack.len() - 1 + } + EvalTree::Parameter(i) => *i, + EvalTree::Eval(_, args, f) => { + let dest_pos = stack.len(); + for _ in args { + stack.push(T::default()); + } + + let a: Vec<_> = args + .iter() + .map(|x| x.linearize_impl(stack, instr, arg_start)) + .collect(); + + instr.push(Instr::Copy(dest_pos, a)); + f.linearize_impl(stack, instr, dest_pos) + } + EvalTree::Add(a) => { + stack.push(T::default()); + let res = stack.len() - 1; + + let add = Instr::Add( + res, + a.iter() + .map(|x| x.linearize_impl(stack, instr, arg_start)) + .collect(), + ); + instr.push(add); + + res + } + EvalTree::Mul(m) => { + stack.push(T::default()); + let res = stack.len() - 1; + + let mul = Instr::Mul( + res, + m.iter() + .map(|x| x.linearize_impl(stack, instr, arg_start)) + .collect(), + ); + instr.push(mul); + + res + } + EvalTree::Pow(p) => { + stack.push(T::default()); + let res = stack.len() - 1; + let b = p.0.linearize_impl(stack, instr, arg_start); + + instr.push(Instr::Pow(res, b, p.1)); + res + } + EvalTree::Powf(p) => { + stack.push(T::default()); + let res = stack.len() - 1; + let b = p.0.linearize_impl(stack, instr, arg_start); + let e = p.1.linearize_impl(stack, instr, arg_start); + + instr.push(Instr::Powf(res, b, e)); + res + } + EvalTree::ReadArg(a) => arg_start + *a, + EvalTree::BuiltinFun(s, v) => { + stack.push(T::default()); + let arg = v.linearize_impl(stack, instr, arg_start); + let c = Instr::BuiltinFun(stack.len() - 1, *s, arg); + instr.push(c); + stack.len() - 1 + } + } + } +} + +impl EvalTree { + /// Evaluate the evaluation tree. Consider converting to a linear form for repeated evaluation. + pub fn evaluate(&mut self, params: &[T]) -> T { + self.evaluate_impl(params, &[]) + } + + fn evaluate_impl(&mut self, params: &[T], args: &[T]) -> T { + match self { + EvalTree::Const(c) => c.clone(), + EvalTree::Parameter(p) => params[*p].clone(), + EvalTree::Eval(arg_buf, e_args, f) => { + for (b, a) in arg_buf.iter_mut().zip(e_args.iter_mut()) { + *b = a.evaluate_impl(params, args); + } + + f.evaluate_impl(params, &arg_buf) + } + EvalTree::Add(a) => { + let mut r = a[0].evaluate_impl(params, args); + for arg in &mut a[1..] { + r += arg.evaluate_impl(params, args); + } + r + } + EvalTree::Mul(m) => { + let mut r = m[0].evaluate_impl(params, args); + for arg in &mut m[1..] { + r *= arg.evaluate_impl(params, args); + } + r + } + EvalTree::Pow(p) => { + let (b, e) = &mut **p; + let b_eval = b.evaluate_impl(params, args); + + if *e >= 0 { + b_eval.pow(*e as u64) + } else { + b_eval.pow(e.unsigned_abs()).inv() + } + } + EvalTree::Powf(p) => { + let (b, e) = &mut **p; + let b_eval = b.evaluate_impl(params, args); + let e_eval = e.evaluate_impl(params, args); + b_eval.powf(&e_eval) + } + EvalTree::ReadArg(i) => args[*i].clone(), + EvalTree::BuiltinFun(s, a) => { + let arg = a.evaluate_impl(params, args); + match *s { + State::EXP => arg.exp(), + State::LOG => arg.log(), + State::SIN => arg.sin(), + State::COS => arg.cos(), + State::SQRT => arg.sqrt(), + _ => unreachable!(), + } + } + } + } +} + impl<'a> AtomView<'a> { + /// Convert nested expressions to a from suitable for evaluation. + pub fn evaluator T + Copy>( + &self, + coeff_map: F, + const_map: &HashMap>, + params: &[Atom], + ) -> ExpressionEvaluator { + let tree = self.to_eval_tree(coeff_map, const_map, params); + tree.linearize(params.len()) + } + + /// Convert nested expressions to a tree. + pub fn to_eval_tree T + Copy>( + &self, + coeff_map: F, + const_map: &HashMap>, + params: &[Atom], + ) -> EvalTree { + self.to_eval_tree_impl(coeff_map, const_map, params, &[]) + } + + fn to_eval_tree_impl T + Copy>( + &self, + coeff_map: F, + const_map: &HashMap>, + params: &[Atom], + args: &[Symbol], + ) -> EvalTree { + if let Some(p) = params.iter().position(|a| a.as_view() == *self) { + return EvalTree::Parameter(p); + } + + if let Some(c) = const_map.get(&self.into()) { + return match c { + ConstOrExpr::Const(c) => EvalTree::Const(c.clone()), + ConstOrExpr::Expr(args, v) => { + if !args.is_empty() { + panic!( + "Function {} called with wrong number of arguments: 0 vs {}", + self, + args.len() + ); + } + + let r = v.to_eval_tree_impl(coeff_map, const_map, params, args); + EvalTree::Eval(vec![], vec![], Box::new(r)) + } + }; + } + + match self { + AtomView::Num(n) => match n.get_coeff_view() { + CoefficientView::Natural(n, d) => { + EvalTree::Const(coeff_map(&Rational::Natural(n, d))) + } + CoefficientView::Large(l) => { + EvalTree::Const(coeff_map(&Rational::Large(l.to_rat()))) + } + CoefficientView::Float(f) => { + // TODO: converting back to rational is slow + EvalTree::Const(coeff_map(&f.to_float().to_rational())) + } + CoefficientView::FiniteField(_, _) => { + unimplemented!("Finite field not yet supported for evaluation") + } + CoefficientView::RationalPolynomial(_) => { + unimplemented!( + "Rational polynomial coefficient not yet supported for evaluation" + ) + } + }, + AtomView::Var(v) => { + let name = v.get_symbol(); + + if let Some(p) = args.iter().position(|s| *s == name) { + return EvalTree::ReadArg(p); + } + + panic!( + "Variable {} not in constant map", + State::get_name(v.get_symbol()) + ); + } + AtomView::Fun(f) => { + let name = f.get_symbol(); + if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) { + assert!(f.get_nargs() == 1); + let arg = f.iter().next().unwrap(); + let arg_eval = arg.to_eval_tree_impl(coeff_map, const_map, params, args); + + return EvalTree::BuiltinFun(f.get_symbol(), Box::new(arg_eval)); + } + + let symb = InlineVar::new(f.get_symbol()); + let Some(fun) = const_map.get(&symb.as_view().into()) else { + panic!("Undefined function {}", State::get_name(f.get_symbol())); + }; + + match fun { + ConstOrExpr::Const(t) => EvalTree::Const(t.clone()), + ConstOrExpr::Expr(arg_spec, e) => { + if f.get_nargs() != arg_spec.len() { + panic!( + "Function {} called with wrong number of arguments: {} vs {}", + f.get_symbol(), + f.get_nargs(), + arg_spec.len() + ); + } + + let eval_args = f + .iter() + .map(|arg| arg.to_eval_tree_impl(coeff_map, const_map, params, args)) + .collect(); + let res = e.to_eval_tree_impl(coeff_map, const_map, params, arg_spec); + + EvalTree::Eval(vec![T::default(); arg_spec.len()], eval_args, Box::new(res)) + } + } + } + AtomView::Pow(p) => { + let (b, e) = p.get_base_exp(); + let b_eval = b.to_eval_tree_impl(coeff_map, const_map, params, args); + + if let AtomView::Num(n) = e { + if let CoefficientView::Natural(num, den) = n.get_coeff_view() { + if den == 1 { + return EvalTree::Pow(Box::new((b_eval, num))); + } + } + } + + let e_eval = e.to_eval_tree_impl(coeff_map, const_map, params, args); + EvalTree::Powf(Box::new((b_eval, e_eval))) + } + AtomView::Mul(m) => { + let mut muls = vec![]; + for arg in m.iter() { + muls.push(arg.to_eval_tree_impl(coeff_map, const_map, params, args)); + } + + EvalTree::Mul(muls) + } + AtomView::Add(a) => { + let mut adds = vec![]; + for arg in a.iter() { + adds.push(arg.to_eval_tree_impl(coeff_map, const_map, params, args)); + } + + EvalTree::Add(adds) + } + } + } + /// Evaluate an expression using a constant map and a function map. /// The constant map can map any literal expression to a value, for example /// a variable or a function with fixed arguments.