diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 3b86ded4a87..96ac1e8a6ac 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -19,7 +19,7 @@ //! //! This is the only pass which removes duplicated pure [`Instruction`]s however and so is needed when //! different blocks are merged, i.e. after the [`flatten_cfg`][super::flatten_cfg] pass. -use std::collections::HashSet; +use std::collections::{HashSet, VecDeque}; use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; @@ -28,6 +28,7 @@ use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::{DataFlowGraph, InsertInstructionResult}, + dom::DominatorTree, function::Function, instruction::{Instruction, InstructionId}, types::Type, @@ -67,10 +68,10 @@ impl Function { /// The structure of this pass is simple: /// Go through each block and re-insert all instructions. pub(crate) fn constant_fold(&mut self, use_constraint_info: bool) { - let mut context = Context { use_constraint_info, ..Default::default() }; - context.block_queue.push(self.entry_block()); + let mut context = Context::new(self, use_constraint_info); + context.block_queue.push_back(self.entry_block()); - while let Some(block) = context.block_queue.pop() { + while let Some(block) = context.block_queue.pop_front() { if context.visited_blocks.contains(&block) { continue; } @@ -81,34 +82,57 @@ impl Function { } } -#[derive(Default)] struct Context { use_constraint_info: bool, /// Maps pre-folded ValueIds to the new ValueIds obtained by re-inserting the instruction. visited_blocks: HashSet, - block_queue: Vec, + block_queue: VecDeque, + + /// Contains sets of values which are constrained to be equivalent to each other. + /// + /// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`. + /// + /// We partition the maps of constrained values according to the side-effects flag at the point + /// at which the values are constrained. This prevents constraints which are only sometimes enforced + /// being used to modify the rest of the program. + constraint_simplification_mappings: HashMap>, + + // Cache of instructions without any side-effects along with their outputs. + cached_instruction_results: InstructionResultCache, + + dom: DominatorTree, } /// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. -type InstructionResultCache = HashMap, Vec>>; +/// +/// In addition to each result, the original BasicBlockId is stored as well. This allows us +/// to deduplicate instructions across blocks as long as the new block dominates the original. +type InstructionResultCache = HashMap, ResultCache>>; + +/// Records the results of all duplicate [`Instruction`]s along with the blocks in which they sit. +/// +/// For more information see [`InstructionResultCache`]. +#[derive(Default)] +struct ResultCache { + results: Vec<(BasicBlockId, Vec)>, +} impl Context { + fn new(function: &Function, use_constraint_info: bool) -> Self { + Self { + use_constraint_info, + visited_blocks: Default::default(), + block_queue: Default::default(), + constraint_simplification_mappings: Default::default(), + cached_instruction_results: Default::default(), + dom: DominatorTree::with_function(function), + } + } + fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { let instructions = function.dfg[block].take_instructions(); - // Cache of instructions without any side-effects along with their outputs. - let mut cached_instruction_results = HashMap::default(); - - // Contains sets of values which are constrained to be equivalent to each other. - // - // The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`. - // - // We partition the maps of constrained values according to the side-effects flag at the point - // at which the values are constrained. This prevents constraints which are only sometimes enforced - // being used to modify the rest of the program. - let mut constraint_simplification_mappings: HashMap> = - HashMap::default(); let mut side_effects_enabled_var = function.dfg.make_constant(FieldElement::one(), Type::bool()); @@ -117,8 +141,6 @@ impl Context { &mut function.dfg, block, instruction_id, - &mut cached_instruction_results, - &mut constraint_simplification_mappings, &mut side_effects_enabled_var, ); } @@ -126,22 +148,19 @@ impl Context { } fn fold_constants_into_instruction( - &self, + &mut self, dfg: &mut DataFlowGraph, block: BasicBlockId, id: InstructionId, - instruction_result_cache: &mut InstructionResultCache, - constraint_simplification_mappings: &mut HashMap>, side_effects_enabled_var: &mut ValueId, ) { - let constraint_simplification_mapping = - constraint_simplification_mappings.entry(*side_effects_enabled_var).or_default(); + let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping); let old_results = dfg.instruction_results(id).to_vec(); // If a copy of this instruction exists earlier in the block, then reuse the previous results. if let Some(cached_results) = - Self::get_cached(dfg, instruction_result_cache, &instruction, *side_effects_enabled_var) + self.get_cached(dfg, &instruction, *side_effects_enabled_var, block) { Self::replace_result_ids(dfg, &old_results, cached_results); return; @@ -156,9 +175,8 @@ impl Context { instruction.clone(), new_results, dfg, - instruction_result_cache, - constraint_simplification_mapping, *side_effects_enabled_var, + block, ); // If we just inserted an `Instruction::EnableSideEffectsIf`, we need to update `side_effects_enabled_var` @@ -229,13 +247,12 @@ impl Context { } fn cache_instruction( - &self, + &mut self, instruction: Instruction, instruction_results: Vec, dfg: &DataFlowGraph, - instruction_result_cache: &mut InstructionResultCache, - constraint_simplification_mapping: &mut HashMap, side_effects_enabled_var: ValueId, + block: BasicBlockId, ) { if self.use_constraint_info { // If the instruction was a constraint, then create a link between the two `ValueId`s @@ -248,18 +265,18 @@ impl Context { // Prefer replacing with constants where possible. (Value::NumericConstant { .. }, _) => { - constraint_simplification_mapping.insert(rhs, lhs); + self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs); } (_, Value::NumericConstant { .. }) => { - constraint_simplification_mapping.insert(lhs, rhs); + self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs); } // Otherwise prefer block parameters over instruction results. // This is as block parameters are more likely to be a single witness rather than a full expression. (Value::Param { .. }, Value::Instruction { .. }) => { - constraint_simplification_mapping.insert(rhs, lhs); + self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs); } (Value::Instruction { .. }, Value::Param { .. }) => { - constraint_simplification_mapping.insert(lhs, rhs); + self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs); } (_, _) => (), } @@ -273,13 +290,22 @@ impl Context { self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); let predicate = use_predicate.then_some(side_effects_enabled_var); - instruction_result_cache + self.cached_instruction_results .entry(instruction) .or_default() - .insert(predicate, instruction_results); + .entry(predicate) + .or_default() + .cache(block, instruction_results); } } + fn get_constraint_map( + &mut self, + side_effects_enabled_var: ValueId, + ) -> &mut HashMap { + self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() + } + /// Replaces a set of [`ValueId`]s inside the [`DataFlowGraph`] with another. fn replace_result_ids( dfg: &mut DataFlowGraph, @@ -292,22 +318,40 @@ impl Context { } fn get_cached<'a>( + &'a mut self, dfg: &DataFlowGraph, - instruction_result_cache: &'a mut InstructionResultCache, instruction: &Instruction, side_effects_enabled_var: ValueId, - ) -> Option<&'a Vec> { - let results_for_instruction = instruction_result_cache.get(instruction); + block: BasicBlockId, + ) -> Option<&'a [ValueId]> { + let results_for_instruction = self.cached_instruction_results.get(instruction)?; - // See if there's a cached version with no predicate first - if let Some(results) = results_for_instruction.and_then(|map| map.get(&None)) { - return Some(results); - } + let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); + let predicate = predicate.then_some(side_effects_enabled_var); + + results_for_instruction.get(&predicate)?.get(block, &mut self.dom) + } +} - let predicate = - instruction.requires_acir_gen_predicate(dfg).then_some(side_effects_enabled_var); +impl ResultCache { + /// Records that an `Instruction` in block `block` produced the result values `results`. + fn cache(&mut self, block: BasicBlockId, results: Vec) { + self.results.push((block, results)); + } - results_for_instruction.and_then(|map| map.get(&predicate)) + /// Returns a set of [`ValueId`]s produced from a copy of this [`Instruction`] which sits + /// within a block which dominates `block`. + /// + /// We require that the cached instruction's block dominates `block` in order to avoid + /// cycles causing issues (e.g. two instructions being replaced with the results of each other + /// such that neither instruction exists anymore.) + fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<&[ValueId]> { + for (origin_block, results) in &self.results { + if dom.dominates(*origin_block, block) { + return Some(results); + } + } + None } } @@ -896,4 +940,48 @@ mod test { let ending_instruction_count = instructions.len(); assert_eq!(ending_instruction_count, 1); } + + #[test] + fn deduplicate_across_blocks() { + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // v2 = not v0 + // return v2 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let b1 = builder.insert_block(); + + let v0 = builder.add_parameter(Type::bool()); + let _v1 = builder.insert_not(v0); + builder.terminate_with_jmp(b1, Vec::new()); + + builder.switch_to_block(b1); + let v2 = builder.insert_not(v0); + builder.terminate_with_return(vec![v2]); + + let ssa = builder.finish(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 1); + + // Expected output: + // + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // return v1 + // } + let ssa = ssa.fold_constants_using_constraints(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 0); + } }