Skip to content

Commit

Permalink
Horner scheme improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 2, 2024
1 parent dc0d807 commit 654eb66
Showing 1 changed file with 86 additions and 23 deletions.
109 changes: 86 additions & 23 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,34 +422,79 @@ impl EvalTree<Rational> {
return;
};

// TODO: find power to extract, now we do just one
let mut max_pow: Option<i64> = None;
for x in &*a {
if let EvalTree::Mul(m) = x {
let mut pow_counter = 0;
for y in m {
if let EvalTree::Pow(p) = y {
if p.0 == scheme[0] {
pow_counter += p.1;
}
} else if y == &scheme[0] {
pow_counter += 1; // support x*x*x^3 in term
}
}

if pow_counter > 0 && (max_pow.is_none() || pow_counter < max_pow.unwrap()) {
max_pow = Some(pow_counter);
}
} else if x == &scheme[0] {
max_pow = Some(1);
}
}

// TODO: jump to next variable if the current variable only appears in one factor?
// this will improve the scheme but may hide common subexpressions?

let Some(max_pow) = max_pow else {
return self.apply_horner_scheme(&scheme[1..]);
};

let mut contains = vec![];
let mut rest = vec![];

for x in a {
let mut found = false;
if let EvalTree::Mul(m) = x {
for (p, y) in m.iter_mut().enumerate() {
let mut pow_counter = 0;

m.retain(|y| {
if let EvalTree::Pow(p) = y {
if p.0 == scheme[0] {
found = true;
if p.1 == 2 {
*y = p.0.clone(); // TODO: prevent clone
pow_counter += p.1;
false
} else {
p.1 -= 1;
}
true
}
} else if y == &scheme[0] {
found = true;
// remove from prod
m.remove(p);
if m.len() == 1 {
*x = m[0].clone();
}
break;
pow_counter += 1;
false
} else {
true
}
});

if pow_counter > max_pow {
if pow_counter > max_pow + 1 {
m.push(EvalTree::Pow(Box::new((
scheme[0].clone(),
pow_counter - max_pow,
))));
} else {
m.push(scheme[0].clone());
}

m.sort();
}

if m.is_empty() {
*x = EvalTree::Const(Rational::one());
} else if m.len() == 1 {
*x = m.pop().unwrap();
}

found = pow_counter > 0;
} else if x == &scheme[0] {
found = true;
*x = EvalTree::Const(Rational::one());
Expand All @@ -462,21 +507,39 @@ impl EvalTree<Rational> {
}
}

if contains.is_empty() {
*self = EvalTree::Add(rest);
self.apply_horner_scheme(&scheme[1..]);
let extracted = if max_pow == 1 {
scheme[0].clone()
} else {
EvalTree::Pow(Box::new((scheme[0].clone(), max_pow)))
};

let mut contains = if contains.len() == 1 {
contains.pop().unwrap()
} else {
let mut c = EvalTree::Mul(vec![EvalTree::Add(contains), scheme[0].clone()]);
c.apply_horner_scheme(&scheme[1..]);
EvalTree::Add(contains)
};

contains.apply_horner_scheme(&scheme); // keep trying with same variable

let mut v = vec![contains, extracted];
v.sort();
let c = EvalTree::Mul(v);

if rest.is_empty() {
*self = c;
} else {
let mut r = EvalTree::Add(rest);
let mut r = if rest.len() == 1 {
rest.pop().unwrap()
} else {
EvalTree::Add(rest)
};

r.apply_horner_scheme(&scheme[1..]);

*self = EvalTree::Add(vec![c, r]);
}
let mut v = vec![c, r];
v.sort();

*self = EvalTree::Add(v);
}
}

Expand Down Expand Up @@ -532,7 +595,7 @@ impl EvalTree<Rational> {

occurrence.retain(|_, v| *v > 1);
let mut order: Vec<_> = occurrence.into_iter().collect();
order.sort_by_key(|k| k.1);
order.sort_by_key(|k| std::cmp::Reverse(k.1)); // occurrence order
let scheme = order.into_iter().map(|(k, _)| k).collect::<Vec<_>>();

self.apply_horner_scheme(&scheme);
Expand Down

0 comments on commit 654eb66

Please sign in to comment.