From cbc87b4b8046310d8f4559925332949bcf5a2815 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Sat, 31 Aug 2024 09:42:53 +0200 Subject: [PATCH] Add method to merge evaluators --- src/evaluate.rs | 169 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) diff --git a/src/evaluate.rs b/src/evaluate.rs index 2977320..e3a4d49 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -976,6 +976,175 @@ impl ExpressionEvaluator { } } +impl ExpressionEvaluator { + /// Merge evaluator `other` into `self`. The parameters must be the same. + pub fn merge(&mut self, mut other: Self, cpe_rounds: Option) -> Result<(), String> { + if self.param_count != other.param_count { + return Err("Parameter count is different".to_owned()); + } + + let mut constants = HashMap::default(); + + for (i, c) in self.stack[self.param_count..self.reserved_indices] + .iter() + .enumerate() + { + constants.insert(c.clone(), i); + } + + let old_len = self.stack.len() - self.reserved_indices; + + self.stack.truncate(self.reserved_indices); + + for c in &other.stack[self.param_count..other.reserved_indices] { + if constants.get(c).is_none() { + let i = constants.len(); + constants.insert(c.clone(), i); + self.stack.push(c.clone()); + } + } + + let new_reserved_indices = self.stack.len(); + let mut delta = new_reserved_indices - self.reserved_indices; + + // shift stack indices + if delta > 0 { + for i in &mut self.instructions { + match i { + Instr::Add(r, a) | Instr::Mul(r, a) => { + *r += delta; + for aa in a { + if *aa >= self.reserved_indices { + *aa += delta; + } + } + } + Instr::Pow(r, b, _) | Instr::BuiltinFun(r, _, b) => { + *r += delta; + if *b >= self.reserved_indices { + *b += delta; + } + } + Instr::Powf(r, b, e) => { + *r += delta; + if *b >= self.reserved_indices { + *b += delta; + } + if *e >= self.reserved_indices { + *e += delta; + } + } + } + } + + for x in &mut self.result_indices { + *x += delta; + } + } + + delta = old_len + new_reserved_indices - other.reserved_indices; + for i in &mut other.instructions { + match i { + Instr::Add(r, a) | Instr::Mul(r, a) => { + *r += delta; + for aa in a { + if *aa >= other.reserved_indices { + *aa += delta; + } else if *aa >= other.param_count { + *aa = self.param_count + constants[&other.stack[*aa]]; + } + } + } + Instr::Pow(r, b, _) | Instr::BuiltinFun(r, _, b) => { + *r += delta; + if *b >= other.reserved_indices { + *b += delta; + } else if *b >= other.param_count { + *b = self.param_count + constants[&other.stack[*b]]; + } + } + Instr::Powf(r, b, e) => { + *r += delta; + if *b >= other.reserved_indices { + *b += delta; + } else if *b >= other.param_count { + *b = self.param_count + constants[&other.stack[*b]]; + } + if *e >= other.reserved_indices { + *e += delta; + } else if *e >= other.param_count { + *e = self.param_count + constants[&other.stack[*e]]; + } + } + } + } + + for x in &mut other.result_indices { + if *x >= other.reserved_indices { + *x += delta; + } else if *x >= other.param_count { + *x = self.param_count + constants[&other.stack[*x]]; + } + } + + self.instructions.extend(other.instructions.drain(..)); + self.result_indices.extend(other.result_indices.drain(..)); + self.reserved_indices = new_reserved_indices; + + // undo the stack optimization + let mut unfold = HashMap::default(); + for (index, i) in &mut self.instructions.iter_mut().enumerate() { + match i { + Instr::Add(r, a) | Instr::Mul(r, a) => { + for aa in a { + if *aa >= self.reserved_indices { + *aa = unfold[aa]; + } + } + + unfold.insert(*r, index + self.reserved_indices); + *r = index + self.reserved_indices; + } + Instr::Pow(r, b, _) | Instr::BuiltinFun(r, _, b) => { + if *b >= self.reserved_indices { + *b = unfold[b]; + } + unfold.insert(*r, index + self.reserved_indices); + *r = index + self.reserved_indices; + } + Instr::Powf(r, b, e) => { + if *b >= self.reserved_indices { + *b = unfold[b]; + } + if *e >= self.reserved_indices { + *e = unfold[e]; + } + unfold.insert(*r, index + self.reserved_indices); + *r = index + self.reserved_indices; + } + } + } + + for i in &mut self.result_indices { + *i = unfold[i]; + } + + for _ in 0..self.instructions.len() { + self.stack.push(T::default()); + } + + for _ in 0..cpe_rounds.unwrap_or(usize::MAX) { + if self.remove_common_pairs() == 0 { + break; + } + } + + self.optimize_stack(); + + Ok(()) + } +} + impl ExpressionEvaluator { pub fn optimize_stack(&mut self) { let mut last_use: Vec = vec![0; self.stack.len()];