diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index a669cfb55a8..a6b47cda070 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -160,6 +160,9 @@ impl Function { } } +/// `ValueId` in an `EnableConstraintsIf` instruction. +type Predicate = ValueId; + struct Context<'a> { use_constraint_info: bool, brillig_info: Option>, @@ -174,7 +177,7 @@ struct Context<'a> { /// We partition the maps of constrained values according to the side-effects flag at the point /// at which the values are constrained. This prevents constraints which are only sometimes enforced /// being used to modify the rest of the program. - constraint_simplification_mappings: HashMap>, + constraint_simplification_mappings: ConstraitSimplificationCache, // Cache of instructions without any side-effects along with their outputs. cached_instruction_results: InstructionResultCache, @@ -188,12 +191,39 @@ pub(crate) struct BrilligInfo<'a> { brillig_functions: &'a BTreeMap, } +struct SimplificationCache { + simplified: ValueId, + blocks: HashSet, +} + +impl SimplificationCache { + fn new(simplified: ValueId) -> Self { + Self { simplified, blocks: Default::default() } + } + fn merge(&mut self, dfg: &DataFlowGraph, other: ValueId, block: BasicBlockId) { + if self.simplified == other { + self.blocks.insert(block); + } else { + match simplify(dfg, self.simplified, other) { + Some((complex, simple)) if self.simplified == complex => { + self.simplified = simple; + self.blocks.clear(); + self.blocks.insert(block); + } + _ => {} + } + } + } +} + +type ConstraitSimplificationCache = HashMap>; + /// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. /// /// In addition to each result, the original BasicBlockId is stored as well. This allows us /// to deduplicate instructions across blocks as long as the new block dominates the original. -type InstructionResultCache = HashMap, ResultCache>>; +type InstructionResultCache = HashMap, ResultCache>>; /// Records the results of all duplicate [`Instruction`]s along with the blocks in which they sit. /// @@ -308,7 +338,7 @@ impl<'brillig> Context<'brillig> { fn resolve_instruction( instruction_id: InstructionId, dfg: &DataFlowGraph, - constraint_simplification_mapping: &HashMap, + constraint_simplification_mapping: &HashMap, ) -> Instruction { let instruction = dfg[instruction_id].clone(); @@ -319,12 +349,12 @@ impl<'brillig> Context<'brillig> { // constraints to the cache. fn resolve_cache( dfg: &DataFlowGraph, - cache: &HashMap, + cache: &HashMap, value_id: ValueId, ) -> ValueId { let resolved_id = dfg.resolve(value_id); match cache.get(&resolved_id) { - Some(cached_value) => resolve_cache(dfg, cache, *cached_value), + Some(cached_value) => resolve_cache(dfg, cache, cached_value.simplified), None => resolved_id, } } @@ -378,7 +408,10 @@ impl<'brillig> Context<'brillig> { if let Instruction::Constrain(lhs, rhs, _) = instruction { // These `ValueId`s should be fully resolved now. if let Some((complex, simple)) = simplify(dfg, lhs, rhs) { - self.get_constraint_map(side_effects_enabled_var).insert(complex, simple); + self.get_constraint_map(side_effects_enabled_var) + .entry(complex) + .or_insert_with(|| SimplificationCache::new(simple)) + .merge(dfg, simple, block); } } } @@ -402,7 +435,7 @@ impl<'brillig> Context<'brillig> { fn get_constraint_map( &mut self, side_effects_enabled_var: ValueId, - ) -> &mut HashMap { + ) -> &mut HashMap { self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() }