Skip to content

Commit

Permalink
Improve CSE quality
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 26, 2024
1 parent 743806d commit 78605e6
Showing 1 changed file with 118 additions and 18 deletions.
136 changes: 118 additions & 18 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,7 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {
if p.1 > 1 {
instr.push(Instr::Mul(res, vec![b; p.1 as usize]));
} else {
instr.push(Instr::Pow(res, b, p.1));
instr.push(Instr::Pow(res, b, p.1));
}
res
}
Expand Down Expand Up @@ -1521,48 +1521,144 @@ impl<T: Clone + Default + std::fmt::Debug + Eq + std::hash::Hash + Ord> SplitExp
}

for t in &mut self.tree {
t.replace_subexpression(&h);
t.replace_subexpression(&h, false);
}

let mut v: Vec<_> = h.into_iter().map(|(k, v)| (v, k)).collect();
let mut v: Vec<_> = h.clone().into_iter().map(|(k, v)| (v, k)).collect();

v.sort();

for (_, x) in v {
// replace subexpressions in subexpressions and
// sort them based on their dependencies
for (_, mut x) in v {
x.replace_subexpression(&h, true);
self.subexpressions.push(x);
}

let mut dep_tree = vec![];
for (i, s) in self.subexpressions.iter().enumerate() {
let mut deps = vec![];
s.get_dependent_subexpressions(&mut deps);
dep_tree.push((i, deps.clone()));
}

let mut rename = HashMap::default();
let mut new_subs = vec![];
let mut i = 0;
while !dep_tree.is_empty() {
if dep_tree[i].1.iter().all(|x| rename.contains_key(x)) {
rename.insert(dep_tree[i].0, new_subs.len());
new_subs.push(self.subexpressions[dep_tree[i].0].clone());
dep_tree.swap_remove(i);
if i == dep_tree.len() {
i = 0;
}
} else {
i = (i + 1) % dep_tree.len();
}
}

for x in &mut new_subs {
x.rename_subexpression(&rename);
}
for t in &mut self.tree {
t.rename_subexpression(&rename);
}

self.subexpressions = new_subs;
}
}

impl<T: Clone + Default + std::fmt::Debug + Eq + std::hash::Hash + Ord> Expression<T> {
fn replace_subexpression(&mut self, subexp: &HashMap<Expression<T>, usize>) {
if let Some(i) = subexp.get(&self) {
*self = Expression::SubExpression(*i);
return;
fn rename_subexpression(&mut self, subexp: &HashMap<usize, usize>) {
match self {
Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {}
Expression::Eval(_, ae) => {
for arg in &mut *ae {
arg.rename_subexpression(subexp);
}
}
Expression::Add(a) | Expression::Mul(a) => {
for arg in a {
arg.rename_subexpression(subexp);
}
}
Expression::Pow(p) => {
p.0.rename_subexpression(subexp);
}
Expression::Powf(p) => {
p.0.rename_subexpression(subexp);
p.1.rename_subexpression(subexp);
}
Expression::BuiltinFun(_, a) => {
a.rename_subexpression(subexp);
}
Expression::SubExpression(i) => {
*self = Expression::SubExpression(*subexp.get(i).unwrap());
}
}
}

fn get_dependent_subexpressions(&self, dep: &mut Vec<usize>) {
match self {
Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {}
Expression::Eval(_, ae) => {
for arg in &mut *ae {
arg.replace_subexpression(subexp);
for arg in ae {
arg.get_dependent_subexpressions(dep);
}
}
Expression::Add(a) | Expression::Mul(a) => {
for arg in a {
arg.replace_subexpression(subexp);
arg.get_dependent_subexpressions(dep);
}
}
Expression::Pow(p) => {
p.0.replace_subexpression(subexp);
p.0.get_dependent_subexpressions(dep);
}
Expression::Powf(p) => {
p.0.replace_subexpression(subexp);
p.1.replace_subexpression(subexp);
p.0.get_dependent_subexpressions(dep);
p.1.get_dependent_subexpressions(dep);
}
Expression::BuiltinFun(_, _) => {}
Expression::SubExpression(_) => {
unimplemented!("The expression should not already have subexpressions")
Expression::BuiltinFun(_, a) => {
a.get_dependent_subexpressions(dep);
}
Expression::SubExpression(i) => {
dep.push(*i);
}
}
}

fn replace_subexpression(&mut self, subexp: &HashMap<Expression<T>, usize>, skip_root: bool) {
if !skip_root {
if let Some(i) = subexp.get(&self) {
*self = Expression::SubExpression(*i);
return;
}
}

match self {
Expression::Const(_) | Expression::Parameter(_) | Expression::ReadArg(_) => {}
Expression::Eval(_, ae) => {
for arg in &mut *ae {
arg.replace_subexpression(subexp, false);
}
}
Expression::Add(a) | Expression::Mul(a) => {
for arg in &mut *a {
arg.replace_subexpression(subexp, false);
}

a.sort();
}
Expression::Pow(p) => {
p.0.replace_subexpression(subexp, false);
}
Expression::Powf(p) => {
p.0.replace_subexpression(subexp, false);
p.1.replace_subexpression(subexp, false);
}
Expression::BuiltinFun(_, _) => {}
Expression::SubExpression(_) => {}
}
}

Expand Down Expand Up @@ -2440,7 +2536,7 @@ impl<'a> AtomView<'a> {
})
}

fn to_eval_tree_impl<T: Clone + Default, F: Fn(&Rational) -> T + Copy>(
fn to_eval_tree_impl<T: Clone + Default + Ord, F: Fn(&Rational) -> T + Copy>(
&self,
coeff_map: F,
fn_map: &FunctionMap<'a, T>,
Expand Down Expand Up @@ -2594,6 +2690,8 @@ impl<'a> AtomView<'a> {
}
}

muls.sort();

Ok(Expression::Mul(muls))
}
AtomView::Add(a) => {
Expand All @@ -2602,6 +2700,8 @@ impl<'a> AtomView<'a> {
adds.push(arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)?);
}

adds.sort();

Ok(Expression::Add(adds))
}
}
Expand Down

0 comments on commit 78605e6

Please sign in to comment.