Skip to content

Commit

Permalink
Store parameter count in EvalTree
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 27, 2024
1 parent c0e693c commit c22578f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
2 changes: 1 addition & 1 deletion examples/nested_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn main() {
println!("C++ time {:#?}", t.elapsed());

let t2 = tree.map_coeff::<f64, _>(&|r| r.into());
let mut evaluator: ExpressionEvaluator<f64> = t2.linearize(params.len());
let mut evaluator: ExpressionEvaluator<f64> = t2.linearize();

evaluator.evaluate_multiple(&params, &mut out);
println!("Eval: {}, {}", out[0], out[1]);
Expand Down
29 changes: 17 additions & 12 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ pub struct SplitExpression<T> {
pub struct EvalTree<T> {
functions: Vec<(String, Vec<Symbol>, SplitExpression<T>)>,
expressions: SplitExpression<T>,
param_count: usize,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
Expand Down Expand Up @@ -1383,7 +1384,7 @@ enum Instr {
BuiltinFun(usize, Symbol, usize),
}

impl<T: Clone + Default + PartialEq> SplitExpression<T> {
impl<T: Clone + PartialEq> SplitExpression<T> {
pub fn map_coeff<T2, F: Fn(&T) -> T2>(&self, f: &F) -> SplitExpression<T2> {
SplitExpression {
tree: self.tree.iter().map(|x| x.map_coeff(f)).collect(),
Expand Down Expand Up @@ -1428,7 +1429,7 @@ impl<T: Clone + Default + PartialEq> SplitExpression<T> {
}
}

impl<T: Clone + Default + PartialEq> Expression<T> {
impl<T: Clone + PartialEq> Expression<T> {
pub fn map_coeff<T2, F: Fn(&T) -> T2>(&self, f: &F) -> Expression<T2> {
match self {
Expression::Const(c) => Expression::Const(f(c)),
Expand Down Expand Up @@ -1495,7 +1496,7 @@ impl<T: Clone + Default + PartialEq> Expression<T> {
}
}

impl<T: Clone + Default + PartialEq> EvalTree<T> {
impl<T: Clone + PartialEq> EvalTree<T> {
pub fn map_coeff<T2, F: Fn(&T) -> T2>(&self, f: &F) -> EvalTree<T2> {
EvalTree {
expressions: SplitExpression {
Expand All @@ -1517,6 +1518,7 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {
.iter()
.map(|(s, a, e)| (s.clone(), a.clone(), e.map_coeff(f)))
.collect(),
param_count: self.param_count,
}
}

Expand All @@ -1527,13 +1529,15 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {

self.expressions.unnest(max_depth);
}
}

impl<T: Clone + Default + PartialEq> EvalTree<T> {
/// Create a linear version of the tree that can be evaluated more efficiently.
pub fn linearize(mut self, param_count: usize) -> ExpressionEvaluator<T> {
let mut stack = vec![T::default(); param_count];
pub fn linearize(mut self) -> ExpressionEvaluator<T> {
let mut stack = vec![T::default(); self.param_count];

// strip every constant and move them into the stack after the params
self.strip_constants(&mut stack, param_count);
self.strip_constants(&mut stack);
let reserved_indices = stack.len();

let mut sub_expr_pos = HashMap::default();
Expand All @@ -1555,7 +1559,7 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {

let mut e = ExpressionEvaluator {
stack,
param_count,
param_count: self.param_count,
reserved_indices,
instructions,
result_indices,
Expand All @@ -1565,22 +1569,22 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {
e
}

fn strip_constants(&mut self, stack: &mut Vec<T>, param_len: usize) {
fn strip_constants(&mut self, stack: &mut Vec<T>) {
for t in &mut self.expressions.tree {
t.strip_constants(stack, param_len);
t.strip_constants(stack, self.param_count);
}

for e in &mut self.expressions.subexpressions {
e.strip_constants(stack, param_len);
e.strip_constants(stack, self.param_count);
}

for (_, _, e) in &mut self.functions {
for t in &mut e.tree {
t.strip_constants(stack, param_len);
t.strip_constants(stack, self.param_count);
}

for e in &mut e.subexpressions {
e.strip_constants(stack, param_len);
e.strip_constants(stack, self.param_count);
}
}
}
Expand Down Expand Up @@ -3021,6 +3025,7 @@ impl<'a> AtomView<'a> {
subexpressions: vec![],
},
functions: funcs,
param_count: params.len(),
})
}

Expand Down

0 comments on commit c22578f

Please sign in to comment.