From c4e3ab38986db6eb359ac9e06368d4508db26011 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Fri, 29 Nov 2024 14:33:05 +0000 Subject: [PATCH 1/9] chore: apply sync fixes --- .aztec-sync-commit | 2 +- .../noirc_evaluator/src/ssa/ir/instruction.rs | 15 +- .../src/ssa/ir/instruction/call.rs | 34 +- .../src/ssa/ir/instruction/call/blackbox.rs | 4 +- .../src/ssa/opt/constant_folding.rs | 445 ++++---------- .../src/ssa/opt/flatten_cfg.rs | 556 ++++++++++-------- .../noirc_evaluator/src/ssa/opt/inlining.rs | 1 + .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 332 +++-------- .../src/ssa/ssa_gen/context.rs | 11 +- .../noirc_evaluator/src/ssa/ssa_gen/mod.rs | 5 +- .../profiler/src/cli/gates_flamegraph_cmd.rs | 25 +- 11 files changed, 559 insertions(+), 871 deletions(-) diff --git a/.aztec-sync-commit b/.aztec-sync-commit index d97a936c081..477ebbca903 100644 --- a/.aztec-sync-commit +++ b/.aztec-sync-commit @@ -1 +1 @@ -1bfc15e08873a1f0f3743e259f418b70426b3f25 +0577c1a70e9746bd06f07d2813af1be39e01ca02 diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 6737b335b7d..f606fffbf91 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -11,7 +11,7 @@ use fxhash::FxHasher64; use iter_extended::vecmap; use noirc_frontend::hir_def::types::Type as HirType; -use crate::ssa::{ir::function::RuntimeType, opt::flatten_cfg::value_merger::ValueMerger}; +use crate::ssa::opt::flatten_cfg::value_merger::ValueMerger; use super::{ basic_block::BasicBlockId, @@ -478,19 +478,8 @@ impl Instruction { | ArraySet { .. } | MakeArray { .. } => true, - // Store instructions must be removed by DIE in acir code, any load - // instructions should already be unused by that point. - // - // Note that this check assumes that it is being performed after the flattening - // pass and after the last mem2reg pass. This is currently the case for the DIE - // pass where this check is done, but does mean that we cannot perform mem2reg - // after the DIE pass. - Store { .. } => { - matches!(function.runtime(), RuntimeType::Acir(_)) - && function.reachable_blocks().len() == 1 - } - Constrain(..) + | Store { .. } | EnableSideEffectsIf { .. } | IncrementRc { .. } | DecrementRc { .. } diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 4be37b3c626..67222d06ea8 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -45,17 +45,17 @@ pub(super) fn simplify_call( _ => return SimplifyResult::None, }; + let return_type = ctrl_typevars.and_then(|return_types| return_types.first().cloned()); + let constant_args: Option> = arguments.iter().map(|value_id| dfg.get_numeric_constant(*value_id)).collect(); - match intrinsic { + let simplified_result = match intrinsic { Intrinsic::ToBits(endian) => { // TODO: simplify to a range constraint if `limb_count == 1` - if let (Some(constant_args), Some(return_type)) = - (constant_args, ctrl_typevars.map(|return_types| return_types.first().cloned())) - { + if let (Some(constant_args), Some(return_type)) = (constant_args, return_type.clone()) { let field = constant_args[0]; - let limb_count = if let Some(Type::Array(_, array_len)) = return_type { + let limb_count = if let Type::Array(_, array_len) = return_type { array_len as u32 } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") @@ -67,12 +67,10 @@ pub(super) fn simplify_call( } Intrinsic::ToRadix(endian) => { // TODO: simplify to a range constraint if `limb_count == 1` - if let (Some(constant_args), Some(return_type)) = - (constant_args, ctrl_typevars.map(|return_types| return_types.first().cloned())) - { + if let (Some(constant_args), Some(return_type)) = (constant_args, return_type.clone()) { let field = constant_args[0]; let radix = constant_args[1].to_u128() as u32; - let limb_count = if let Some(Type::Array(_, array_len)) = return_type { + let limb_count = if let Type::Array(_, array_len) = return_type { array_len as u32 } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") @@ -330,7 +328,7 @@ pub(super) fn simplify_call( } Intrinsic::FromField => { let incoming_type = Type::field(); - let target_type = ctrl_typevars.unwrap().remove(0); + let target_type = return_type.clone().unwrap(); let truncate = Instruction::Truncate { value: arguments[0], @@ -352,8 +350,8 @@ pub(super) fn simplify_call( Intrinsic::AsWitness => SimplifyResult::None, Intrinsic::IsUnconstrained => SimplifyResult::None, Intrinsic::DerivePedersenGenerators => { - if let Some(Type::Array(_, len)) = ctrl_typevars.unwrap().first() { - simplify_derive_generators(dfg, arguments, *len as u32, block, call_stack) + if let Some(Type::Array(_, len)) = return_type.clone() { + simplify_derive_generators(dfg, arguments, len as u32, block, call_stack) } else { unreachable!("Derive Pedersen Generators must return an array"); } @@ -370,7 +368,19 @@ pub(super) fn simplify_call( } Intrinsic::ArrayRefCount => SimplifyResult::None, Intrinsic::SliceRefCount => SimplifyResult::None, + }; + + if let (Some(expected_types), SimplifyResult::SimplifiedTo(result)) = + (return_type, &simplified_result) + { + assert_eq!( + dfg.type_of_value(*result), + expected_types, + "Simplification should not alter return type" + ); } + + simplified_result } /// Slices have a tuple structure (slice length, slice contents) to enable logic diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs index 4f2a31e2fb0..301b75e0bd4 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -48,7 +48,7 @@ pub(super) fn simplify_ec_add( let result_x = dfg.make_constant(result_x, Type::field()); let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); let typ = Type::Array(Arc::new(vec![Type::field()]), 3); @@ -107,7 +107,7 @@ pub(super) fn simplify_msm( let result_x = dfg.make_constant(result_x, Type::field()); let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); let elements = im::vector![result_x, result_y, result_is_infinity]; let typ = Type::Array(Arc::new(vec![Type::field()]), 3); diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 41c84c935b1..ceda0c6272f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -149,8 +149,7 @@ impl Function { use_constraint_info: bool, brillig_info: Option, ) { - let mut context = Context::new(use_constraint_info, brillig_info); - let mut dom = DominatorTree::with_function(self); + let mut context = Context::new(self, use_constraint_info, brillig_info); context.block_queue.push_back(self.entry_block()); while let Some(block) = context.block_queue.pop_front() { @@ -159,7 +158,7 @@ impl Function { } context.visited_blocks.insert(block); - context.fold_constants_in_block(&mut self.dfg, &mut dom, block); + context.fold_constants_in_block(self, block); } } } @@ -173,15 +172,22 @@ struct Context<'a> { /// Contains sets of values which are constrained to be equivalent to each other. /// - /// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`. + /// The mapping's structure is `side_effects_enabled_var => (constrained_value => [(block, simplified_value)])`. /// /// We partition the maps of constrained values according to the side-effects flag at the point /// at which the values are constrained. This prevents constraints which are only sometimes enforced /// being used to modify the rest of the program. - constraint_simplification_mappings: ConstraintSimplificationCache, + /// + /// We also keep track of how a value was simplified to other values per block. That is, + /// a same ValueId could have been simplified to one value in one block and to another value + /// in another block. + constraint_simplification_mappings: + HashMap>>, // Cache of instructions without any side-effects along with their outputs. cached_instruction_results: InstructionResultCache, + + dom: DominatorTree, } #[derive(Copy, Clone)] @@ -190,56 +196,9 @@ pub(crate) struct BrilligInfo<'a> { brillig_functions: &'a BTreeMap, } -/// Records a simplified equivalents of an [`Instruction`] in the blocks -/// where the constraint that advised the simplification has been encountered. -/// -/// For more information see [`ConstraintSimplificationCache`]. -#[derive(Default)] -struct SimplificationCache { - /// Simplified expressions where we found them. - /// - /// It will always have at least one value because `add` is called - /// after the default is constructed. - simplifications: HashMap, -} - -impl SimplificationCache { - /// Called with a newly encountered simplification. - fn add(&mut self, dfg: &DataFlowGraph, simple: ValueId, block: BasicBlockId) { - self.simplifications - .entry(block) - .and_modify(|existing| { - // `SimplificationCache` may already hold a simplification in this block - // so we check whether `simple` is a better simplification than the current one. - if let Some((_, simpler)) = simplify(dfg, *existing, simple) { - *existing = simpler; - }; - }) - .or_insert(simple); - } - - /// Try to find a simplification in a visible block. - fn get(&self, block: BasicBlockId, dom: &DominatorTree) -> Option { - // Deterministically walk up the dominator chain until we encounter a block that contains a simplification. - dom.find_map_dominator(block, |b| self.simplifications.get(&b).cloned()) - } -} - -/// HashMap from `(side_effects_enabled_var, Instruction)` to a simplified expression that it can -/// be replaced with based on constraints that testify to their equivalence, stored together -/// with the set of blocks at which this constraint has been observed. -/// -/// Only blocks dominated by one in the cache should have access to this information, otherwise -/// we create a sort of time paradox where we replace an instruction with a constant we believe -/// it _should_ equal to, without ever actually producing and asserting the value. -type ConstraintSimplificationCache = HashMap>; - /// HashMap from `(Instruction, side_effects_enabled_var)` to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. /// -/// The `side_effects_enabled_var` is optional because we only use them when `Instruction::requires_acir_gen_predicate` -/// is true _and_ the constraint information is also taken into account. -/// /// In addition to each result, the original BasicBlockId is stored as well. This allows us /// to deduplicate instructions across blocks as long as the new block dominates the original. type InstructionResultCache = HashMap, ResultCache>>; @@ -249,11 +208,15 @@ type InstructionResultCache = HashMap, Resu /// For more information see [`InstructionResultCache`]. #[derive(Default)] struct ResultCache { - result: Option<(BasicBlockId, Vec)>, + results: Vec<(BasicBlockId, Vec)>, } impl<'brillig> Context<'brillig> { - fn new(use_constraint_info: bool, brillig_info: Option>) -> Self { + fn new( + function: &Function, + use_constraint_info: bool, + brillig_info: Option>, + ) -> Self { Self { use_constraint_info, brillig_info, @@ -261,65 +224,52 @@ impl<'brillig> Context<'brillig> { block_queue: Default::default(), constraint_simplification_mappings: Default::default(), cached_instruction_results: Default::default(), + dom: DominatorTree::with_function(function), } } - fn fold_constants_in_block( - &mut self, - dfg: &mut DataFlowGraph, - dom: &mut DominatorTree, - block: BasicBlockId, - ) { - let instructions = dfg[block].take_instructions(); + fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { + let instructions = function.dfg[block].take_instructions(); - // Default side effect condition variable with an enabled state. - let mut side_effects_enabled_var = dfg.make_constant(FieldElement::one(), Type::bool()); + let mut side_effects_enabled_var = + function.dfg.make_constant(FieldElement::one(), Type::bool()); for instruction_id in instructions { self.fold_constants_into_instruction( - dfg, - dom, + &mut function.dfg, block, instruction_id, &mut side_effects_enabled_var, ); } - self.block_queue.extend(dfg[block].successors()); + self.block_queue.extend(function.dfg[block].successors()); } fn fold_constants_into_instruction( &mut self, dfg: &mut DataFlowGraph, - dom: &mut DominatorTree, - mut block: BasicBlockId, + block: BasicBlockId, id: InstructionId, side_effects_enabled_var: &mut ValueId, ) { - let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); - - let instruction = - Self::resolve_instruction(id, block, dfg, dom, constraint_simplification_mapping); - + let constraint_simplification_mapping = + self.constraint_simplification_mappings.get(side_effects_enabled_var); + let instruction = Self::resolve_instruction( + id, + block, + dfg, + &mut self.dom, + constraint_simplification_mapping, + ); let old_results = dfg.instruction_results(id).to_vec(); // If a copy of this instruction exists earlier in the block, then reuse the previous results. - if let Some(cache_result) = - self.get_cached(dfg, dom, &instruction, *side_effects_enabled_var, block) + if let Some(cached_results) = + self.get_cached(dfg, &instruction, *side_effects_enabled_var, block) { - match cache_result { - CacheResult::Cached(cached) => { - Self::replace_result_ids(dfg, &old_results, cached); - return; - } - CacheResult::NeedToHoistToCommonBlock(dominator) => { - // Just change the block to insert in the common dominator instead. - // This will only move the current instance of the instruction right now. - // When constant folding is run a second time later on, it'll catch - // that the previous instance can be deduplicated to this instance. - block = dominator; - } - } - }; + Self::replace_result_ids(dfg, &old_results, cached_results); + return; + } let new_results = // First try to inline a call to a brillig function with all constant arguments. @@ -364,7 +314,7 @@ impl<'brillig> Context<'brillig> { block: BasicBlockId, dfg: &DataFlowGraph, dom: &mut DominatorTree, - constraint_simplification_mapping: &HashMap, + constraint_simplification_mapping: Option<&HashMap>>, ) -> Instruction { let instruction = dfg[instruction_id].clone(); @@ -374,28 +324,30 @@ impl<'brillig> Context<'brillig> { // This allows us to reach a stable final `ValueId` for each instruction input as we add more // constraints to the cache. fn resolve_cache( - block: BasicBlockId, dfg: &DataFlowGraph, dom: &mut DominatorTree, - cache: &HashMap, + cache: Option<&HashMap>>, value_id: ValueId, + block: BasicBlockId, ) -> ValueId { let resolved_id = dfg.resolve(value_id); - match cache.get(&resolved_id) { - Some(simplification_cache) => { - if let Some(simplified) = simplification_cache.get(block, dom) { - resolve_cache(block, dfg, dom, cache, simplified) - } else { - resolved_id - } + let Some(cached_values) = cache.and_then(|cache| cache.get(&resolved_id)) else { + return resolved_id; + }; + + for (cached_block, cached_value) in cached_values { + // We can only use the simplified value if it was simplified in a block that dominates the current one + if dom.dominates(*cached_block, block) { + return resolve_cache(dfg, dom, cache, *cached_value, block); } - None => resolved_id, } + + resolved_id } // Resolve any inputs to ensure that we're comparing like-for-like instructions. instruction.map_values(|value_id| { - resolve_cache(block, dfg, dom, constraint_simplification_mapping, value_id) + resolve_cache(dfg, dom, constraint_simplification_mapping, value_id, block) }) } @@ -446,7 +398,7 @@ impl<'brillig> Context<'brillig> { self.get_constraint_map(side_effects_enabled_var) .entry(complex) .or_default() - .add(dfg, simple, block); + .push((block, simple)); } } } @@ -468,12 +420,10 @@ impl<'brillig> Context<'brillig> { } } - /// Get the simplification mapping from complex to simpler instructions, - /// which all depend on the same side effect condition variable. fn get_constraint_map( &mut self, side_effects_enabled_var: ValueId, - ) -> &mut HashMap { + ) -> &mut HashMap> { self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() } @@ -488,20 +438,19 @@ impl<'brillig> Context<'brillig> { } } - /// Get a cached result if it can be used in this context. - fn get_cached( - &self, + fn get_cached<'a>( + &'a mut self, dfg: &DataFlowGraph, - dom: &mut DominatorTree, instruction: &Instruction, side_effects_enabled_var: ValueId, block: BasicBlockId, - ) -> Option { + ) -> Option<&'a [ValueId]> { let results_for_instruction = self.cached_instruction_results.get(instruction)?; + let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); let predicate = predicate.then_some(side_effects_enabled_var); - results_for_instruction.get(&predicate)?.get(block, dom, instruction.has_side_effects(dfg)) + results_for_instruction.get(&predicate)?.get(block, &mut self.dom) } /// Checks if the given instruction is a call to a brillig function with all constant arguments. @@ -668,9 +617,7 @@ impl<'brillig> Context<'brillig> { impl ResultCache { /// Records that an `Instruction` in block `block` produced the result values `results`. fn cache(&mut self, block: BasicBlockId, results: Vec) { - if self.result.is_none() { - self.result = Some((block, results)); - } + self.results.push((block, results)); } /// Returns a set of [`ValueId`]s produced from a copy of this [`Instruction`] which sits @@ -679,23 +626,13 @@ impl ResultCache { /// We require that the cached instruction's block dominates `block` in order to avoid /// cycles causing issues (e.g. two instructions being replaced with the results of each other /// such that neither instruction exists anymore.) - fn get( - &self, - block: BasicBlockId, - dom: &mut DominatorTree, - has_side_effects: bool, - ) -> Option { - self.result.as_ref().and_then(|(origin_block, results)| { + fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<&[ValueId]> { + for (origin_block, results) in &self.results { if dom.dominates(*origin_block, block) { - Some(CacheResult::Cached(results)) - } else if !has_side_effects { - // Insert a copy of this instruction in the common dominator - let dominator = dom.common_dominator(*origin_block, block); - Some(CacheResult::NeedToHoistToCommonBlock(dominator)) - } else { - None + return Some(results); } - }) + } + None } } @@ -1003,22 +940,32 @@ mod test { // Regression for #4600 #[test] fn array_get_regression() { + // fn main f0 { + // b0(v0: u1, v1: u64): + // enable_side_effects_if v0 + // v2 = make_array [Field 0, Field 1] + // v3 = array_get v2, index v1 + // v4 = not v0 + // enable_side_effects_if v4 + // v5 = array_get v2, index v1 + // } + // // We want to make sure after constant folding both array_gets remain since they are // under different enable_side_effects_if contexts and thus one may be disabled while // the other is not. If one is removed, it is possible e.g. v4 is replaced with v2 which // is disabled (only gets from index 0) and thus returns the wrong result. let src = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u64): - enable_side_effects v0 - v4 = make_array [Field 0, Field 1] : [Field; 2] - v5 = array_get v4, index v1 -> Field - v6 = not v0 - enable_side_effects v6 - v7 = array_get v4, index v1 -> Field - return - } - "; + acir(inline) fn main f0 { + b0(v0: u1, v1: u64): + enable_side_effects v0 + v4 = make_array [Field 0, Field 1] : [Field; 2] + v5 = array_get v4, index v1 -> Field + v6 = not v0 + enable_side_effects v6 + v7 = array_get v4, index v1 -> Field + return + } + "; let ssa = Ssa::from_str(src).unwrap(); // Expected output is unchanged @@ -1085,6 +1032,7 @@ mod test { // v5 = call keccakf1600(v1) // v6 = call keccakf1600(v2) // } + // // Here we're checking a situation where two identical arrays are being initialized twice and being assigned separate `ValueId`s. // This would result in otherwise identical instructions not being deduplicated. let main_id = Id::test_new(0); @@ -1135,106 +1083,6 @@ mod test { assert_eq!(ending_instruction_count, 2); } - #[test] - fn deduplicate_across_blocks() { - // fn main f0 { - // b0(v0: u1): - // v1 = not v0 - // jmp b1() - // b1(): - // v2 = not v0 - // return v2 - // } - let main_id = Id::test_new(0); - - // Compiling main - let mut builder = FunctionBuilder::new("main".into(), main_id); - let b1 = builder.insert_block(); - - let v0 = builder.add_parameter(Type::bool()); - let _v1 = builder.insert_not(v0); - builder.terminate_with_jmp(b1, Vec::new()); - - builder.switch_to_block(b1); - let v2 = builder.insert_not(v0); - builder.terminate_with_return(vec![v2]); - - let ssa = builder.finish(); - let main = ssa.main(); - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); - assert_eq!(main.dfg[b1].instructions().len(), 1); - - // Expected output: - // - // fn main f0 { - // b0(v0: u1): - // v1 = not v0 - // jmp b1() - // b1(): - // return v1 - // } - let ssa = ssa.fold_constants_using_constraints(); - let main = ssa.main(); - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); - assert_eq!(main.dfg[b1].instructions().len(), 0); - } - - #[test] - fn deduplicate_across_non_dominated_blocks() { - let src = " - brillig(inline) fn main f0 { - b0(v0: u32): - v2 = lt u32 1000, v0 - jmpif v2 then: b1, else: b2 - b1(): - v4 = add v0, u32 1 - v5 = lt v0, v4 - constrain v5 == u1 1 - jmp b2() - b2(): - v7 = lt u32 1000, v0 - jmpif v7 then: b3, else: b4 - b3(): - v8 = add v0, u32 1 - v9 = lt v0, v8 - constrain v9 == u1 1 - jmp b4() - b4(): - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - - // v4 has been hoisted, although: - // - v5 has not yet been removed since it was encountered earlier in the program - // - v8 hasn't been recognized as a duplicate of v6 yet since they still reference v4 and - // v5 respectively - let expected = " - brillig(inline) fn main f0 { - b0(v0: u32): - v2 = lt u32 1000, v0 - v4 = add v0, u32 1 - jmpif v2 then: b1, else: b2 - b1(): - v5 = add v0, u32 1 - v6 = lt v0, v5 - constrain v6 == u1 1 - jmp b2() - b2(): - jmpif v2 then: b3, else: b4 - b3(): - v8 = lt v0, v4 - constrain v8 == u1 1 - jmp b4() - b4(): - return - } - "; - - let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, expected); - } - #[test] fn inlines_brillig_call_without_arguments() { let src = " @@ -1412,87 +1260,46 @@ mod test { } #[test] - fn does_not_use_cached_constrain_in_block_that_is_not_dominated() { - let src = " - brillig(inline) fn main f0 { - b0(v0: Field, v1: Field): - v3 = eq v0, Field 0 - jmpif v3 then: b1, else: b2 - b1(): - v5 = eq v1, Field 1 - constrain v1 == Field 1 - jmp b2() - b2(): - v6 = eq v1, Field 0 - constrain v1 == Field 0 - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, src); - } + fn deduplicate_across_blocks() { + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // v2 = not v0 + // return v2 + // } + let main_id = Id::test_new(0); - #[test] - fn does_not_hoist_constrain_to_common_ancestor() { - let src = " - brillig(inline) fn main f0 { - b0(v0: Field, v1: Field): - v3 = eq v0, Field 0 - jmpif v3 then: b1, else: b2 - b1(): - constrain v1 == Field 1 - jmp b2() - b2(): - jmpif v0 then: b3, else: b4 - b3(): - constrain v1 == Field 1 // This was incorrectly hoisted to b0 but this condition is not valid when going b0 -> b2 -> b4 - jmp b4() - b4(): - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, src); - } + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let b1 = builder.insert_block(); - #[test] - fn deduplicates_side_effecting_intrinsics() { - let src = " - // After EnableSideEffectsIf removal: - acir(inline) fn main f0 { - b0(v0: Field, v1: Field, v2: u1): - v4 = call is_unconstrained() -> u1 - v7 = call to_be_radix(v0, u32 256) -> [u8; 1] // `a.to_be_radix(256)`; - inc_rc v7 - v8 = call to_be_radix(v0, u32 256) -> [u8; 1] // duplicate load of `a` - inc_rc v8 - v9 = cast v2 as Field // `if c { a.to_be_radix(256) }` - v10 = mul v0, v9 // attaching `c` to `a` - v11 = call to_be_radix(v10, u32 256) -> [u8; 1] // calling `to_radix(c * a)` - inc_rc v11 - enable_side_effects v2 // side effect var for `c` shifted down by removal - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let expected = " - acir(inline) fn main f0 { - b0(v0: Field, v1: Field, v2: u1): - v4 = call is_unconstrained() -> u1 - v7 = call to_be_radix(v0, u32 256) -> [u8; 1] - inc_rc v7 - inc_rc v7 - v8 = cast v2 as Field - v9 = mul v0, v8 - v10 = call to_be_radix(v9, u32 256) -> [u8; 1] - inc_rc v10 - enable_side_effects v2 - return - } - "; + let v0 = builder.add_parameter(Type::bool()); + let _v1 = builder.insert_not(v0); + builder.terminate_with_jmp(b1, Vec::new()); + + builder.switch_to_block(b1); + let v2 = builder.insert_not(v0); + builder.terminate_with_return(vec![v2]); + + let ssa = builder.finish(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 1); + + // Expected output: + // + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // return v1 + // } let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, expected); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 0); } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index c8dd0e3c5a3..5d114672a55 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -131,7 +131,8 @@ //! v11 = mul v4, Field 12 //! v12 = add v10, v11 //! store v12 at v5 (new store) -use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use fxhash::FxHashMap as HashMap; +use std::collections::{BTreeMap, HashSet}; use acvm::{acir::AcirField, acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; @@ -185,6 +186,18 @@ struct Context<'f> { /// Maps start of branch -> end of branch branch_ends: HashMap, + /// Maps an address to the old and new value of the element at that address + /// These only hold stores for one block at a time and is cleared + /// between inlining of branches. + store_values: HashMap, + + /// Stores all allocations local to the current branch. + /// Since these branches are local to the current branch (ie. only defined within one branch of + /// an if expression), they should not be merged with their previous value or stored value in + /// the other branch since there is no such value. The ValueId here is that which is returned + /// by the allocate instruction. + local_allocations: HashSet, + /// A stack of each jmpif condition that was taken to reach a particular point in the program. /// When two branches are merged back into one, this constitutes a join point, and is analogous /// to the rest of the program after an if statement. When such a join point / end block is @@ -201,15 +214,13 @@ struct Context<'f> { /// When processing a block, we pop this stack to get its arguments /// and at the end we push the arguments for his successor arguments_stack: Vec>, +} - /// Stores all allocations local to the current branch. - /// - /// Since these branches are local to the current branch (i.e. only defined within one branch of - /// an if expression), they should not be merged with their previous value or stored value in - /// the other branch since there is no such value. - /// - /// The `ValueId` here is that which is returned by the allocate instruction. - local_allocations: HashSet, +#[derive(Clone)] +pub(crate) struct Store { + old_value: ValueId, + new_value: ValueId, + call_stack: CallStack, } #[derive(Clone)] @@ -220,6 +231,8 @@ struct ConditionalBranch { old_condition: ValueId, // The condition of the branch condition: ValueId, + // The store values accumulated when processing the branch + store_values: HashMap, // The allocations accumulated when processing the branch local_allocations: HashSet, } @@ -250,11 +263,12 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap Context<'f> { let old_condition = *condition; let then_condition = self.inserter.resolve(old_condition); + let old_stores = std::mem::take(&mut self.store_values); let old_allocations = std::mem::take(&mut self.local_allocations); let branch = ConditionalBranch { old_condition, condition: self.link_condition(then_condition), - last_block: *then_destination, + store_values: old_stores, local_allocations: old_allocations, + last_block: *then_destination, }; let cond_context = ConditionalContext { condition: then_condition, @@ -457,12 +473,19 @@ impl<'f> Context<'f> { ); let else_condition = self.link_condition(else_condition); + // Make sure the else branch sees the previous values of each store + // rather than any values created in the 'then' branch. + let old_stores = std::mem::take(&mut cond_context.then_branch.store_values); + cond_context.then_branch.store_values = std::mem::take(&mut self.store_values); + self.undo_stores_in_then_branch(&cond_context.then_branch.store_values); + let old_allocations = std::mem::take(&mut self.local_allocations); let else_branch = ConditionalBranch { old_condition: cond_context.then_branch.old_condition, condition: else_condition, - last_block: *block, + store_values: old_stores, local_allocations: old_allocations, + last_block: *block, }; cond_context.then_branch.local_allocations.clear(); cond_context.else_branch = Some(else_branch); @@ -486,8 +509,10 @@ impl<'f> Context<'f> { } let mut else_branch = cond_context.else_branch.unwrap(); + let stores_in_branch = std::mem::replace(&mut self.store_values, else_branch.store_values); self.local_allocations = std::mem::take(&mut else_branch.local_allocations); else_branch.last_block = *block; + else_branch.store_values = stores_in_branch; cond_context.else_branch = Some(else_branch); // We must remember to reset whether side effects are enabled when both branches @@ -555,6 +580,8 @@ impl<'f> Context<'f> { .first() }); + let call_stack = cond_context.call_stack; + self.merge_stores(cond_context.then_branch, cond_context.else_branch, call_stack); self.arguments_stack.pop(); self.arguments_stack.pop(); self.arguments_stack.push(args); @@ -609,29 +636,116 @@ impl<'f> Context<'f> { self.insert_instruction_with_typevars(enable_side_effects, None, call_stack); } + /// Merge any store instructions found in each branch. + /// + /// This function relies on the 'then' branch being merged before the 'else' branch of a jmpif + /// instruction. If this ordering is changed, the ordering that store values are merged within + /// this function also needs to be changed to reflect that. + fn merge_stores( + &mut self, + then_branch: ConditionalBranch, + else_branch: Option, + call_stack: CallStack, + ) { + // Address -> (then_value, else_value, value_before_the_if) + let mut new_map = BTreeMap::new(); + + for (address, store) in then_branch.store_values { + new_map.insert(address, (store.new_value, store.old_value, store.old_value)); + } + + if else_branch.is_some() { + for (address, store) in else_branch.clone().unwrap().store_values { + if let Some(entry) = new_map.get_mut(&address) { + entry.1 = store.new_value; + } else { + new_map.insert(address, (store.old_value, store.new_value, store.old_value)); + } + } + } + + let then_condition = then_branch.condition; + let block = self.inserter.function.entry_block(); + + // Merging must occur in a separate loop as we cannot borrow `self` as mutable while `value_merger` does + let mut new_values = HashMap::default(); + for (address, (then_case, else_case, _)) in &new_map { + let instruction = Instruction::IfElse { + then_condition, + then_value: *then_case, + else_value: *else_case, + }; + let dfg = &mut self.inserter.function.dfg; + let value = dfg + .insert_instruction_and_results(instruction, block, None, call_stack.clone()) + .first(); + + new_values.insert(address, value); + } + + // Replace stores with new merged values + for (address, (_, _, old_value)) in &new_map { + let value = new_values[address]; + let address = *address; + self.insert_instruction_with_typevars( + Instruction::Store { address, value }, + None, + call_stack.clone(), + ); + + if let Some(store) = self.store_values.get_mut(&address) { + store.new_value = value; + } else { + self.store_values.insert( + address, + Store { + old_value: *old_value, + new_value: value, + call_stack: call_stack.clone(), + }, + ); + } + } + } + + fn remember_store(&mut self, address: ValueId, new_value: ValueId, call_stack: CallStack) { + if !self.local_allocations.contains(&address) { + if let Some(store_value) = self.store_values.get_mut(&address) { + store_value.new_value = new_value; + } else { + let load = Instruction::Load { address }; + + let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]); + let old_value = self + .insert_instruction_with_typevars(load.clone(), load_type, call_stack.clone()) + .first(); + + self.store_values.insert(address, Store { old_value, new_value, call_stack }); + } + } + } + /// Push the given instruction to the end of the entry block of the current function. /// /// Note that each ValueId of the instruction will be mapped via self.inserter.resolve. /// As a result, the instruction that will be pushed will actually be a new instruction /// with a different InstructionId from the original. The results of the given instruction /// will also be mapped to the results of the new instruction. - /// - /// `previous_allocate_result` should only be set to the result of an allocate instruction - /// if that instruction was the instruction immediately previous to this one - if there are - /// any instructions in between it should be None. - fn push_instruction(&mut self, id: InstructionId) { + fn push_instruction(&mut self, id: InstructionId) -> Vec { let (instruction, call_stack) = self.inserter.map_instruction(id); let instruction = self.handle_instruction_side_effects(instruction, call_stack.clone()); + let is_allocate = matches!(instruction, Instruction::Allocate); - let instruction_is_allocate = matches!(&instruction, Instruction::Allocate); let entry = self.inserter.function.entry_block(); let results = self.inserter.push_instruction_value(instruction, id, entry, call_stack); // Remember an allocate was created local to this branch so that we do not try to merge store // values across branches for it later. - if instruction_is_allocate { + if is_allocate { self.local_allocations.insert(results.first()); } + + results.results().into_owned() } /// If we are currently in a branch, we need to modify constrain instructions @@ -668,32 +782,8 @@ impl<'f> Context<'f> { Instruction::Constrain(lhs, rhs, message) } Instruction::Store { address, value } => { - // If this instruction immediately follows an allocate, and stores to that - // address there is no previous value to load and we don't need a merge anyway. - if self.local_allocations.contains(&address) { - Instruction::Store { address, value } - } else { - // Instead of storing `value`, store `if condition { value } else { previous_value }` - let typ = self.inserter.function.dfg.type_of_value(value); - let load = Instruction::Load { address }; - let previous_value = self - .insert_instruction_with_typevars( - load, - Some(vec![typ]), - call_stack.clone(), - ) - .first(); - - let instruction = Instruction::IfElse { - then_condition: condition, - then_value: value, - - else_value: previous_value, - }; - - let updated_value = self.insert_instruction(instruction, call_stack); - Instruction::Store { address, value: updated_value } - } + self.remember_store(address, value, call_stack); + Instruction::Store { address, value } } Instruction::RangeCheck { value, max_bit_size, assert_message } => { // Replace value with `value * predicate` to zero out value when predicate is inactive. @@ -815,11 +905,23 @@ impl<'f> Context<'f> { call_stack, ) } + + fn undo_stores_in_then_branch(&mut self, store_values: &HashMap) { + for (address, store) in store_values { + let address = *address; + let value = store.old_value; + let instruction = Instruction::Store { address, value }; + // Considering the location of undoing a store to be the same as the original store. + self.insert_instruction_with_typevars(instruction, None, store.call_stack.clone()); + } + } } #[cfg(test)] mod test { - use acvm::acir::AcirField; + use std::sync::Arc; + + use acvm::{acir::AcirField, FieldElement}; use crate::ssa::{ function_builder::FunctionBuilder, @@ -921,13 +1023,15 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - v3 = cast v0 as Field - v5 = sub Field 5, v2 - v6 = mul v3, v5 - v7 = add v2, v6 - store v7 at v1 - v8 = not v0 + store Field 5 at v1 + v4 = not v0 + store v2 at v1 enable_side_effects u1 1 + v6 = cast v0 as Field + v7 = sub Field 5, v2 + v8 = mul v6, v7 + v9 = add v2, v8 + store v9 at v1 return } "; @@ -958,20 +1062,17 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - v3 = cast v0 as Field - v5 = sub Field 5, v2 - v6 = mul v3, v5 - v7 = add v2, v6 - store v7 at v1 - v8 = not v0 - enable_side_effects v8 - v9 = load v1 -> Field - v10 = cast v8 as Field - v12 = sub Field 6, v9 - v13 = mul v10, v12 - v14 = add v9, v13 - store v14 at v1 + store Field 5 at v1 + v4 = not v0 + store v2 at v1 + enable_side_effects v4 + v5 = load v1 -> Field + store Field 6 at v1 enable_side_effects u1 1 + v8 = cast v0 as Field + v10 = mul v8, Field -1 + v11 = add Field 6, v10 + store v11 at v1 return } "; @@ -1014,123 +1115,84 @@ mod test { // b7 b8 // ↘ ↙ // b9 - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); - let b4 = builder.insert_block(); - let b5 = builder.insert_block(); - let b6 = builder.insert_block(); - let b7 = builder.insert_block(); - let b8 = builder.insert_block(); - let b9 = builder.insert_block(); - - let c1 = builder.add_parameter(Type::bool()); - let c4 = builder.add_parameter(Type::bool()); - - let r1 = builder.insert_allocate(Type::field()); - - let store_value = |builder: &mut FunctionBuilder, value: u128| { - let value = builder.field_constant(value); - builder.insert_store(r1, value); - }; - - let test_function = Id::test_new(1); - - let call_test_function = |builder: &mut FunctionBuilder, block: u128| { - let block = builder.field_constant(block); - let load = builder.insert_load(r1, Type::field()); - builder.insert_call(test_function, vec![block, load], Vec::new()); - }; - - let switch_store_and_test_function = - |builder: &mut FunctionBuilder, block, block_number: u128| { - builder.switch_to_block(block); - store_value(builder, block_number); - call_test_function(builder, block_number); - }; - - let switch_and_test_function = - |builder: &mut FunctionBuilder, block, block_number: u128| { - builder.switch_to_block(block); - call_test_function(builder, block_number); - }; - - store_value(&mut builder, 0); - call_test_function(&mut builder, 0); - builder.terminate_with_jmp(b1, vec![]); - - switch_store_and_test_function(&mut builder, b1, 1); - builder.terminate_with_jmpif(c1, b2, b3); - - switch_store_and_test_function(&mut builder, b2, 2); - builder.terminate_with_jmp(b4, vec![]); - - switch_store_and_test_function(&mut builder, b3, 3); - builder.terminate_with_jmp(b8, vec![]); - - switch_and_test_function(&mut builder, b4, 4); - builder.terminate_with_jmpif(c4, b5, b6); - - switch_store_and_test_function(&mut builder, b5, 5); - builder.terminate_with_jmp(b7, vec![]); - - switch_store_and_test_function(&mut builder, b6, 6); - builder.terminate_with_jmp(b7, vec![]); - - switch_and_test_function(&mut builder, b7, 7); - builder.terminate_with_jmp(b9, vec![]); - - switch_and_test_function(&mut builder, b8, 8); - builder.terminate_with_jmp(b9, vec![]); - - switch_and_test_function(&mut builder, b9, 9); - let load = builder.insert_load(r1, Type::field()); - builder.terminate_with_return(vec![load]); + let src = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = allocate -> &mut Field + store Field 0 at v2 + v4 = load v2 -> Field + // call v1(Field 0, v4) + jmp b1() + b1(): + store Field 1 at v2 + v6 = load v2 -> Field + // call v1(Field 1, v6) + jmpif v0 then: b2, else: b3 + b2(): + store Field 2 at v2 + v8 = load v2 -> Field + // call v1(Field 2, v8) + jmp b4() + b4(): + v12 = load v2 -> Field + // call v1(Field 4, v12) + jmpif v1 then: b5, else: b6 + b5(): + store Field 5 at v2 + v14 = load v2 -> Field + // call v1(Field 5, v14) + jmp b7() + b7(): + v18 = load v2 -> Field + // call v1(Field 7, v18) + jmp b9() + b9(): + v22 = load v2 -> Field + // call v1(Field 9, v22) + v23 = load v2 -> Field + return v23 + b6(): + store Field 6 at v2 + v16 = load v2 -> Field + // call v1(Field 6, v16) + jmp b7() + b3(): + store Field 3 at v2 + v10 = load v2 -> Field + // call v1(Field 3, v10) + jmp b8() + b8(): + v20 = load v2 -> Field + // call v1(Field 8, v20) + jmp b9() + } + "; - let ssa = builder.finish().flatten_cfg().mem2reg(); + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.flatten_cfg().mem2reg(); // Expected results after mem2reg removes the allocation and each load and store: - // - // fn main f0 { - // b0(v0: u1, v1: u1): - // call test_function(Field 0, Field 0) - // call test_function(Field 1, Field 1) - // enable_side_effects v0 - // call test_function(Field 2, Field 2) - // call test_function(Field 4, Field 2) - // v29 = and v0, v1 - // enable_side_effects v29 - // call test_function(Field 5, Field 5) - // v32 = not v1 - // v33 = and v0, v32 - // enable_side_effects v33 - // call test_function(Field 6, Field 6) - // enable_side_effects v0 - // v36 = mul v1, Field 5 - // v37 = mul v32, Field 2 - // v38 = add v36, v37 - // v39 = mul v1, Field 5 - // v40 = mul v32, Field 6 - // v41 = add v39, v40 - // call test_function(Field 7, v42) - // v43 = not v0 - // enable_side_effects v43 - // store Field 3 at v2 - // call test_function(Field 3, Field 3) - // call test_function(Field 8, Field 3) - // enable_side_effects Field 1 - // v47 = mul v0, v41 - // v48 = mul v43, Field 1 - // v49 = add v47, v48 - // v50 = mul v0, v44 - // v51 = mul v43, Field 3 - // v52 = add v50, v51 - // call test_function(Field 9, v53) - // return v54 - // } + let expected = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = allocate -> &mut Field + enable_side_effects v0 + v3 = mul v0, v1 + enable_side_effects v3 + v4 = not v1 + v5 = mul v0, v4 + enable_side_effects v0 + v6 = cast v3 as Field + v8 = mul v6, Field -1 + v10 = add Field 6, v8 + v11 = not v0 + enable_side_effects u1 1 + v13 = cast v0 as Field + v15 = sub v10, Field 3 + v16 = mul v13, v15 + v17 = add Field 3, v16 + return v17 + }"; let main = ssa.main(); let ret = match main.dfg[main.entry_block()].terminator() { @@ -1139,7 +1201,12 @@ mod test { }; let merged_values = get_all_constants_reachable_from_instruction(&main.dfg, ret); - assert_eq!(merged_values, vec![1, 3, 5, 6]); + assert_eq!( + merged_values, + vec![FieldElement::from(3u128), FieldElement::from(6u128), -FieldElement::from(1u128)] + ); + + assert_normalized_ssa_equals(ssa, expected); } #[test] @@ -1220,7 +1287,7 @@ mod test { fn get_all_constants_reachable_from_instruction( dfg: &DataFlowGraph, value: ValueId, - ) -> Vec { + ) -> Vec { match dfg[value] { Value::Instruction { instruction, .. } => { let mut values = vec![]; @@ -1238,7 +1305,7 @@ mod test { values.dedup(); values } - Value::NumericConstant { constant, .. } => vec![constant.to_u128()], + Value::NumericConstant { constant, .. } => vec![constant], _ => Vec::new(), } } @@ -1277,74 +1344,63 @@ mod test { fn should_not_merge_incorrectly_to_false() { // Regression test for #1792 // Tests that it does not simplify a true constraint an always-false constraint - - let src = " - acir(inline) fn main f0 { - b0(v0: [u8; 2]): - v2 = array_get v0, index u8 0 -> u8 - v3 = cast v2 as u32 - v4 = truncate v3 to 1 bits, max_bit_size: 32 - v5 = cast v4 as u1 - v6 = allocate -> &mut Field - store u8 0 at v6 - jmpif v5 then: b2, else: b1 - b2(): - v7 = cast v2 as Field - v9 = add v7, Field 1 - v10 = cast v9 as u8 - store v10 at v6 - jmp b3() - b3(): - constrain v5 == u1 1 - return - b1(): - store u8 0 at v6 - jmp b3() - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - let expected = " - acir(inline) fn main f0 { - b0(v0: [u8; 2]): - v2 = array_get v0, index u8 0 -> u8 - v3 = cast v2 as u32 - v4 = truncate v3 to 1 bits, max_bit_size: 32 - v5 = cast v4 as u1 - v6 = allocate -> &mut Field - store u8 0 at v6 - enable_side_effects v5 - v7 = cast v2 as Field - v9 = add v7, Field 1 - v10 = cast v9 as u8 - v11 = load v6 -> u8 - v12 = cast v4 as Field - v13 = cast v11 as Field - v14 = sub v9, v13 - v15 = mul v12, v14 - v16 = add v13, v15 - v17 = cast v16 as u8 - store v17 at v6 - v18 = not v5 - enable_side_effects v18 - v19 = load v6 -> u8 - v20 = cast v18 as Field - v21 = cast v19 as Field - v23 = sub Field 0, v21 - v24 = mul v20, v23 - v25 = add v21, v24 - v26 = cast v25 as u8 - store v26 at v6 - enable_side_effects u1 1 - constrain v5 == u1 1 - return - } - "; - + // acir(inline) fn main f1 { + // b0(v0: [u8; 2]): + // v5 = array_get v0, index u8 0 + // v6 = cast v5 as u32 + // v8 = truncate v6 to 1 bits, max_bit_size: 32 + // v9 = cast v8 as u1 + // v10 = allocate + // store u8 0 at v10 + // jmpif v9 then: b2, else: b3 + // b2(): + // v12 = cast v5 as Field + // v13 = add v12, Field 1 + // store v13 at v10 + // jmp b4() + // b4(): + // constrain v9 == u1 1 + // return + // b3(): + // store u8 0 at v10 + // jmp b4() + // } + let main_id = Id::test_new(1); + let mut builder = FunctionBuilder::new("main".into(), main_id); + builder.insert_block(); // b0 + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + let element_type = Arc::new(vec![Type::unsigned(8)]); + let array_type = Type::Array(element_type.clone(), 2); + let array = builder.add_parameter(array_type); + let zero = builder.numeric_constant(0_u128, Type::unsigned(8)); + let v5 = builder.insert_array_get(array, zero, Type::unsigned(8)); + let v6 = builder.insert_cast(v5, Type::unsigned(32)); + let i_two = builder.numeric_constant(2_u128, Type::unsigned(32)); + let v8 = builder.insert_binary(v6, BinaryOp::Mod, i_two); + let v9 = builder.insert_cast(v8, Type::bool()); + let v10 = builder.insert_allocate(Type::field()); + builder.insert_store(v10, zero); + builder.terminate_with_jmpif(v9, b1, b2); + builder.switch_to_block(b1); + let one = builder.field_constant(1_u128); + let v5b = builder.insert_cast(v5, Type::field()); + let v13: Id = builder.insert_binary(v5b, BinaryOp::Add, one); + let v14 = builder.insert_cast(v13, Type::unsigned(8)); + builder.insert_store(v10, v14); + builder.terminate_with_jmp(b3, vec![]); + builder.switch_to_block(b2); + builder.insert_store(v10, zero); + builder.terminate_with_jmp(b3, vec![]); + builder.switch_to_block(b3); + let v_true = builder.numeric_constant(true, Type::bool()); + let v12 = builder.insert_binary(v9, BinaryOp::Eq, v_true); + builder.insert_constrain(v12, v_true, None); + builder.terminate_with_return(vec![]); + let ssa = builder.finish(); let flattened_ssa = ssa.flatten_cfg(); let main = flattened_ssa.main(); - // Now assert that there is not an always-false constraint after flattening: let mut constrain_count = 0; for instruction in main.dfg[main.entry_block()].instructions() { @@ -1358,8 +1414,6 @@ mod test { } } assert_eq!(constrain_count, 1); - - assert_normalized_ssa_equals(flattened_ssa, expected); } #[test] @@ -1495,7 +1549,7 @@ mod test { b2(): return b1(): - jmp b2() + jmp b2() } "; let merged_ssa = Ssa::from_str(src).unwrap(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index f91487fd73e..6cf7070e65e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -1089,6 +1089,7 @@ mod test { } #[test] + #[ignore] #[should_panic( expected = "Attempted to recur more than 1000 times during inlining function 'main': acir(inline) fn main f0 {" )] diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 53a31ae57c1..0690dbbf204 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -18,7 +18,6 @@ //! - A reference with 0 aliases means we were unable to find which reference this reference //! refers to. If such a reference is stored to, we must conservatively invalidate every //! reference in the current block. -//! - We also track the last load instruction to each address per block. //! //! From there, to figure out the value of each reference at the end of block, iterate each instruction: //! - On `Instruction::Allocate`: @@ -29,13 +28,6 @@ //! - Furthermore, if the result of the load is a reference, mark the result as an alias //! of the reference it dereferences to (if known). //! - If which reference it dereferences to is not known, this load result has no aliases. -//! - We also track the last instance of a load instruction to each address in a block. -//! If we see that the last load instruction was from the same address as the current load instruction, -//! we move to replace the result of the current load with the result of the previous load. -//! This removal requires a couple conditions: -//! - No store occurs to that address before the next load, -//! - The address is not used as an argument to a call -//! This optimization helps us remove repeated loads for which there are not known values. //! - On `Instruction::Store { address, value }`: //! - If the address of the store is known: //! - If the address has exactly 1 alias: @@ -48,13 +40,11 @@ //! - Conservatively mark every alias in the block to `Unknown`. //! - Additionally, if there were no Loads to any alias of the address between this Store and //! the previous Store to the same address, the previous store can be removed. -//! - Remove the instance of the last load instruction to the address and its aliases //! - On `Instruction::Call { arguments }`: //! - If any argument of the call is a reference, set the value of each alias of that //! reference to `Unknown` //! - Any builtin functions that may return aliases if their input also contains a //! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc. -//! - Remove the instance of the last load instruction for any reference arguments and their aliases //! //! On a terminator instruction: //! - If the terminator is a `Jmp`: @@ -284,9 +274,6 @@ impl<'f> PerFunctionContext<'f> { if let Some(first_predecessor) = predecessors.next() { let mut first = self.blocks.get(&first_predecessor).cloned().unwrap_or_default(); first.last_stores.clear(); - // Last loads are tracked per block. During unification we are creating a new block from the current one, - // so we must clear the last loads of the current block before we return the new block. - first.last_loads.clear(); // Note that we have to start folding with the first block as the accumulator. // If we started with an empty block, an empty block union'd with any other block @@ -423,28 +410,6 @@ impl<'f> PerFunctionContext<'f> { self.last_loads.insert(address, (instruction, block_id)); } - - // Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads). - // If we do have a repeat load, we can remove the current load and map its result to the previous load's result. - if let Some(last_load) = references.last_loads.get(&address) { - let Instruction::Load { address: previous_address } = - &self.inserter.function.dfg[*last_load] - else { - panic!("Expected a Load instruction here"); - }; - let result = self.inserter.function.dfg.instruction_results(instruction)[0]; - let previous_result = - self.inserter.function.dfg.instruction_results(*last_load)[0]; - if *previous_address == address { - self.inserter.map_value(result, previous_result); - self.instructions_to_remove.insert(instruction); - } - } - // We want to set the load for every load even if the address has a known value - // and the previous load instruction was removed. - // We are safe to still remove a repeat load in this case as we are mapping from the current load's - // result to the previous load, which if it was removed should already have a mapping to the known value. - references.set_last_load(address, instruction); } Instruction::Store { address, value } => { let address = self.inserter.function.dfg.resolve(*address); @@ -470,8 +435,6 @@ impl<'f> PerFunctionContext<'f> { } references.set_known_value(address, value); - // If we see a store to an address, the last load to that address needs to remain. - references.keep_last_load_for(address, self.inserter.function); references.last_stores.insert(address, instruction); } Instruction::Allocate => { @@ -579,9 +542,6 @@ impl<'f> PerFunctionContext<'f> { let value = self.inserter.function.dfg.resolve(*value); references.set_unknown(value); references.mark_value_used(value, self.inserter.function); - - // If a reference is an argument to a call, the last load to that address and its aliases needs to remain. - references.keep_last_load_for(value, self.inserter.function); } } } @@ -612,12 +572,6 @@ impl<'f> PerFunctionContext<'f> { let destination_parameters = self.inserter.function.dfg[*destination].parameters(); assert_eq!(destination_parameters.len(), arguments.len()); - // If we have multiple parameters that alias that same argument value, - // then those parameters also alias each other. - // We save parameters with repeat arguments to later mark those - // parameters as aliasing one another. - let mut arg_set: HashMap> = HashMap::default(); - // Add an alias for each reference parameter for (parameter, argument) in destination_parameters.iter().zip(arguments) { if self.inserter.function.dfg.value_is_reference(*parameter) { @@ -627,27 +581,10 @@ impl<'f> PerFunctionContext<'f> { if let Some(aliases) = references.aliases.get_mut(expression) { // The argument reference is possibly aliased by this block parameter aliases.insert(*parameter); - - // Check if we have seen the same argument - let seen_parameters = arg_set.entry(argument).or_default(); - // Add the current parameter to the parameters we have seen for this argument. - // The previous parameters and the current one alias one another. - seen_parameters.insert(*parameter); } } } } - - // Set the aliases of the parameters - for (_, aliased_params) in arg_set { - for param in aliased_params.iter() { - self.set_aliases( - references, - *param, - AliasSet::known_multiple(aliased_params.clone()), - ); - } - } } TerminatorInstruction::Return { return_values, .. } => { // Removing all `last_stores` for each returned reference is more important here @@ -675,8 +612,6 @@ mod tests { map::Id, types::Type, }, - opt::assert_normalized_ssa_equals, - Ssa, }; #[test] @@ -887,53 +822,88 @@ mod tests { // is later stored in a successor block #[test] fn load_aliases_in_predecessor_block() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = allocate -> &mut Field - store Field 0 at v0 - v2 = allocate -> &mut &mut Field - store v0 at v2 - v3 = load v2 -> &mut Field - v4 = load v2 -> &mut Field - jmp b1() - b1(): - store Field 1 at v3 - store Field 2 at v4 - v7 = load v3 -> Field - v8 = eq v7, Field 2 - return - } - "; + // fn main { + // b0(): + // v0 = allocate + // store Field 0 at v0 + // v2 = allocate + // store v0 at v2 + // v3 = load v2 + // v4 = load v2 + // jmp b1() + // b1(): + // store Field 1 at v3 + // store Field 2 at v4 + // v7 = load v3 + // v8 = eq v7, Field 2 + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let v0 = builder.insert_allocate(Type::field()); + + let zero = builder.field_constant(0u128); + builder.insert_store(v0, zero); + + let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field()))); + builder.insert_store(v2, v0); + + let v3 = builder.insert_load(v2, Type::field()); + let v4 = builder.insert_load(v2, Type::field()); + let b1 = builder.insert_block(); + builder.terminate_with_jmp(b1, vec![]); + + builder.switch_to_block(b1); + + let one = builder.field_constant(1u128); + builder.insert_store(v3, one); + + let two = builder.field_constant(2u128); + builder.insert_store(v4, two); + + let v8 = builder.insert_load(v3, Type::field()); + let _ = builder.insert_binary(v8, BinaryOp::Eq, two); + + builder.terminate_with_return(vec![]); + + let ssa = builder.finish(); + assert_eq!(ssa.main().reachable_blocks().len(), 2); - let mut ssa = Ssa::from_str(src).unwrap(); - let main = ssa.main_mut(); + // Expected result: + // acir fn main f0 { + // b0(): + // v9 = allocate + // store Field 0 at v9 + // v10 = allocate + // jmp b1() + // b1(): + // return + // } + let ssa = ssa.mem2reg(); + println!("{}", ssa); - let instructions = main.dfg[main.entry_block()].instructions(); - assert_eq!(instructions.len(), 6); // The final return is not counted + let main = ssa.main(); + assert_eq!(main.reachable_blocks().len(), 2); // All loads should be removed + assert_eq!(count_loads(main.entry_block(), &main.dfg), 0); + assert_eq!(count_loads(b1, &main.dfg), 0); + // The first store is not removed as it is used as a nested reference in another store. - // We would need to track whether the store where `v0` is the store value gets removed to know whether + // We would need to track whether the store where `v9` is the store value gets removed to know whether // to remove it. + assert_eq!(count_stores(main.entry_block(), &main.dfg), 1); // The first store in b1 is removed since there is another store to the same reference // in the same block, and the store is not needed before the later store. // The rest of the stores are also removed as no loads are done within any blocks // to the stored values. - let expected = " - acir(inline) fn main f0 { - b0(): - v0 = allocate -> &mut Field - store Field 0 at v0 - v2 = allocate -> &mut &mut Field - jmp b1() - b1(): - return - } - "; + assert_eq!(count_stores(b1, &main.dfg), 0); - let ssa = ssa.mem2reg(); - assert_normalized_ssa_equals(ssa, expected); + let b1_instructions = main.dfg[b1].instructions(); + + // We expect the last eq to be optimized out + assert_eq!(b1_instructions.len(), 0); } #[test] @@ -963,7 +933,7 @@ mod tests { // v10 = eq v9, Field 2 // constrain v9 == Field 2 // v11 = load v2 - // v12 = load v11 + // v12 = load v10 // v13 = eq v12, Field 2 // constrain v11 == Field 2 // return @@ -1022,7 +992,7 @@ mod tests { let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 4); - // The stores from the original SSA should remain + // The store from the original SSA should remain assert_eq!(count_stores(main.entry_block(), &main.dfg), 2); assert_eq!(count_stores(b2, &main.dfg), 1); @@ -1069,160 +1039,4 @@ mod tests { let main = ssa.main(); assert_eq!(count_loads(main.entry_block(), &main.dfg), 1); } - - #[test] - fn remove_repeat_loads() { - // This tests starts with two loads from the same unknown load. - // Specifically you should look for `load v2` in `b3`. - // We should be able to remove the second repeated load. - let src = " - acir(inline) fn main f0 { - b0(): - v0 = allocate -> &mut Field - store Field 0 at v0 - v2 = allocate -> &mut &mut Field - store v0 at v2 - jmp b1(Field 0) - b1(v3: Field): - v4 = eq v3, Field 0 - jmpif v4 then: b2, else: b3 - b2(): - v5 = load v2 -> &mut Field - store Field 2 at v5 - v8 = add v3, Field 1 - jmp b1(v8) - b3(): - v9 = load v0 -> Field - v10 = eq v9, Field 2 - constrain v9 == Field 2 - v11 = load v2 -> &mut Field - v12 = load v2 -> &mut Field - v13 = load v12 -> Field - v14 = eq v13, Field 2 - constrain v13 == Field 2 - return - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - // The repeated load from v3 should be removed - // b3 should only have three loads now rather than four previously - // - // All stores are expected to remain. - let expected = " - acir(inline) fn main f0 { - b0(): - v1 = allocate -> &mut Field - store Field 0 at v1 - v3 = allocate -> &mut &mut Field - store v1 at v3 - jmp b1(Field 0) - b1(v0: Field): - v4 = eq v0, Field 0 - jmpif v4 then: b3, else: b2 - b3(): - v11 = load v3 -> &mut Field - store Field 2 at v11 - v13 = add v0, Field 1 - jmp b1(v13) - b2(): - v5 = load v1 -> Field - v7 = eq v5, Field 2 - constrain v5 == Field 2 - v8 = load v3 -> &mut Field - v9 = load v8 -> Field - v10 = eq v9, Field 2 - constrain v9 == Field 2 - return - } - "; - - let ssa = ssa.mem2reg(); - assert_normalized_ssa_equals(ssa, expected); - } - - #[test] - fn keep_repeat_loads_passed_to_a_call() { - // The test is the exact same as `remove_repeat_loads` above except with the call - // to `f1` between the repeated loads. - let src = " - acir(inline) fn main f0 { - b0(): - v1 = allocate -> &mut Field - store Field 0 at v1 - v3 = allocate -> &mut &mut Field - store v1 at v3 - jmp b1(Field 0) - b1(v0: Field): - v4 = eq v0, Field 0 - jmpif v4 then: b3, else: b2 - b3(): - v13 = load v3 -> &mut Field - store Field 2 at v13 - v15 = add v0, Field 1 - jmp b1(v15) - b2(): - v5 = load v1 -> Field - v7 = eq v5, Field 2 - constrain v5 == Field 2 - v8 = load v3 -> &mut Field - call f1(v3) - v10 = load v3 -> &mut Field - v11 = load v10 -> Field - v12 = eq v11, Field 2 - constrain v11 == Field 2 - return - } - acir(inline) fn foo f1 { - b0(v0: &mut Field): - return - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - let ssa = ssa.mem2reg(); - // We expect the program to be unchanged - assert_normalized_ssa_equals(ssa, src); - } - - #[test] - fn keep_repeat_loads_with_alias_store() { - // v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store - // to its aliases in between the repeat loads does not remove those loads. - let src = " - acir(inline) fn main f0 { - b0(v0: u1): - jmpif v0 then: b2, else: b1 - b2(): - v6 = allocate -> &mut Field - store Field 0 at v6 - jmp b3(v6, v6, v6) - b3(v1: &mut Field, v2: &mut Field, v3: &mut Field): - v8 = load v1 -> Field - store Field 2 at v2 - v10 = load v1 -> Field - store Field 1 at v3 - v11 = load v1 -> Field - store Field 3 at v3 - v13 = load v1 -> Field - constrain v8 == Field 0 - constrain v10 == Field 2 - constrain v11 == Field 1 - constrain v13 == Field 3 - return - b1(): - v4 = allocate -> &mut Field - store Field 1 at v4 - jmp b3(v4, v4, v4) - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - let ssa = ssa.mem2reg(); - // We expect the program to be unchanged - assert_normalized_ssa_equals(ssa, src); - } } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index ddc3365b551..0c6041029da 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -172,7 +172,6 @@ impl<'a> FunctionContext<'a> { /// Always returns a Value::Mutable wrapping the allocate instruction. pub(super) fn new_mutable_variable(&mut self, value_to_store: ValueId) -> Value { let element_type = self.builder.current_function.dfg.type_of_value(value_to_store); - self.builder.increment_array_reference_count(value_to_store); let alloc = self.builder.insert_allocate(element_type); self.builder.insert_store(alloc, value_to_store); let typ = self.builder.type_of_value(value_to_store); @@ -736,6 +735,7 @@ impl<'a> FunctionContext<'a> { // Reference counting in brillig relies on us incrementing reference // counts when arrays/slices are constructed or indexed. // Thus, if we dereference an lvalue which happens to be array/slice we should increment its reference counter. + self.builder.increment_array_reference_count(reference); self.builder.insert_load(reference, element_type).into() }) } @@ -916,10 +916,7 @@ impl<'a> FunctionContext<'a> { let parameters = self.builder.current_function.dfg.block_parameters(entry).to_vec(); for parameter in parameters { - // Avoid reference counts for immutable arrays that aren't behind references. - if self.builder.current_function.dfg.value_is_reference(parameter) { - self.builder.increment_array_reference_count(parameter); - } + self.builder.increment_array_reference_count(parameter); } entry @@ -936,9 +933,7 @@ impl<'a> FunctionContext<'a> { dropped_parameters.retain(|parameter| !terminator_args.contains(parameter)); for parameter in dropped_parameters { - if self.builder.current_function.dfg.value_is_reference(parameter) { - self.builder.decrement_array_reference_count(parameter); - } + self.builder.decrement_array_reference_count(parameter); } } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index d28236bd360..c50f0a7f45c 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -665,11 +665,12 @@ impl<'a> FunctionContext<'a> { values = values.map(|value| { let value = value.eval(self); + // Make sure to increment array reference counts on each let binding + self.builder.increment_array_reference_count(value); + Tree::Leaf(if let_expr.mutable { self.new_mutable_variable(value) } else { - // `new_mutable_variable` already increments rcs internally - self.builder.increment_array_reference_count(value); value::Value::Normal(value) }) }); diff --git a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs index c3ae29de058..e68a8cd5bd2 100644 --- a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs +++ b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs @@ -31,6 +31,10 @@ pub(crate) struct GatesFlamegraphCommand { /// The output folder for the flamegraph svg files #[clap(long, short)] output: String, + + /// The output name for the flamegraph svg files + #[clap(long, short = 'f')] + output_filename: Option, } pub(crate) fn run(args: GatesFlamegraphCommand) -> eyre::Result<()> { @@ -43,6 +47,7 @@ pub(crate) fn run(args: GatesFlamegraphCommand) -> eyre::Result<()> { }, &InfernoFlamegraphGenerator { count_name: "gates".to_string() }, &PathBuf::from(args.output), + args.output_filename, ) } @@ -51,6 +56,7 @@ fn run_with_provider( gates_provider: &Provider, flamegraph_generator: &Generator, output_path: &Path, + output_filename: Option, ) -> eyre::Result<()> { let mut program = read_program_from_file(artifact_path).context("Error reading program from file")?; @@ -91,13 +97,18 @@ fn run_with_provider( }) .collect(); + let output_filename = if let Some(output_filename) = &output_filename { + format!("{}::{}::gates.svg", output_filename, func_name) + } else { + format!("{}::gates.svg", func_name) + }; flamegraph_generator.generate_flamegraph( samples, &debug_artifact.debug_symbols[func_idx], &debug_artifact, artifact_path.to_str().unwrap(), &func_name, - &Path::new(&output_path).join(Path::new(&format!("{}_gates.svg", &func_name))), + &Path::new(&output_path).join(Path::new(&output_filename)), )?; } @@ -189,11 +200,17 @@ mod tests { }; let flamegraph_generator = TestFlamegraphGenerator::default(); - super::run_with_provider(&artifact_path, &provider, &flamegraph_generator, temp_dir.path()) - .expect("should run without errors"); + super::run_with_provider( + &artifact_path, + &provider, + &flamegraph_generator, + temp_dir.path(), + Some(String::from("test_filename")), + ) + .expect("should run without errors"); // Check that the output file was written to - let output_file = temp_dir.path().join("main_gates.svg"); + let output_file = temp_dir.path().join("test_filename::main::gates.svg"); assert!(output_file.exists()); } } From 76894c76b1bd2ac36df0c4052ff07d17473650e4 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Fri, 29 Nov 2024 14:35:13 +0000 Subject: [PATCH 2/9] chore: apply sync fixes --- .aztec-sync-commit | 2 +- .../noirc_evaluator/src/ssa/ir/instruction.rs | 15 +- .../src/ssa/ir/instruction/call.rs | 34 +- .../src/ssa/ir/instruction/call/blackbox.rs | 4 +- .../src/ssa/opt/constant_folding.rs | 445 ++++---------- .../src/ssa/opt/flatten_cfg.rs | 556 ++++++++++-------- .../noirc_evaluator/src/ssa/opt/inlining.rs | 1 + .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 332 +++-------- .../src/ssa/ssa_gen/context.rs | 11 +- .../noirc_evaluator/src/ssa/ssa_gen/mod.rs | 5 +- .../profiler/src/cli/gates_flamegraph_cmd.rs | 25 +- 11 files changed, 559 insertions(+), 871 deletions(-) diff --git a/.aztec-sync-commit b/.aztec-sync-commit index d97a936c081..477ebbca903 100644 --- a/.aztec-sync-commit +++ b/.aztec-sync-commit @@ -1 +1 @@ -1bfc15e08873a1f0f3743e259f418b70426b3f25 +0577c1a70e9746bd06f07d2813af1be39e01ca02 diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 6737b335b7d..f606fffbf91 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -11,7 +11,7 @@ use fxhash::FxHasher64; use iter_extended::vecmap; use noirc_frontend::hir_def::types::Type as HirType; -use crate::ssa::{ir::function::RuntimeType, opt::flatten_cfg::value_merger::ValueMerger}; +use crate::ssa::opt::flatten_cfg::value_merger::ValueMerger; use super::{ basic_block::BasicBlockId, @@ -478,19 +478,8 @@ impl Instruction { | ArraySet { .. } | MakeArray { .. } => true, - // Store instructions must be removed by DIE in acir code, any load - // instructions should already be unused by that point. - // - // Note that this check assumes that it is being performed after the flattening - // pass and after the last mem2reg pass. This is currently the case for the DIE - // pass where this check is done, but does mean that we cannot perform mem2reg - // after the DIE pass. - Store { .. } => { - matches!(function.runtime(), RuntimeType::Acir(_)) - && function.reachable_blocks().len() == 1 - } - Constrain(..) + | Store { .. } | EnableSideEffectsIf { .. } | IncrementRc { .. } | DecrementRc { .. } diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 4be37b3c626..67222d06ea8 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -45,17 +45,17 @@ pub(super) fn simplify_call( _ => return SimplifyResult::None, }; + let return_type = ctrl_typevars.and_then(|return_types| return_types.first().cloned()); + let constant_args: Option> = arguments.iter().map(|value_id| dfg.get_numeric_constant(*value_id)).collect(); - match intrinsic { + let simplified_result = match intrinsic { Intrinsic::ToBits(endian) => { // TODO: simplify to a range constraint if `limb_count == 1` - if let (Some(constant_args), Some(return_type)) = - (constant_args, ctrl_typevars.map(|return_types| return_types.first().cloned())) - { + if let (Some(constant_args), Some(return_type)) = (constant_args, return_type.clone()) { let field = constant_args[0]; - let limb_count = if let Some(Type::Array(_, array_len)) = return_type { + let limb_count = if let Type::Array(_, array_len) = return_type { array_len as u32 } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") @@ -67,12 +67,10 @@ pub(super) fn simplify_call( } Intrinsic::ToRadix(endian) => { // TODO: simplify to a range constraint if `limb_count == 1` - if let (Some(constant_args), Some(return_type)) = - (constant_args, ctrl_typevars.map(|return_types| return_types.first().cloned())) - { + if let (Some(constant_args), Some(return_type)) = (constant_args, return_type.clone()) { let field = constant_args[0]; let radix = constant_args[1].to_u128() as u32; - let limb_count = if let Some(Type::Array(_, array_len)) = return_type { + let limb_count = if let Type::Array(_, array_len) = return_type { array_len as u32 } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") @@ -330,7 +328,7 @@ pub(super) fn simplify_call( } Intrinsic::FromField => { let incoming_type = Type::field(); - let target_type = ctrl_typevars.unwrap().remove(0); + let target_type = return_type.clone().unwrap(); let truncate = Instruction::Truncate { value: arguments[0], @@ -352,8 +350,8 @@ pub(super) fn simplify_call( Intrinsic::AsWitness => SimplifyResult::None, Intrinsic::IsUnconstrained => SimplifyResult::None, Intrinsic::DerivePedersenGenerators => { - if let Some(Type::Array(_, len)) = ctrl_typevars.unwrap().first() { - simplify_derive_generators(dfg, arguments, *len as u32, block, call_stack) + if let Some(Type::Array(_, len)) = return_type.clone() { + simplify_derive_generators(dfg, arguments, len as u32, block, call_stack) } else { unreachable!("Derive Pedersen Generators must return an array"); } @@ -370,7 +368,19 @@ pub(super) fn simplify_call( } Intrinsic::ArrayRefCount => SimplifyResult::None, Intrinsic::SliceRefCount => SimplifyResult::None, + }; + + if let (Some(expected_types), SimplifyResult::SimplifiedTo(result)) = + (return_type, &simplified_result) + { + assert_eq!( + dfg.type_of_value(*result), + expected_types, + "Simplification should not alter return type" + ); } + + simplified_result } /// Slices have a tuple structure (slice length, slice contents) to enable logic diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs index 4f2a31e2fb0..301b75e0bd4 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -48,7 +48,7 @@ pub(super) fn simplify_ec_add( let result_x = dfg.make_constant(result_x, Type::field()); let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); let typ = Type::Array(Arc::new(vec![Type::field()]), 3); @@ -107,7 +107,7 @@ pub(super) fn simplify_msm( let result_x = dfg.make_constant(result_x, Type::field()); let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); let elements = im::vector![result_x, result_y, result_is_infinity]; let typ = Type::Array(Arc::new(vec![Type::field()]), 3); diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 41c84c935b1..ceda0c6272f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -149,8 +149,7 @@ impl Function { use_constraint_info: bool, brillig_info: Option, ) { - let mut context = Context::new(use_constraint_info, brillig_info); - let mut dom = DominatorTree::with_function(self); + let mut context = Context::new(self, use_constraint_info, brillig_info); context.block_queue.push_back(self.entry_block()); while let Some(block) = context.block_queue.pop_front() { @@ -159,7 +158,7 @@ impl Function { } context.visited_blocks.insert(block); - context.fold_constants_in_block(&mut self.dfg, &mut dom, block); + context.fold_constants_in_block(self, block); } } } @@ -173,15 +172,22 @@ struct Context<'a> { /// Contains sets of values which are constrained to be equivalent to each other. /// - /// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`. + /// The mapping's structure is `side_effects_enabled_var => (constrained_value => [(block, simplified_value)])`. /// /// We partition the maps of constrained values according to the side-effects flag at the point /// at which the values are constrained. This prevents constraints which are only sometimes enforced /// being used to modify the rest of the program. - constraint_simplification_mappings: ConstraintSimplificationCache, + /// + /// We also keep track of how a value was simplified to other values per block. That is, + /// a same ValueId could have been simplified to one value in one block and to another value + /// in another block. + constraint_simplification_mappings: + HashMap>>, // Cache of instructions without any side-effects along with their outputs. cached_instruction_results: InstructionResultCache, + + dom: DominatorTree, } #[derive(Copy, Clone)] @@ -190,56 +196,9 @@ pub(crate) struct BrilligInfo<'a> { brillig_functions: &'a BTreeMap, } -/// Records a simplified equivalents of an [`Instruction`] in the blocks -/// where the constraint that advised the simplification has been encountered. -/// -/// For more information see [`ConstraintSimplificationCache`]. -#[derive(Default)] -struct SimplificationCache { - /// Simplified expressions where we found them. - /// - /// It will always have at least one value because `add` is called - /// after the default is constructed. - simplifications: HashMap, -} - -impl SimplificationCache { - /// Called with a newly encountered simplification. - fn add(&mut self, dfg: &DataFlowGraph, simple: ValueId, block: BasicBlockId) { - self.simplifications - .entry(block) - .and_modify(|existing| { - // `SimplificationCache` may already hold a simplification in this block - // so we check whether `simple` is a better simplification than the current one. - if let Some((_, simpler)) = simplify(dfg, *existing, simple) { - *existing = simpler; - }; - }) - .or_insert(simple); - } - - /// Try to find a simplification in a visible block. - fn get(&self, block: BasicBlockId, dom: &DominatorTree) -> Option { - // Deterministically walk up the dominator chain until we encounter a block that contains a simplification. - dom.find_map_dominator(block, |b| self.simplifications.get(&b).cloned()) - } -} - -/// HashMap from `(side_effects_enabled_var, Instruction)` to a simplified expression that it can -/// be replaced with based on constraints that testify to their equivalence, stored together -/// with the set of blocks at which this constraint has been observed. -/// -/// Only blocks dominated by one in the cache should have access to this information, otherwise -/// we create a sort of time paradox where we replace an instruction with a constant we believe -/// it _should_ equal to, without ever actually producing and asserting the value. -type ConstraintSimplificationCache = HashMap>; - /// HashMap from `(Instruction, side_effects_enabled_var)` to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. /// -/// The `side_effects_enabled_var` is optional because we only use them when `Instruction::requires_acir_gen_predicate` -/// is true _and_ the constraint information is also taken into account. -/// /// In addition to each result, the original BasicBlockId is stored as well. This allows us /// to deduplicate instructions across blocks as long as the new block dominates the original. type InstructionResultCache = HashMap, ResultCache>>; @@ -249,11 +208,15 @@ type InstructionResultCache = HashMap, Resu /// For more information see [`InstructionResultCache`]. #[derive(Default)] struct ResultCache { - result: Option<(BasicBlockId, Vec)>, + results: Vec<(BasicBlockId, Vec)>, } impl<'brillig> Context<'brillig> { - fn new(use_constraint_info: bool, brillig_info: Option>) -> Self { + fn new( + function: &Function, + use_constraint_info: bool, + brillig_info: Option>, + ) -> Self { Self { use_constraint_info, brillig_info, @@ -261,65 +224,52 @@ impl<'brillig> Context<'brillig> { block_queue: Default::default(), constraint_simplification_mappings: Default::default(), cached_instruction_results: Default::default(), + dom: DominatorTree::with_function(function), } } - fn fold_constants_in_block( - &mut self, - dfg: &mut DataFlowGraph, - dom: &mut DominatorTree, - block: BasicBlockId, - ) { - let instructions = dfg[block].take_instructions(); + fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { + let instructions = function.dfg[block].take_instructions(); - // Default side effect condition variable with an enabled state. - let mut side_effects_enabled_var = dfg.make_constant(FieldElement::one(), Type::bool()); + let mut side_effects_enabled_var = + function.dfg.make_constant(FieldElement::one(), Type::bool()); for instruction_id in instructions { self.fold_constants_into_instruction( - dfg, - dom, + &mut function.dfg, block, instruction_id, &mut side_effects_enabled_var, ); } - self.block_queue.extend(dfg[block].successors()); + self.block_queue.extend(function.dfg[block].successors()); } fn fold_constants_into_instruction( &mut self, dfg: &mut DataFlowGraph, - dom: &mut DominatorTree, - mut block: BasicBlockId, + block: BasicBlockId, id: InstructionId, side_effects_enabled_var: &mut ValueId, ) { - let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); - - let instruction = - Self::resolve_instruction(id, block, dfg, dom, constraint_simplification_mapping); - + let constraint_simplification_mapping = + self.constraint_simplification_mappings.get(side_effects_enabled_var); + let instruction = Self::resolve_instruction( + id, + block, + dfg, + &mut self.dom, + constraint_simplification_mapping, + ); let old_results = dfg.instruction_results(id).to_vec(); // If a copy of this instruction exists earlier in the block, then reuse the previous results. - if let Some(cache_result) = - self.get_cached(dfg, dom, &instruction, *side_effects_enabled_var, block) + if let Some(cached_results) = + self.get_cached(dfg, &instruction, *side_effects_enabled_var, block) { - match cache_result { - CacheResult::Cached(cached) => { - Self::replace_result_ids(dfg, &old_results, cached); - return; - } - CacheResult::NeedToHoistToCommonBlock(dominator) => { - // Just change the block to insert in the common dominator instead. - // This will only move the current instance of the instruction right now. - // When constant folding is run a second time later on, it'll catch - // that the previous instance can be deduplicated to this instance. - block = dominator; - } - } - }; + Self::replace_result_ids(dfg, &old_results, cached_results); + return; + } let new_results = // First try to inline a call to a brillig function with all constant arguments. @@ -364,7 +314,7 @@ impl<'brillig> Context<'brillig> { block: BasicBlockId, dfg: &DataFlowGraph, dom: &mut DominatorTree, - constraint_simplification_mapping: &HashMap, + constraint_simplification_mapping: Option<&HashMap>>, ) -> Instruction { let instruction = dfg[instruction_id].clone(); @@ -374,28 +324,30 @@ impl<'brillig> Context<'brillig> { // This allows us to reach a stable final `ValueId` for each instruction input as we add more // constraints to the cache. fn resolve_cache( - block: BasicBlockId, dfg: &DataFlowGraph, dom: &mut DominatorTree, - cache: &HashMap, + cache: Option<&HashMap>>, value_id: ValueId, + block: BasicBlockId, ) -> ValueId { let resolved_id = dfg.resolve(value_id); - match cache.get(&resolved_id) { - Some(simplification_cache) => { - if let Some(simplified) = simplification_cache.get(block, dom) { - resolve_cache(block, dfg, dom, cache, simplified) - } else { - resolved_id - } + let Some(cached_values) = cache.and_then(|cache| cache.get(&resolved_id)) else { + return resolved_id; + }; + + for (cached_block, cached_value) in cached_values { + // We can only use the simplified value if it was simplified in a block that dominates the current one + if dom.dominates(*cached_block, block) { + return resolve_cache(dfg, dom, cache, *cached_value, block); } - None => resolved_id, } + + resolved_id } // Resolve any inputs to ensure that we're comparing like-for-like instructions. instruction.map_values(|value_id| { - resolve_cache(block, dfg, dom, constraint_simplification_mapping, value_id) + resolve_cache(dfg, dom, constraint_simplification_mapping, value_id, block) }) } @@ -446,7 +398,7 @@ impl<'brillig> Context<'brillig> { self.get_constraint_map(side_effects_enabled_var) .entry(complex) .or_default() - .add(dfg, simple, block); + .push((block, simple)); } } } @@ -468,12 +420,10 @@ impl<'brillig> Context<'brillig> { } } - /// Get the simplification mapping from complex to simpler instructions, - /// which all depend on the same side effect condition variable. fn get_constraint_map( &mut self, side_effects_enabled_var: ValueId, - ) -> &mut HashMap { + ) -> &mut HashMap> { self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() } @@ -488,20 +438,19 @@ impl<'brillig> Context<'brillig> { } } - /// Get a cached result if it can be used in this context. - fn get_cached( - &self, + fn get_cached<'a>( + &'a mut self, dfg: &DataFlowGraph, - dom: &mut DominatorTree, instruction: &Instruction, side_effects_enabled_var: ValueId, block: BasicBlockId, - ) -> Option { + ) -> Option<&'a [ValueId]> { let results_for_instruction = self.cached_instruction_results.get(instruction)?; + let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); let predicate = predicate.then_some(side_effects_enabled_var); - results_for_instruction.get(&predicate)?.get(block, dom, instruction.has_side_effects(dfg)) + results_for_instruction.get(&predicate)?.get(block, &mut self.dom) } /// Checks if the given instruction is a call to a brillig function with all constant arguments. @@ -668,9 +617,7 @@ impl<'brillig> Context<'brillig> { impl ResultCache { /// Records that an `Instruction` in block `block` produced the result values `results`. fn cache(&mut self, block: BasicBlockId, results: Vec) { - if self.result.is_none() { - self.result = Some((block, results)); - } + self.results.push((block, results)); } /// Returns a set of [`ValueId`]s produced from a copy of this [`Instruction`] which sits @@ -679,23 +626,13 @@ impl ResultCache { /// We require that the cached instruction's block dominates `block` in order to avoid /// cycles causing issues (e.g. two instructions being replaced with the results of each other /// such that neither instruction exists anymore.) - fn get( - &self, - block: BasicBlockId, - dom: &mut DominatorTree, - has_side_effects: bool, - ) -> Option { - self.result.as_ref().and_then(|(origin_block, results)| { + fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<&[ValueId]> { + for (origin_block, results) in &self.results { if dom.dominates(*origin_block, block) { - Some(CacheResult::Cached(results)) - } else if !has_side_effects { - // Insert a copy of this instruction in the common dominator - let dominator = dom.common_dominator(*origin_block, block); - Some(CacheResult::NeedToHoistToCommonBlock(dominator)) - } else { - None + return Some(results); } - }) + } + None } } @@ -1003,22 +940,32 @@ mod test { // Regression for #4600 #[test] fn array_get_regression() { + // fn main f0 { + // b0(v0: u1, v1: u64): + // enable_side_effects_if v0 + // v2 = make_array [Field 0, Field 1] + // v3 = array_get v2, index v1 + // v4 = not v0 + // enable_side_effects_if v4 + // v5 = array_get v2, index v1 + // } + // // We want to make sure after constant folding both array_gets remain since they are // under different enable_side_effects_if contexts and thus one may be disabled while // the other is not. If one is removed, it is possible e.g. v4 is replaced with v2 which // is disabled (only gets from index 0) and thus returns the wrong result. let src = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u64): - enable_side_effects v0 - v4 = make_array [Field 0, Field 1] : [Field; 2] - v5 = array_get v4, index v1 -> Field - v6 = not v0 - enable_side_effects v6 - v7 = array_get v4, index v1 -> Field - return - } - "; + acir(inline) fn main f0 { + b0(v0: u1, v1: u64): + enable_side_effects v0 + v4 = make_array [Field 0, Field 1] : [Field; 2] + v5 = array_get v4, index v1 -> Field + v6 = not v0 + enable_side_effects v6 + v7 = array_get v4, index v1 -> Field + return + } + "; let ssa = Ssa::from_str(src).unwrap(); // Expected output is unchanged @@ -1085,6 +1032,7 @@ mod test { // v5 = call keccakf1600(v1) // v6 = call keccakf1600(v2) // } + // // Here we're checking a situation where two identical arrays are being initialized twice and being assigned separate `ValueId`s. // This would result in otherwise identical instructions not being deduplicated. let main_id = Id::test_new(0); @@ -1135,106 +1083,6 @@ mod test { assert_eq!(ending_instruction_count, 2); } - #[test] - fn deduplicate_across_blocks() { - // fn main f0 { - // b0(v0: u1): - // v1 = not v0 - // jmp b1() - // b1(): - // v2 = not v0 - // return v2 - // } - let main_id = Id::test_new(0); - - // Compiling main - let mut builder = FunctionBuilder::new("main".into(), main_id); - let b1 = builder.insert_block(); - - let v0 = builder.add_parameter(Type::bool()); - let _v1 = builder.insert_not(v0); - builder.terminate_with_jmp(b1, Vec::new()); - - builder.switch_to_block(b1); - let v2 = builder.insert_not(v0); - builder.terminate_with_return(vec![v2]); - - let ssa = builder.finish(); - let main = ssa.main(); - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); - assert_eq!(main.dfg[b1].instructions().len(), 1); - - // Expected output: - // - // fn main f0 { - // b0(v0: u1): - // v1 = not v0 - // jmp b1() - // b1(): - // return v1 - // } - let ssa = ssa.fold_constants_using_constraints(); - let main = ssa.main(); - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); - assert_eq!(main.dfg[b1].instructions().len(), 0); - } - - #[test] - fn deduplicate_across_non_dominated_blocks() { - let src = " - brillig(inline) fn main f0 { - b0(v0: u32): - v2 = lt u32 1000, v0 - jmpif v2 then: b1, else: b2 - b1(): - v4 = add v0, u32 1 - v5 = lt v0, v4 - constrain v5 == u1 1 - jmp b2() - b2(): - v7 = lt u32 1000, v0 - jmpif v7 then: b3, else: b4 - b3(): - v8 = add v0, u32 1 - v9 = lt v0, v8 - constrain v9 == u1 1 - jmp b4() - b4(): - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - - // v4 has been hoisted, although: - // - v5 has not yet been removed since it was encountered earlier in the program - // - v8 hasn't been recognized as a duplicate of v6 yet since they still reference v4 and - // v5 respectively - let expected = " - brillig(inline) fn main f0 { - b0(v0: u32): - v2 = lt u32 1000, v0 - v4 = add v0, u32 1 - jmpif v2 then: b1, else: b2 - b1(): - v5 = add v0, u32 1 - v6 = lt v0, v5 - constrain v6 == u1 1 - jmp b2() - b2(): - jmpif v2 then: b3, else: b4 - b3(): - v8 = lt v0, v4 - constrain v8 == u1 1 - jmp b4() - b4(): - return - } - "; - - let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, expected); - } - #[test] fn inlines_brillig_call_without_arguments() { let src = " @@ -1412,87 +1260,46 @@ mod test { } #[test] - fn does_not_use_cached_constrain_in_block_that_is_not_dominated() { - let src = " - brillig(inline) fn main f0 { - b0(v0: Field, v1: Field): - v3 = eq v0, Field 0 - jmpif v3 then: b1, else: b2 - b1(): - v5 = eq v1, Field 1 - constrain v1 == Field 1 - jmp b2() - b2(): - v6 = eq v1, Field 0 - constrain v1 == Field 0 - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, src); - } + fn deduplicate_across_blocks() { + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // v2 = not v0 + // return v2 + // } + let main_id = Id::test_new(0); - #[test] - fn does_not_hoist_constrain_to_common_ancestor() { - let src = " - brillig(inline) fn main f0 { - b0(v0: Field, v1: Field): - v3 = eq v0, Field 0 - jmpif v3 then: b1, else: b2 - b1(): - constrain v1 == Field 1 - jmp b2() - b2(): - jmpif v0 then: b3, else: b4 - b3(): - constrain v1 == Field 1 // This was incorrectly hoisted to b0 but this condition is not valid when going b0 -> b2 -> b4 - jmp b4() - b4(): - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, src); - } + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let b1 = builder.insert_block(); - #[test] - fn deduplicates_side_effecting_intrinsics() { - let src = " - // After EnableSideEffectsIf removal: - acir(inline) fn main f0 { - b0(v0: Field, v1: Field, v2: u1): - v4 = call is_unconstrained() -> u1 - v7 = call to_be_radix(v0, u32 256) -> [u8; 1] // `a.to_be_radix(256)`; - inc_rc v7 - v8 = call to_be_radix(v0, u32 256) -> [u8; 1] // duplicate load of `a` - inc_rc v8 - v9 = cast v2 as Field // `if c { a.to_be_radix(256) }` - v10 = mul v0, v9 // attaching `c` to `a` - v11 = call to_be_radix(v10, u32 256) -> [u8; 1] // calling `to_radix(c * a)` - inc_rc v11 - enable_side_effects v2 // side effect var for `c` shifted down by removal - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let expected = " - acir(inline) fn main f0 { - b0(v0: Field, v1: Field, v2: u1): - v4 = call is_unconstrained() -> u1 - v7 = call to_be_radix(v0, u32 256) -> [u8; 1] - inc_rc v7 - inc_rc v7 - v8 = cast v2 as Field - v9 = mul v0, v8 - v10 = call to_be_radix(v9, u32 256) -> [u8; 1] - inc_rc v10 - enable_side_effects v2 - return - } - "; + let v0 = builder.add_parameter(Type::bool()); + let _v1 = builder.insert_not(v0); + builder.terminate_with_jmp(b1, Vec::new()); + + builder.switch_to_block(b1); + let v2 = builder.insert_not(v0); + builder.terminate_with_return(vec![v2]); + + let ssa = builder.finish(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 1); + + // Expected output: + // + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // return v1 + // } let ssa = ssa.fold_constants_using_constraints(); - assert_normalized_ssa_equals(ssa, expected); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 0); } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index c8dd0e3c5a3..5d114672a55 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -131,7 +131,8 @@ //! v11 = mul v4, Field 12 //! v12 = add v10, v11 //! store v12 at v5 (new store) -use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use fxhash::FxHashMap as HashMap; +use std::collections::{BTreeMap, HashSet}; use acvm::{acir::AcirField, acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; @@ -185,6 +186,18 @@ struct Context<'f> { /// Maps start of branch -> end of branch branch_ends: HashMap, + /// Maps an address to the old and new value of the element at that address + /// These only hold stores for one block at a time and is cleared + /// between inlining of branches. + store_values: HashMap, + + /// Stores all allocations local to the current branch. + /// Since these branches are local to the current branch (ie. only defined within one branch of + /// an if expression), they should not be merged with their previous value or stored value in + /// the other branch since there is no such value. The ValueId here is that which is returned + /// by the allocate instruction. + local_allocations: HashSet, + /// A stack of each jmpif condition that was taken to reach a particular point in the program. /// When two branches are merged back into one, this constitutes a join point, and is analogous /// to the rest of the program after an if statement. When such a join point / end block is @@ -201,15 +214,13 @@ struct Context<'f> { /// When processing a block, we pop this stack to get its arguments /// and at the end we push the arguments for his successor arguments_stack: Vec>, +} - /// Stores all allocations local to the current branch. - /// - /// Since these branches are local to the current branch (i.e. only defined within one branch of - /// an if expression), they should not be merged with their previous value or stored value in - /// the other branch since there is no such value. - /// - /// The `ValueId` here is that which is returned by the allocate instruction. - local_allocations: HashSet, +#[derive(Clone)] +pub(crate) struct Store { + old_value: ValueId, + new_value: ValueId, + call_stack: CallStack, } #[derive(Clone)] @@ -220,6 +231,8 @@ struct ConditionalBranch { old_condition: ValueId, // The condition of the branch condition: ValueId, + // The store values accumulated when processing the branch + store_values: HashMap, // The allocations accumulated when processing the branch local_allocations: HashSet, } @@ -250,11 +263,12 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap Context<'f> { let old_condition = *condition; let then_condition = self.inserter.resolve(old_condition); + let old_stores = std::mem::take(&mut self.store_values); let old_allocations = std::mem::take(&mut self.local_allocations); let branch = ConditionalBranch { old_condition, condition: self.link_condition(then_condition), - last_block: *then_destination, + store_values: old_stores, local_allocations: old_allocations, + last_block: *then_destination, }; let cond_context = ConditionalContext { condition: then_condition, @@ -457,12 +473,19 @@ impl<'f> Context<'f> { ); let else_condition = self.link_condition(else_condition); + // Make sure the else branch sees the previous values of each store + // rather than any values created in the 'then' branch. + let old_stores = std::mem::take(&mut cond_context.then_branch.store_values); + cond_context.then_branch.store_values = std::mem::take(&mut self.store_values); + self.undo_stores_in_then_branch(&cond_context.then_branch.store_values); + let old_allocations = std::mem::take(&mut self.local_allocations); let else_branch = ConditionalBranch { old_condition: cond_context.then_branch.old_condition, condition: else_condition, - last_block: *block, + store_values: old_stores, local_allocations: old_allocations, + last_block: *block, }; cond_context.then_branch.local_allocations.clear(); cond_context.else_branch = Some(else_branch); @@ -486,8 +509,10 @@ impl<'f> Context<'f> { } let mut else_branch = cond_context.else_branch.unwrap(); + let stores_in_branch = std::mem::replace(&mut self.store_values, else_branch.store_values); self.local_allocations = std::mem::take(&mut else_branch.local_allocations); else_branch.last_block = *block; + else_branch.store_values = stores_in_branch; cond_context.else_branch = Some(else_branch); // We must remember to reset whether side effects are enabled when both branches @@ -555,6 +580,8 @@ impl<'f> Context<'f> { .first() }); + let call_stack = cond_context.call_stack; + self.merge_stores(cond_context.then_branch, cond_context.else_branch, call_stack); self.arguments_stack.pop(); self.arguments_stack.pop(); self.arguments_stack.push(args); @@ -609,29 +636,116 @@ impl<'f> Context<'f> { self.insert_instruction_with_typevars(enable_side_effects, None, call_stack); } + /// Merge any store instructions found in each branch. + /// + /// This function relies on the 'then' branch being merged before the 'else' branch of a jmpif + /// instruction. If this ordering is changed, the ordering that store values are merged within + /// this function also needs to be changed to reflect that. + fn merge_stores( + &mut self, + then_branch: ConditionalBranch, + else_branch: Option, + call_stack: CallStack, + ) { + // Address -> (then_value, else_value, value_before_the_if) + let mut new_map = BTreeMap::new(); + + for (address, store) in then_branch.store_values { + new_map.insert(address, (store.new_value, store.old_value, store.old_value)); + } + + if else_branch.is_some() { + for (address, store) in else_branch.clone().unwrap().store_values { + if let Some(entry) = new_map.get_mut(&address) { + entry.1 = store.new_value; + } else { + new_map.insert(address, (store.old_value, store.new_value, store.old_value)); + } + } + } + + let then_condition = then_branch.condition; + let block = self.inserter.function.entry_block(); + + // Merging must occur in a separate loop as we cannot borrow `self` as mutable while `value_merger` does + let mut new_values = HashMap::default(); + for (address, (then_case, else_case, _)) in &new_map { + let instruction = Instruction::IfElse { + then_condition, + then_value: *then_case, + else_value: *else_case, + }; + let dfg = &mut self.inserter.function.dfg; + let value = dfg + .insert_instruction_and_results(instruction, block, None, call_stack.clone()) + .first(); + + new_values.insert(address, value); + } + + // Replace stores with new merged values + for (address, (_, _, old_value)) in &new_map { + let value = new_values[address]; + let address = *address; + self.insert_instruction_with_typevars( + Instruction::Store { address, value }, + None, + call_stack.clone(), + ); + + if let Some(store) = self.store_values.get_mut(&address) { + store.new_value = value; + } else { + self.store_values.insert( + address, + Store { + old_value: *old_value, + new_value: value, + call_stack: call_stack.clone(), + }, + ); + } + } + } + + fn remember_store(&mut self, address: ValueId, new_value: ValueId, call_stack: CallStack) { + if !self.local_allocations.contains(&address) { + if let Some(store_value) = self.store_values.get_mut(&address) { + store_value.new_value = new_value; + } else { + let load = Instruction::Load { address }; + + let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]); + let old_value = self + .insert_instruction_with_typevars(load.clone(), load_type, call_stack.clone()) + .first(); + + self.store_values.insert(address, Store { old_value, new_value, call_stack }); + } + } + } + /// Push the given instruction to the end of the entry block of the current function. /// /// Note that each ValueId of the instruction will be mapped via self.inserter.resolve. /// As a result, the instruction that will be pushed will actually be a new instruction /// with a different InstructionId from the original. The results of the given instruction /// will also be mapped to the results of the new instruction. - /// - /// `previous_allocate_result` should only be set to the result of an allocate instruction - /// if that instruction was the instruction immediately previous to this one - if there are - /// any instructions in between it should be None. - fn push_instruction(&mut self, id: InstructionId) { + fn push_instruction(&mut self, id: InstructionId) -> Vec { let (instruction, call_stack) = self.inserter.map_instruction(id); let instruction = self.handle_instruction_side_effects(instruction, call_stack.clone()); + let is_allocate = matches!(instruction, Instruction::Allocate); - let instruction_is_allocate = matches!(&instruction, Instruction::Allocate); let entry = self.inserter.function.entry_block(); let results = self.inserter.push_instruction_value(instruction, id, entry, call_stack); // Remember an allocate was created local to this branch so that we do not try to merge store // values across branches for it later. - if instruction_is_allocate { + if is_allocate { self.local_allocations.insert(results.first()); } + + results.results().into_owned() } /// If we are currently in a branch, we need to modify constrain instructions @@ -668,32 +782,8 @@ impl<'f> Context<'f> { Instruction::Constrain(lhs, rhs, message) } Instruction::Store { address, value } => { - // If this instruction immediately follows an allocate, and stores to that - // address there is no previous value to load and we don't need a merge anyway. - if self.local_allocations.contains(&address) { - Instruction::Store { address, value } - } else { - // Instead of storing `value`, store `if condition { value } else { previous_value }` - let typ = self.inserter.function.dfg.type_of_value(value); - let load = Instruction::Load { address }; - let previous_value = self - .insert_instruction_with_typevars( - load, - Some(vec![typ]), - call_stack.clone(), - ) - .first(); - - let instruction = Instruction::IfElse { - then_condition: condition, - then_value: value, - - else_value: previous_value, - }; - - let updated_value = self.insert_instruction(instruction, call_stack); - Instruction::Store { address, value: updated_value } - } + self.remember_store(address, value, call_stack); + Instruction::Store { address, value } } Instruction::RangeCheck { value, max_bit_size, assert_message } => { // Replace value with `value * predicate` to zero out value when predicate is inactive. @@ -815,11 +905,23 @@ impl<'f> Context<'f> { call_stack, ) } + + fn undo_stores_in_then_branch(&mut self, store_values: &HashMap) { + for (address, store) in store_values { + let address = *address; + let value = store.old_value; + let instruction = Instruction::Store { address, value }; + // Considering the location of undoing a store to be the same as the original store. + self.insert_instruction_with_typevars(instruction, None, store.call_stack.clone()); + } + } } #[cfg(test)] mod test { - use acvm::acir::AcirField; + use std::sync::Arc; + + use acvm::{acir::AcirField, FieldElement}; use crate::ssa::{ function_builder::FunctionBuilder, @@ -921,13 +1023,15 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - v3 = cast v0 as Field - v5 = sub Field 5, v2 - v6 = mul v3, v5 - v7 = add v2, v6 - store v7 at v1 - v8 = not v0 + store Field 5 at v1 + v4 = not v0 + store v2 at v1 enable_side_effects u1 1 + v6 = cast v0 as Field + v7 = sub Field 5, v2 + v8 = mul v6, v7 + v9 = add v2, v8 + store v9 at v1 return } "; @@ -958,20 +1062,17 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - v3 = cast v0 as Field - v5 = sub Field 5, v2 - v6 = mul v3, v5 - v7 = add v2, v6 - store v7 at v1 - v8 = not v0 - enable_side_effects v8 - v9 = load v1 -> Field - v10 = cast v8 as Field - v12 = sub Field 6, v9 - v13 = mul v10, v12 - v14 = add v9, v13 - store v14 at v1 + store Field 5 at v1 + v4 = not v0 + store v2 at v1 + enable_side_effects v4 + v5 = load v1 -> Field + store Field 6 at v1 enable_side_effects u1 1 + v8 = cast v0 as Field + v10 = mul v8, Field -1 + v11 = add Field 6, v10 + store v11 at v1 return } "; @@ -1014,123 +1115,84 @@ mod test { // b7 b8 // ↘ ↙ // b9 - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); - let b4 = builder.insert_block(); - let b5 = builder.insert_block(); - let b6 = builder.insert_block(); - let b7 = builder.insert_block(); - let b8 = builder.insert_block(); - let b9 = builder.insert_block(); - - let c1 = builder.add_parameter(Type::bool()); - let c4 = builder.add_parameter(Type::bool()); - - let r1 = builder.insert_allocate(Type::field()); - - let store_value = |builder: &mut FunctionBuilder, value: u128| { - let value = builder.field_constant(value); - builder.insert_store(r1, value); - }; - - let test_function = Id::test_new(1); - - let call_test_function = |builder: &mut FunctionBuilder, block: u128| { - let block = builder.field_constant(block); - let load = builder.insert_load(r1, Type::field()); - builder.insert_call(test_function, vec![block, load], Vec::new()); - }; - - let switch_store_and_test_function = - |builder: &mut FunctionBuilder, block, block_number: u128| { - builder.switch_to_block(block); - store_value(builder, block_number); - call_test_function(builder, block_number); - }; - - let switch_and_test_function = - |builder: &mut FunctionBuilder, block, block_number: u128| { - builder.switch_to_block(block); - call_test_function(builder, block_number); - }; - - store_value(&mut builder, 0); - call_test_function(&mut builder, 0); - builder.terminate_with_jmp(b1, vec![]); - - switch_store_and_test_function(&mut builder, b1, 1); - builder.terminate_with_jmpif(c1, b2, b3); - - switch_store_and_test_function(&mut builder, b2, 2); - builder.terminate_with_jmp(b4, vec![]); - - switch_store_and_test_function(&mut builder, b3, 3); - builder.terminate_with_jmp(b8, vec![]); - - switch_and_test_function(&mut builder, b4, 4); - builder.terminate_with_jmpif(c4, b5, b6); - - switch_store_and_test_function(&mut builder, b5, 5); - builder.terminate_with_jmp(b7, vec![]); - - switch_store_and_test_function(&mut builder, b6, 6); - builder.terminate_with_jmp(b7, vec![]); - - switch_and_test_function(&mut builder, b7, 7); - builder.terminate_with_jmp(b9, vec![]); - - switch_and_test_function(&mut builder, b8, 8); - builder.terminate_with_jmp(b9, vec![]); - - switch_and_test_function(&mut builder, b9, 9); - let load = builder.insert_load(r1, Type::field()); - builder.terminate_with_return(vec![load]); + let src = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = allocate -> &mut Field + store Field 0 at v2 + v4 = load v2 -> Field + // call v1(Field 0, v4) + jmp b1() + b1(): + store Field 1 at v2 + v6 = load v2 -> Field + // call v1(Field 1, v6) + jmpif v0 then: b2, else: b3 + b2(): + store Field 2 at v2 + v8 = load v2 -> Field + // call v1(Field 2, v8) + jmp b4() + b4(): + v12 = load v2 -> Field + // call v1(Field 4, v12) + jmpif v1 then: b5, else: b6 + b5(): + store Field 5 at v2 + v14 = load v2 -> Field + // call v1(Field 5, v14) + jmp b7() + b7(): + v18 = load v2 -> Field + // call v1(Field 7, v18) + jmp b9() + b9(): + v22 = load v2 -> Field + // call v1(Field 9, v22) + v23 = load v2 -> Field + return v23 + b6(): + store Field 6 at v2 + v16 = load v2 -> Field + // call v1(Field 6, v16) + jmp b7() + b3(): + store Field 3 at v2 + v10 = load v2 -> Field + // call v1(Field 3, v10) + jmp b8() + b8(): + v20 = load v2 -> Field + // call v1(Field 8, v20) + jmp b9() + } + "; - let ssa = builder.finish().flatten_cfg().mem2reg(); + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.flatten_cfg().mem2reg(); // Expected results after mem2reg removes the allocation and each load and store: - // - // fn main f0 { - // b0(v0: u1, v1: u1): - // call test_function(Field 0, Field 0) - // call test_function(Field 1, Field 1) - // enable_side_effects v0 - // call test_function(Field 2, Field 2) - // call test_function(Field 4, Field 2) - // v29 = and v0, v1 - // enable_side_effects v29 - // call test_function(Field 5, Field 5) - // v32 = not v1 - // v33 = and v0, v32 - // enable_side_effects v33 - // call test_function(Field 6, Field 6) - // enable_side_effects v0 - // v36 = mul v1, Field 5 - // v37 = mul v32, Field 2 - // v38 = add v36, v37 - // v39 = mul v1, Field 5 - // v40 = mul v32, Field 6 - // v41 = add v39, v40 - // call test_function(Field 7, v42) - // v43 = not v0 - // enable_side_effects v43 - // store Field 3 at v2 - // call test_function(Field 3, Field 3) - // call test_function(Field 8, Field 3) - // enable_side_effects Field 1 - // v47 = mul v0, v41 - // v48 = mul v43, Field 1 - // v49 = add v47, v48 - // v50 = mul v0, v44 - // v51 = mul v43, Field 3 - // v52 = add v50, v51 - // call test_function(Field 9, v53) - // return v54 - // } + let expected = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = allocate -> &mut Field + enable_side_effects v0 + v3 = mul v0, v1 + enable_side_effects v3 + v4 = not v1 + v5 = mul v0, v4 + enable_side_effects v0 + v6 = cast v3 as Field + v8 = mul v6, Field -1 + v10 = add Field 6, v8 + v11 = not v0 + enable_side_effects u1 1 + v13 = cast v0 as Field + v15 = sub v10, Field 3 + v16 = mul v13, v15 + v17 = add Field 3, v16 + return v17 + }"; let main = ssa.main(); let ret = match main.dfg[main.entry_block()].terminator() { @@ -1139,7 +1201,12 @@ mod test { }; let merged_values = get_all_constants_reachable_from_instruction(&main.dfg, ret); - assert_eq!(merged_values, vec![1, 3, 5, 6]); + assert_eq!( + merged_values, + vec![FieldElement::from(3u128), FieldElement::from(6u128), -FieldElement::from(1u128)] + ); + + assert_normalized_ssa_equals(ssa, expected); } #[test] @@ -1220,7 +1287,7 @@ mod test { fn get_all_constants_reachable_from_instruction( dfg: &DataFlowGraph, value: ValueId, - ) -> Vec { + ) -> Vec { match dfg[value] { Value::Instruction { instruction, .. } => { let mut values = vec![]; @@ -1238,7 +1305,7 @@ mod test { values.dedup(); values } - Value::NumericConstant { constant, .. } => vec![constant.to_u128()], + Value::NumericConstant { constant, .. } => vec![constant], _ => Vec::new(), } } @@ -1277,74 +1344,63 @@ mod test { fn should_not_merge_incorrectly_to_false() { // Regression test for #1792 // Tests that it does not simplify a true constraint an always-false constraint - - let src = " - acir(inline) fn main f0 { - b0(v0: [u8; 2]): - v2 = array_get v0, index u8 0 -> u8 - v3 = cast v2 as u32 - v4 = truncate v3 to 1 bits, max_bit_size: 32 - v5 = cast v4 as u1 - v6 = allocate -> &mut Field - store u8 0 at v6 - jmpif v5 then: b2, else: b1 - b2(): - v7 = cast v2 as Field - v9 = add v7, Field 1 - v10 = cast v9 as u8 - store v10 at v6 - jmp b3() - b3(): - constrain v5 == u1 1 - return - b1(): - store u8 0 at v6 - jmp b3() - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - let expected = " - acir(inline) fn main f0 { - b0(v0: [u8; 2]): - v2 = array_get v0, index u8 0 -> u8 - v3 = cast v2 as u32 - v4 = truncate v3 to 1 bits, max_bit_size: 32 - v5 = cast v4 as u1 - v6 = allocate -> &mut Field - store u8 0 at v6 - enable_side_effects v5 - v7 = cast v2 as Field - v9 = add v7, Field 1 - v10 = cast v9 as u8 - v11 = load v6 -> u8 - v12 = cast v4 as Field - v13 = cast v11 as Field - v14 = sub v9, v13 - v15 = mul v12, v14 - v16 = add v13, v15 - v17 = cast v16 as u8 - store v17 at v6 - v18 = not v5 - enable_side_effects v18 - v19 = load v6 -> u8 - v20 = cast v18 as Field - v21 = cast v19 as Field - v23 = sub Field 0, v21 - v24 = mul v20, v23 - v25 = add v21, v24 - v26 = cast v25 as u8 - store v26 at v6 - enable_side_effects u1 1 - constrain v5 == u1 1 - return - } - "; - + // acir(inline) fn main f1 { + // b0(v0: [u8; 2]): + // v5 = array_get v0, index u8 0 + // v6 = cast v5 as u32 + // v8 = truncate v6 to 1 bits, max_bit_size: 32 + // v9 = cast v8 as u1 + // v10 = allocate + // store u8 0 at v10 + // jmpif v9 then: b2, else: b3 + // b2(): + // v12 = cast v5 as Field + // v13 = add v12, Field 1 + // store v13 at v10 + // jmp b4() + // b4(): + // constrain v9 == u1 1 + // return + // b3(): + // store u8 0 at v10 + // jmp b4() + // } + let main_id = Id::test_new(1); + let mut builder = FunctionBuilder::new("main".into(), main_id); + builder.insert_block(); // b0 + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + let element_type = Arc::new(vec![Type::unsigned(8)]); + let array_type = Type::Array(element_type.clone(), 2); + let array = builder.add_parameter(array_type); + let zero = builder.numeric_constant(0_u128, Type::unsigned(8)); + let v5 = builder.insert_array_get(array, zero, Type::unsigned(8)); + let v6 = builder.insert_cast(v5, Type::unsigned(32)); + let i_two = builder.numeric_constant(2_u128, Type::unsigned(32)); + let v8 = builder.insert_binary(v6, BinaryOp::Mod, i_two); + let v9 = builder.insert_cast(v8, Type::bool()); + let v10 = builder.insert_allocate(Type::field()); + builder.insert_store(v10, zero); + builder.terminate_with_jmpif(v9, b1, b2); + builder.switch_to_block(b1); + let one = builder.field_constant(1_u128); + let v5b = builder.insert_cast(v5, Type::field()); + let v13: Id = builder.insert_binary(v5b, BinaryOp::Add, one); + let v14 = builder.insert_cast(v13, Type::unsigned(8)); + builder.insert_store(v10, v14); + builder.terminate_with_jmp(b3, vec![]); + builder.switch_to_block(b2); + builder.insert_store(v10, zero); + builder.terminate_with_jmp(b3, vec![]); + builder.switch_to_block(b3); + let v_true = builder.numeric_constant(true, Type::bool()); + let v12 = builder.insert_binary(v9, BinaryOp::Eq, v_true); + builder.insert_constrain(v12, v_true, None); + builder.terminate_with_return(vec![]); + let ssa = builder.finish(); let flattened_ssa = ssa.flatten_cfg(); let main = flattened_ssa.main(); - // Now assert that there is not an always-false constraint after flattening: let mut constrain_count = 0; for instruction in main.dfg[main.entry_block()].instructions() { @@ -1358,8 +1414,6 @@ mod test { } } assert_eq!(constrain_count, 1); - - assert_normalized_ssa_equals(flattened_ssa, expected); } #[test] @@ -1495,7 +1549,7 @@ mod test { b2(): return b1(): - jmp b2() + jmp b2() } "; let merged_ssa = Ssa::from_str(src).unwrap(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index f91487fd73e..6cf7070e65e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -1089,6 +1089,7 @@ mod test { } #[test] + #[ignore] #[should_panic( expected = "Attempted to recur more than 1000 times during inlining function 'main': acir(inline) fn main f0 {" )] diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 53a31ae57c1..0690dbbf204 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -18,7 +18,6 @@ //! - A reference with 0 aliases means we were unable to find which reference this reference //! refers to. If such a reference is stored to, we must conservatively invalidate every //! reference in the current block. -//! - We also track the last load instruction to each address per block. //! //! From there, to figure out the value of each reference at the end of block, iterate each instruction: //! - On `Instruction::Allocate`: @@ -29,13 +28,6 @@ //! - Furthermore, if the result of the load is a reference, mark the result as an alias //! of the reference it dereferences to (if known). //! - If which reference it dereferences to is not known, this load result has no aliases. -//! - We also track the last instance of a load instruction to each address in a block. -//! If we see that the last load instruction was from the same address as the current load instruction, -//! we move to replace the result of the current load with the result of the previous load. -//! This removal requires a couple conditions: -//! - No store occurs to that address before the next load, -//! - The address is not used as an argument to a call -//! This optimization helps us remove repeated loads for which there are not known values. //! - On `Instruction::Store { address, value }`: //! - If the address of the store is known: //! - If the address has exactly 1 alias: @@ -48,13 +40,11 @@ //! - Conservatively mark every alias in the block to `Unknown`. //! - Additionally, if there were no Loads to any alias of the address between this Store and //! the previous Store to the same address, the previous store can be removed. -//! - Remove the instance of the last load instruction to the address and its aliases //! - On `Instruction::Call { arguments }`: //! - If any argument of the call is a reference, set the value of each alias of that //! reference to `Unknown` //! - Any builtin functions that may return aliases if their input also contains a //! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc. -//! - Remove the instance of the last load instruction for any reference arguments and their aliases //! //! On a terminator instruction: //! - If the terminator is a `Jmp`: @@ -284,9 +274,6 @@ impl<'f> PerFunctionContext<'f> { if let Some(first_predecessor) = predecessors.next() { let mut first = self.blocks.get(&first_predecessor).cloned().unwrap_or_default(); first.last_stores.clear(); - // Last loads are tracked per block. During unification we are creating a new block from the current one, - // so we must clear the last loads of the current block before we return the new block. - first.last_loads.clear(); // Note that we have to start folding with the first block as the accumulator. // If we started with an empty block, an empty block union'd with any other block @@ -423,28 +410,6 @@ impl<'f> PerFunctionContext<'f> { self.last_loads.insert(address, (instruction, block_id)); } - - // Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads). - // If we do have a repeat load, we can remove the current load and map its result to the previous load's result. - if let Some(last_load) = references.last_loads.get(&address) { - let Instruction::Load { address: previous_address } = - &self.inserter.function.dfg[*last_load] - else { - panic!("Expected a Load instruction here"); - }; - let result = self.inserter.function.dfg.instruction_results(instruction)[0]; - let previous_result = - self.inserter.function.dfg.instruction_results(*last_load)[0]; - if *previous_address == address { - self.inserter.map_value(result, previous_result); - self.instructions_to_remove.insert(instruction); - } - } - // We want to set the load for every load even if the address has a known value - // and the previous load instruction was removed. - // We are safe to still remove a repeat load in this case as we are mapping from the current load's - // result to the previous load, which if it was removed should already have a mapping to the known value. - references.set_last_load(address, instruction); } Instruction::Store { address, value } => { let address = self.inserter.function.dfg.resolve(*address); @@ -470,8 +435,6 @@ impl<'f> PerFunctionContext<'f> { } references.set_known_value(address, value); - // If we see a store to an address, the last load to that address needs to remain. - references.keep_last_load_for(address, self.inserter.function); references.last_stores.insert(address, instruction); } Instruction::Allocate => { @@ -579,9 +542,6 @@ impl<'f> PerFunctionContext<'f> { let value = self.inserter.function.dfg.resolve(*value); references.set_unknown(value); references.mark_value_used(value, self.inserter.function); - - // If a reference is an argument to a call, the last load to that address and its aliases needs to remain. - references.keep_last_load_for(value, self.inserter.function); } } } @@ -612,12 +572,6 @@ impl<'f> PerFunctionContext<'f> { let destination_parameters = self.inserter.function.dfg[*destination].parameters(); assert_eq!(destination_parameters.len(), arguments.len()); - // If we have multiple parameters that alias that same argument value, - // then those parameters also alias each other. - // We save parameters with repeat arguments to later mark those - // parameters as aliasing one another. - let mut arg_set: HashMap> = HashMap::default(); - // Add an alias for each reference parameter for (parameter, argument) in destination_parameters.iter().zip(arguments) { if self.inserter.function.dfg.value_is_reference(*parameter) { @@ -627,27 +581,10 @@ impl<'f> PerFunctionContext<'f> { if let Some(aliases) = references.aliases.get_mut(expression) { // The argument reference is possibly aliased by this block parameter aliases.insert(*parameter); - - // Check if we have seen the same argument - let seen_parameters = arg_set.entry(argument).or_default(); - // Add the current parameter to the parameters we have seen for this argument. - // The previous parameters and the current one alias one another. - seen_parameters.insert(*parameter); } } } } - - // Set the aliases of the parameters - for (_, aliased_params) in arg_set { - for param in aliased_params.iter() { - self.set_aliases( - references, - *param, - AliasSet::known_multiple(aliased_params.clone()), - ); - } - } } TerminatorInstruction::Return { return_values, .. } => { // Removing all `last_stores` for each returned reference is more important here @@ -675,8 +612,6 @@ mod tests { map::Id, types::Type, }, - opt::assert_normalized_ssa_equals, - Ssa, }; #[test] @@ -887,53 +822,88 @@ mod tests { // is later stored in a successor block #[test] fn load_aliases_in_predecessor_block() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = allocate -> &mut Field - store Field 0 at v0 - v2 = allocate -> &mut &mut Field - store v0 at v2 - v3 = load v2 -> &mut Field - v4 = load v2 -> &mut Field - jmp b1() - b1(): - store Field 1 at v3 - store Field 2 at v4 - v7 = load v3 -> Field - v8 = eq v7, Field 2 - return - } - "; + // fn main { + // b0(): + // v0 = allocate + // store Field 0 at v0 + // v2 = allocate + // store v0 at v2 + // v3 = load v2 + // v4 = load v2 + // jmp b1() + // b1(): + // store Field 1 at v3 + // store Field 2 at v4 + // v7 = load v3 + // v8 = eq v7, Field 2 + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let v0 = builder.insert_allocate(Type::field()); + + let zero = builder.field_constant(0u128); + builder.insert_store(v0, zero); + + let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field()))); + builder.insert_store(v2, v0); + + let v3 = builder.insert_load(v2, Type::field()); + let v4 = builder.insert_load(v2, Type::field()); + let b1 = builder.insert_block(); + builder.terminate_with_jmp(b1, vec![]); + + builder.switch_to_block(b1); + + let one = builder.field_constant(1u128); + builder.insert_store(v3, one); + + let two = builder.field_constant(2u128); + builder.insert_store(v4, two); + + let v8 = builder.insert_load(v3, Type::field()); + let _ = builder.insert_binary(v8, BinaryOp::Eq, two); + + builder.terminate_with_return(vec![]); + + let ssa = builder.finish(); + assert_eq!(ssa.main().reachable_blocks().len(), 2); - let mut ssa = Ssa::from_str(src).unwrap(); - let main = ssa.main_mut(); + // Expected result: + // acir fn main f0 { + // b0(): + // v9 = allocate + // store Field 0 at v9 + // v10 = allocate + // jmp b1() + // b1(): + // return + // } + let ssa = ssa.mem2reg(); + println!("{}", ssa); - let instructions = main.dfg[main.entry_block()].instructions(); - assert_eq!(instructions.len(), 6); // The final return is not counted + let main = ssa.main(); + assert_eq!(main.reachable_blocks().len(), 2); // All loads should be removed + assert_eq!(count_loads(main.entry_block(), &main.dfg), 0); + assert_eq!(count_loads(b1, &main.dfg), 0); + // The first store is not removed as it is used as a nested reference in another store. - // We would need to track whether the store where `v0` is the store value gets removed to know whether + // We would need to track whether the store where `v9` is the store value gets removed to know whether // to remove it. + assert_eq!(count_stores(main.entry_block(), &main.dfg), 1); // The first store in b1 is removed since there is another store to the same reference // in the same block, and the store is not needed before the later store. // The rest of the stores are also removed as no loads are done within any blocks // to the stored values. - let expected = " - acir(inline) fn main f0 { - b0(): - v0 = allocate -> &mut Field - store Field 0 at v0 - v2 = allocate -> &mut &mut Field - jmp b1() - b1(): - return - } - "; + assert_eq!(count_stores(b1, &main.dfg), 0); - let ssa = ssa.mem2reg(); - assert_normalized_ssa_equals(ssa, expected); + let b1_instructions = main.dfg[b1].instructions(); + + // We expect the last eq to be optimized out + assert_eq!(b1_instructions.len(), 0); } #[test] @@ -963,7 +933,7 @@ mod tests { // v10 = eq v9, Field 2 // constrain v9 == Field 2 // v11 = load v2 - // v12 = load v11 + // v12 = load v10 // v13 = eq v12, Field 2 // constrain v11 == Field 2 // return @@ -1022,7 +992,7 @@ mod tests { let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 4); - // The stores from the original SSA should remain + // The store from the original SSA should remain assert_eq!(count_stores(main.entry_block(), &main.dfg), 2); assert_eq!(count_stores(b2, &main.dfg), 1); @@ -1069,160 +1039,4 @@ mod tests { let main = ssa.main(); assert_eq!(count_loads(main.entry_block(), &main.dfg), 1); } - - #[test] - fn remove_repeat_loads() { - // This tests starts with two loads from the same unknown load. - // Specifically you should look for `load v2` in `b3`. - // We should be able to remove the second repeated load. - let src = " - acir(inline) fn main f0 { - b0(): - v0 = allocate -> &mut Field - store Field 0 at v0 - v2 = allocate -> &mut &mut Field - store v0 at v2 - jmp b1(Field 0) - b1(v3: Field): - v4 = eq v3, Field 0 - jmpif v4 then: b2, else: b3 - b2(): - v5 = load v2 -> &mut Field - store Field 2 at v5 - v8 = add v3, Field 1 - jmp b1(v8) - b3(): - v9 = load v0 -> Field - v10 = eq v9, Field 2 - constrain v9 == Field 2 - v11 = load v2 -> &mut Field - v12 = load v2 -> &mut Field - v13 = load v12 -> Field - v14 = eq v13, Field 2 - constrain v13 == Field 2 - return - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - // The repeated load from v3 should be removed - // b3 should only have three loads now rather than four previously - // - // All stores are expected to remain. - let expected = " - acir(inline) fn main f0 { - b0(): - v1 = allocate -> &mut Field - store Field 0 at v1 - v3 = allocate -> &mut &mut Field - store v1 at v3 - jmp b1(Field 0) - b1(v0: Field): - v4 = eq v0, Field 0 - jmpif v4 then: b3, else: b2 - b3(): - v11 = load v3 -> &mut Field - store Field 2 at v11 - v13 = add v0, Field 1 - jmp b1(v13) - b2(): - v5 = load v1 -> Field - v7 = eq v5, Field 2 - constrain v5 == Field 2 - v8 = load v3 -> &mut Field - v9 = load v8 -> Field - v10 = eq v9, Field 2 - constrain v9 == Field 2 - return - } - "; - - let ssa = ssa.mem2reg(); - assert_normalized_ssa_equals(ssa, expected); - } - - #[test] - fn keep_repeat_loads_passed_to_a_call() { - // The test is the exact same as `remove_repeat_loads` above except with the call - // to `f1` between the repeated loads. - let src = " - acir(inline) fn main f0 { - b0(): - v1 = allocate -> &mut Field - store Field 0 at v1 - v3 = allocate -> &mut &mut Field - store v1 at v3 - jmp b1(Field 0) - b1(v0: Field): - v4 = eq v0, Field 0 - jmpif v4 then: b3, else: b2 - b3(): - v13 = load v3 -> &mut Field - store Field 2 at v13 - v15 = add v0, Field 1 - jmp b1(v15) - b2(): - v5 = load v1 -> Field - v7 = eq v5, Field 2 - constrain v5 == Field 2 - v8 = load v3 -> &mut Field - call f1(v3) - v10 = load v3 -> &mut Field - v11 = load v10 -> Field - v12 = eq v11, Field 2 - constrain v11 == Field 2 - return - } - acir(inline) fn foo f1 { - b0(v0: &mut Field): - return - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - let ssa = ssa.mem2reg(); - // We expect the program to be unchanged - assert_normalized_ssa_equals(ssa, src); - } - - #[test] - fn keep_repeat_loads_with_alias_store() { - // v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store - // to its aliases in between the repeat loads does not remove those loads. - let src = " - acir(inline) fn main f0 { - b0(v0: u1): - jmpif v0 then: b2, else: b1 - b2(): - v6 = allocate -> &mut Field - store Field 0 at v6 - jmp b3(v6, v6, v6) - b3(v1: &mut Field, v2: &mut Field, v3: &mut Field): - v8 = load v1 -> Field - store Field 2 at v2 - v10 = load v1 -> Field - store Field 1 at v3 - v11 = load v1 -> Field - store Field 3 at v3 - v13 = load v1 -> Field - constrain v8 == Field 0 - constrain v10 == Field 2 - constrain v11 == Field 1 - constrain v13 == Field 3 - return - b1(): - v4 = allocate -> &mut Field - store Field 1 at v4 - jmp b3(v4, v4, v4) - } - "; - - let ssa = Ssa::from_str(src).unwrap(); - - let ssa = ssa.mem2reg(); - // We expect the program to be unchanged - assert_normalized_ssa_equals(ssa, src); - } } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index ddc3365b551..0c6041029da 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -172,7 +172,6 @@ impl<'a> FunctionContext<'a> { /// Always returns a Value::Mutable wrapping the allocate instruction. pub(super) fn new_mutable_variable(&mut self, value_to_store: ValueId) -> Value { let element_type = self.builder.current_function.dfg.type_of_value(value_to_store); - self.builder.increment_array_reference_count(value_to_store); let alloc = self.builder.insert_allocate(element_type); self.builder.insert_store(alloc, value_to_store); let typ = self.builder.type_of_value(value_to_store); @@ -736,6 +735,7 @@ impl<'a> FunctionContext<'a> { // Reference counting in brillig relies on us incrementing reference // counts when arrays/slices are constructed or indexed. // Thus, if we dereference an lvalue which happens to be array/slice we should increment its reference counter. + self.builder.increment_array_reference_count(reference); self.builder.insert_load(reference, element_type).into() }) } @@ -916,10 +916,7 @@ impl<'a> FunctionContext<'a> { let parameters = self.builder.current_function.dfg.block_parameters(entry).to_vec(); for parameter in parameters { - // Avoid reference counts for immutable arrays that aren't behind references. - if self.builder.current_function.dfg.value_is_reference(parameter) { - self.builder.increment_array_reference_count(parameter); - } + self.builder.increment_array_reference_count(parameter); } entry @@ -936,9 +933,7 @@ impl<'a> FunctionContext<'a> { dropped_parameters.retain(|parameter| !terminator_args.contains(parameter)); for parameter in dropped_parameters { - if self.builder.current_function.dfg.value_is_reference(parameter) { - self.builder.decrement_array_reference_count(parameter); - } + self.builder.decrement_array_reference_count(parameter); } } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index d28236bd360..c50f0a7f45c 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -665,11 +665,12 @@ impl<'a> FunctionContext<'a> { values = values.map(|value| { let value = value.eval(self); + // Make sure to increment array reference counts on each let binding + self.builder.increment_array_reference_count(value); + Tree::Leaf(if let_expr.mutable { self.new_mutable_variable(value) } else { - // `new_mutable_variable` already increments rcs internally - self.builder.increment_array_reference_count(value); value::Value::Normal(value) }) }); diff --git a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs index c3ae29de058..e68a8cd5bd2 100644 --- a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs +++ b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs @@ -31,6 +31,10 @@ pub(crate) struct GatesFlamegraphCommand { /// The output folder for the flamegraph svg files #[clap(long, short)] output: String, + + /// The output name for the flamegraph svg files + #[clap(long, short = 'f')] + output_filename: Option, } pub(crate) fn run(args: GatesFlamegraphCommand) -> eyre::Result<()> { @@ -43,6 +47,7 @@ pub(crate) fn run(args: GatesFlamegraphCommand) -> eyre::Result<()> { }, &InfernoFlamegraphGenerator { count_name: "gates".to_string() }, &PathBuf::from(args.output), + args.output_filename, ) } @@ -51,6 +56,7 @@ fn run_with_provider( gates_provider: &Provider, flamegraph_generator: &Generator, output_path: &Path, + output_filename: Option, ) -> eyre::Result<()> { let mut program = read_program_from_file(artifact_path).context("Error reading program from file")?; @@ -91,13 +97,18 @@ fn run_with_provider( }) .collect(); + let output_filename = if let Some(output_filename) = &output_filename { + format!("{}::{}::gates.svg", output_filename, func_name) + } else { + format!("{}::gates.svg", func_name) + }; flamegraph_generator.generate_flamegraph( samples, &debug_artifact.debug_symbols[func_idx], &debug_artifact, artifact_path.to_str().unwrap(), &func_name, - &Path::new(&output_path).join(Path::new(&format!("{}_gates.svg", &func_name))), + &Path::new(&output_path).join(Path::new(&output_filename)), )?; } @@ -189,11 +200,17 @@ mod tests { }; let flamegraph_generator = TestFlamegraphGenerator::default(); - super::run_with_provider(&artifact_path, &provider, &flamegraph_generator, temp_dir.path()) - .expect("should run without errors"); + super::run_with_provider( + &artifact_path, + &provider, + &flamegraph_generator, + temp_dir.path(), + Some(String::from("test_filename")), + ) + .expect("should run without errors"); // Check that the output file was written to - let output_file = temp_dir.path().join("main_gates.svg"); + let output_file = temp_dir.path().join("test_filename::main::gates.svg"); assert!(output_file.exists()); } } From 1aba967e48a71d10e639a74db2411b92ccff78f9 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Fri, 29 Nov 2024 14:35:49 +0000 Subject: [PATCH 3/9] Update compiler/noirc_evaluator/src/ssa/opt/inlining.rs --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 6cf7070e65e..f91487fd73e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -1089,7 +1089,6 @@ mod test { } #[test] - #[ignore] #[should_panic( expected = "Attempted to recur more than 1000 times during inlining function 'main': acir(inline) fn main f0 {" )] From b86ff6ac325e1e932126b246a2e1608305781bc4 Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Fri, 29 Nov 2024 14:47:54 +0000 Subject: [PATCH 4/9] . --- .../noirc_evaluator/src/ssa/ir/instruction.rs | 15 ++- .../src/ssa/opt/constant_folding.rs | 116 ++++++++++++++---- 2 files changed, 108 insertions(+), 23 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index f606fffbf91..6737b335b7d 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -11,7 +11,7 @@ use fxhash::FxHasher64; use iter_extended::vecmap; use noirc_frontend::hir_def::types::Type as HirType; -use crate::ssa::opt::flatten_cfg::value_merger::ValueMerger; +use crate::ssa::{ir::function::RuntimeType, opt::flatten_cfg::value_merger::ValueMerger}; use super::{ basic_block::BasicBlockId, @@ -478,8 +478,19 @@ impl Instruction { | ArraySet { .. } | MakeArray { .. } => true, + // Store instructions must be removed by DIE in acir code, any load + // instructions should already be unused by that point. + // + // Note that this check assumes that it is being performed after the flattening + // pass and after the last mem2reg pass. This is currently the case for the DIE + // pass where this check is done, but does mean that we cannot perform mem2reg + // after the DIE pass. + Store { .. } => { + matches!(function.runtime(), RuntimeType::Acir(_)) + && function.reachable_blocks().len() == 1 + } + Constrain(..) - | Store { .. } | EnableSideEffectsIf { .. } | IncrementRc { .. } | DecrementRc { .. } diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index ceda0c6272f..76c726e679a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -940,32 +940,21 @@ mod test { // Regression for #4600 #[test] fn array_get_regression() { - // fn main f0 { - // b0(v0: u1, v1: u64): - // enable_side_effects_if v0 - // v2 = make_array [Field 0, Field 1] - // v3 = array_get v2, index v1 - // v4 = not v0 - // enable_side_effects_if v4 - // v5 = array_get v2, index v1 - // } - // // We want to make sure after constant folding both array_gets remain since they are // under different enable_side_effects_if contexts and thus one may be disabled while // the other is not. If one is removed, it is possible e.g. v4 is replaced with v2 which // is disabled (only gets from index 0) and thus returns the wrong result. let src = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u64): - enable_side_effects v0 - v4 = make_array [Field 0, Field 1] : [Field; 2] - v5 = array_get v4, index v1 -> Field - v6 = not v0 - enable_side_effects v6 - v7 = array_get v4, index v1 -> Field - return - } - "; + acir(inline) fn main f0 { + b0(v0: u1, v1: u64): + enable_side_effects v0 + v4 = make_array [Field 0, Field 1] : [Field; 2] + v5 = array_get v4, index v1 -> Field + v6 = not v0 + enable_side_effects v6 + v7 = array_get v4, index v1 -> Field + return + }"; let ssa = Ssa::from_str(src).unwrap(); // Expected output is unchanged @@ -1302,4 +1291,89 @@ mod test { assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); assert_eq!(main.dfg[b1].instructions().len(), 0); } + + #[test] + fn does_not_use_cached_constrain_in_block_that_is_not_dominated() { + let src = " + brillig(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 0 + jmpif v3 then: b1, else: b2 + b1(): + v5 = eq v1, Field 1 + constrain v1 == Field 1 + jmp b2() + b2(): + v6 = eq v1, Field 0 + constrain v1 == Field 0 + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn does_not_hoist_constrain_to_common_ancestor() { + let src = " + brillig(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 0 + jmpif v3 then: b1, else: b2 + b1(): + constrain v1 == Field 1 + jmp b2() + b2(): + jmpif v0 then: b3, else: b4 + b3(): + constrain v1 == Field 1 // This was incorrectly hoisted to b0 but this condition is not valid when going b0 -> b2 -> b4 + jmp b4() + b4(): + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn deduplicates_side_effecting_intrinsics() { + let src = " + // After EnableSideEffectsIf removal: + acir(inline) fn main f0 { + b0(v0: Field, v1: Field, v2: u1): + v4 = call is_unconstrained() -> u1 + v7 = call to_be_radix(v0, u32 256) -> [u8; 1] // `a.to_be_radix(256)`; + inc_rc v7 + v8 = call to_be_radix(v0, u32 256) -> [u8; 1] // duplicate load of `a` + inc_rc v8 + v9 = cast v2 as Field // `if c { a.to_be_radix(256) }` + v10 = mul v0, v9 // attaching `c` to `a` + v11 = call to_be_radix(v10, u32 256) -> [u8; 1] // calling `to_radix(c * a)` + inc_rc v11 + enable_side_effects v2 // side effect var for `c` shifted down by removal + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let expected = " + acir(inline) fn main f0 { + b0(v0: Field, v1: Field, v2: u1): + v4 = call is_unconstrained() -> u1 + v7 = call to_be_radix(v0, u32 256) -> [u8; 1] + inc_rc v7 + inc_rc v7 + v8 = cast v2 as Field + v9 = mul v0, v8 + v10 = call to_be_radix(v9, u32 256) -> [u8; 1] + inc_rc v10 + enable_side_effects v2 + return + } + "; + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, expected); + } } From e51b3075a925e032b96152f2ad8be00d4a71c972 Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Fri, 29 Nov 2024 14:50:13 +0000 Subject: [PATCH 5/9] . --- .../noirc_evaluator/src/ssa/ssa_gen/context.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index 0c6041029da..e39eed79021 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -172,6 +172,7 @@ impl<'a> FunctionContext<'a> { /// Always returns a Value::Mutable wrapping the allocate instruction. pub(super) fn new_mutable_variable(&mut self, value_to_store: ValueId) -> Value { let element_type = self.builder.current_function.dfg.type_of_value(value_to_store); + self.builder.increment_array_reference_count(value_to_store); let alloc = self.builder.insert_allocate(element_type); self.builder.insert_store(alloc, value_to_store); let typ = self.builder.type_of_value(value_to_store); @@ -732,10 +733,6 @@ impl<'a> FunctionContext<'a> { let element_types = Self::convert_type(element_type); values.map_both(element_types, |value, element_type| { let reference = value.eval_reference(); - // Reference counting in brillig relies on us incrementing reference - // counts when arrays/slices are constructed or indexed. - // Thus, if we dereference an lvalue which happens to be array/slice we should increment its reference counter. - self.builder.increment_array_reference_count(reference); self.builder.insert_load(reference, element_type).into() }) } @@ -916,7 +913,10 @@ impl<'a> FunctionContext<'a> { let parameters = self.builder.current_function.dfg.block_parameters(entry).to_vec(); for parameter in parameters { - self.builder.increment_array_reference_count(parameter); + // Avoid reference counts for immutable arrays that aren't behind references. + if self.builder.current_function.dfg.value_is_reference(parameter) { + self.builder.increment_array_reference_count(parameter); + } } entry @@ -933,7 +933,9 @@ impl<'a> FunctionContext<'a> { dropped_parameters.retain(|parameter| !terminator_args.contains(parameter)); for parameter in dropped_parameters { - self.builder.decrement_array_reference_count(parameter); + if self.builder.current_function.dfg.value_is_reference(parameter) { + self.builder.decrement_array_reference_count(parameter); + } } } From 78f2ca46703732e036ff7cd4f79cd633462e4a02 Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Fri, 29 Nov 2024 14:50:52 +0000 Subject: [PATCH 6/9] . --- .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 334 ++++++++++++++---- 1 file changed, 260 insertions(+), 74 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 0690dbbf204..f9fb5712892 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -18,6 +18,7 @@ //! - A reference with 0 aliases means we were unable to find which reference this reference //! refers to. If such a reference is stored to, we must conservatively invalidate every //! reference in the current block. +//! - We also track the last load instruction to each address per block. //! //! From there, to figure out the value of each reference at the end of block, iterate each instruction: //! - On `Instruction::Allocate`: @@ -28,6 +29,13 @@ //! - Furthermore, if the result of the load is a reference, mark the result as an alias //! of the reference it dereferences to (if known). //! - If which reference it dereferences to is not known, this load result has no aliases. +//! - We also track the last instance of a load instruction to each address in a block. +//! If we see that the last load instruction was from the same address as the current load instruction, +//! we move to replace the result of the current load with the result of the previous load. +//! This removal requires a couple conditions: +//! - No store occurs to that address before the next load, +//! - The address is not used as an argument to a call +//! This optimization helps us remove repeated loads for which there are not known values. //! - On `Instruction::Store { address, value }`: //! - If the address of the store is known: //! - If the address has exactly 1 alias: @@ -40,11 +48,13 @@ //! - Conservatively mark every alias in the block to `Unknown`. //! - Additionally, if there were no Loads to any alias of the address between this Store and //! the previous Store to the same address, the previous store can be removed. +//! - Remove the instance of the last load instruction to the address and its aliases //! - On `Instruction::Call { arguments }`: //! - If any argument of the call is a reference, set the value of each alias of that //! reference to `Unknown` //! - Any builtin functions that may return aliases if their input also contains a //! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc. +//! - Remove the instance of the last load instruction for any reference arguments and their aliases //! //! On a terminator instruction: //! - If the terminator is a `Jmp`: @@ -274,6 +284,9 @@ impl<'f> PerFunctionContext<'f> { if let Some(first_predecessor) = predecessors.next() { let mut first = self.blocks.get(&first_predecessor).cloned().unwrap_or_default(); first.last_stores.clear(); + // Last loads are tracked per block. During unification we are creating a new block from the current one, + // so we must clear the last loads of the current block before we return the new block. + first.last_loads.clear(); // Note that we have to start folding with the first block as the accumulator. // If we started with an empty block, an empty block union'd with any other block @@ -410,6 +423,28 @@ impl<'f> PerFunctionContext<'f> { self.last_loads.insert(address, (instruction, block_id)); } + + // Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads). + // If we do have a repeat load, we can remove the current load and map its result to the previous load's result. + if let Some(last_load) = references.last_loads.get(&address) { + let Instruction::Load { address: previous_address } = + &self.inserter.function.dfg[*last_load] + else { + panic!("Expected a Load instruction here"); + }; + let result = self.inserter.function.dfg.instruction_results(instruction)[0]; + let previous_result = + self.inserter.function.dfg.instruction_results(*last_load)[0]; + if *previous_address == address { + self.inserter.map_value(result, previous_result); + self.instructions_to_remove.insert(instruction); + } + } + // We want to set the load for every load even if the address has a known value + // and the previous load instruction was removed. + // We are safe to still remove a repeat load in this case as we are mapping from the current load's + // result to the previous load, which if it was removed should already have a mapping to the known value. + references.set_last_load(address, instruction); } Instruction::Store { address, value } => { let address = self.inserter.function.dfg.resolve(*address); @@ -435,6 +470,8 @@ impl<'f> PerFunctionContext<'f> { } references.set_known_value(address, value); + // If we see a store to an address, the last load to that address needs to remain. + references.keep_last_load_for(address, self.inserter.function); references.last_stores.insert(address, instruction); } Instruction::Allocate => { @@ -542,6 +579,9 @@ impl<'f> PerFunctionContext<'f> { let value = self.inserter.function.dfg.resolve(*value); references.set_unknown(value); references.mark_value_used(value, self.inserter.function); + + // If a reference is an argument to a call, the last load to that address and its aliases needs to remain. + references.keep_last_load_for(value, self.inserter.function); } } } @@ -572,6 +612,12 @@ impl<'f> PerFunctionContext<'f> { let destination_parameters = self.inserter.function.dfg[*destination].parameters(); assert_eq!(destination_parameters.len(), arguments.len()); + // If we have multiple parameters that alias that same argument value, + // then those parameters also alias each other. + // We save parameters with repeat arguments to later mark those + // parameters as aliasing one another. + let mut arg_set: HashMap> = HashMap::default(); + // Add an alias for each reference parameter for (parameter, argument) in destination_parameters.iter().zip(arguments) { if self.inserter.function.dfg.value_is_reference(*parameter) { @@ -581,10 +627,27 @@ impl<'f> PerFunctionContext<'f> { if let Some(aliases) = references.aliases.get_mut(expression) { // The argument reference is possibly aliased by this block parameter aliases.insert(*parameter); + + // Check if we have seen the same argument + let seen_parameters = arg_set.entry(argument).or_default(); + // Add the current parameter to the parameters we have seen for this argument. + // The previous parameters and the current one alias one another. + seen_parameters.insert(*parameter); } } } } + + // Set the aliases of the parameters + for (_, aliased_params) in arg_set { + for param in aliased_params.iter() { + self.set_aliases( + references, + *param, + AliasSet::known_multiple(aliased_params.clone()), + ); + } + } } TerminatorInstruction::Return { return_values, .. } => { // Removing all `last_stores` for each returned reference is more important here @@ -612,6 +675,8 @@ mod tests { map::Id, types::Type, }, + opt::assert_normalized_ssa_equals, + Ssa, }; #[test] @@ -822,88 +887,53 @@ mod tests { // is later stored in a successor block #[test] fn load_aliases_in_predecessor_block() { - // fn main { - // b0(): - // v0 = allocate - // store Field 0 at v0 - // v2 = allocate - // store v0 at v2 - // v3 = load v2 - // v4 = load v2 - // jmp b1() - // b1(): - // store Field 1 at v3 - // store Field 2 at v4 - // v7 = load v3 - // v8 = eq v7, Field 2 - // return - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let v0 = builder.insert_allocate(Type::field()); - - let zero = builder.field_constant(0u128); - builder.insert_store(v0, zero); - - let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field()))); - builder.insert_store(v2, v0); - - let v3 = builder.insert_load(v2, Type::field()); - let v4 = builder.insert_load(v2, Type::field()); - let b1 = builder.insert_block(); - builder.terminate_with_jmp(b1, vec![]); - - builder.switch_to_block(b1); - - let one = builder.field_constant(1u128); - builder.insert_store(v3, one); - - let two = builder.field_constant(2u128); - builder.insert_store(v4, two); - - let v8 = builder.insert_load(v3, Type::field()); - let _ = builder.insert_binary(v8, BinaryOp::Eq, two); - - builder.terminate_with_return(vec![]); - - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 2); + let src = " + acir(inline) fn main f0 { + b0(): + v0 = allocate -> &mut Field + store Field 0 at v0 + v2 = allocate -> &mut &mut Field + store v0 at v2 + v3 = load v2 -> &mut Field + v4 = load v2 -> &mut Field + jmp b1() + b1(): + store Field 1 at v3 + store Field 2 at v4 + v7 = load v3 -> Field + v8 = eq v7, Field 2 + return + } + "; - // Expected result: - // acir fn main f0 { - // b0(): - // v9 = allocate - // store Field 0 at v9 - // v10 = allocate - // jmp b1() - // b1(): - // return - // } - let ssa = ssa.mem2reg(); - println!("{}", ssa); + let mut ssa = Ssa::from_str(src).unwrap(); + let main = ssa.main_mut(); - let main = ssa.main(); - assert_eq!(main.reachable_blocks().len(), 2); + let instructions = main.dfg[main.entry_block()].instructions(); + assert_eq!(instructions.len(), 6); // The final return is not counted // All loads should be removed - assert_eq!(count_loads(main.entry_block(), &main.dfg), 0); - assert_eq!(count_loads(b1, &main.dfg), 0); - // The first store is not removed as it is used as a nested reference in another store. - // We would need to track whether the store where `v9` is the store value gets removed to know whether + // We would need to track whether the store where `v0` is the store value gets removed to know whether // to remove it. - assert_eq!(count_stores(main.entry_block(), &main.dfg), 1); // The first store in b1 is removed since there is another store to the same reference // in the same block, and the store is not needed before the later store. // The rest of the stores are also removed as no loads are done within any blocks // to the stored values. - assert_eq!(count_stores(b1, &main.dfg), 0); - - let b1_instructions = main.dfg[b1].instructions(); + let expected = " + acir(inline) fn main f0 { + b0(): + v0 = allocate -> &mut Field + store Field 0 at v0 + v2 = allocate -> &mut &mut Field + jmp b1() + b1(): + return + } + "; - // We expect the last eq to be optimized out - assert_eq!(b1_instructions.len(), 0); + let ssa = ssa.mem2reg(); + assert_normalized_ssa_equals(ssa, expected); } #[test] @@ -933,7 +963,7 @@ mod tests { // v10 = eq v9, Field 2 // constrain v9 == Field 2 // v11 = load v2 - // v12 = load v10 + // v12 = load v11 // v13 = eq v12, Field 2 // constrain v11 == Field 2 // return @@ -992,7 +1022,7 @@ mod tests { let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 4); - // The store from the original SSA should remain + // The stores from the original SSA should remain assert_eq!(count_stores(main.entry_block(), &main.dfg), 2); assert_eq!(count_stores(b2, &main.dfg), 1); @@ -1039,4 +1069,160 @@ mod tests { let main = ssa.main(); assert_eq!(count_loads(main.entry_block(), &main.dfg), 1); } -} + + #[test] + fn remove_repeat_loads() { + // This tests starts with two loads from the same unknown load. + // Specifically you should look for `load v2` in `b3`. + // We should be able to remove the second repeated load. + let src = " + acir(inline) fn main f0 { + b0(): + v0 = allocate -> &mut Field + store Field 0 at v0 + v2 = allocate -> &mut &mut Field + store v0 at v2 + jmp b1(Field 0) + b1(v3: Field): + v4 = eq v3, Field 0 + jmpif v4 then: b2, else: b3 + b2(): + v5 = load v2 -> &mut Field + store Field 2 at v5 + v8 = add v3, Field 1 + jmp b1(v8) + b3(): + v9 = load v0 -> Field + v10 = eq v9, Field 2 + constrain v9 == Field 2 + v11 = load v2 -> &mut Field + v12 = load v2 -> &mut Field + v13 = load v12 -> Field + v14 = eq v13, Field 2 + constrain v13 == Field 2 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + // The repeated load from v3 should be removed + // b3 should only have three loads now rather than four previously + // + // All stores are expected to remain. + let expected = " + acir(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = allocate -> &mut &mut Field + store v1 at v3 + jmp b1(Field 0) + b1(v0: Field): + v4 = eq v0, Field 0 + jmpif v4 then: b3, else: b2 + b3(): + v11 = load v3 -> &mut Field + store Field 2 at v11 + v13 = add v0, Field 1 + jmp b1(v13) + b2(): + v5 = load v1 -> Field + v7 = eq v5, Field 2 + constrain v5 == Field 2 + v8 = load v3 -> &mut Field + v9 = load v8 -> Field + v10 = eq v9, Field 2 + constrain v9 == Field 2 + return + } + "; + + let ssa = ssa.mem2reg(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn keep_repeat_loads_passed_to_a_call() { + // The test is the exact same as `remove_repeat_loads` above except with the call + // to `f1` between the repeated loads. + let src = " + acir(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = allocate -> &mut &mut Field + store v1 at v3 + jmp b1(Field 0) + b1(v0: Field): + v4 = eq v0, Field 0 + jmpif v4 then: b3, else: b2 + b3(): + v13 = load v3 -> &mut Field + store Field 2 at v13 + v15 = add v0, Field 1 + jmp b1(v15) + b2(): + v5 = load v1 -> Field + v7 = eq v5, Field 2 + constrain v5 == Field 2 + v8 = load v3 -> &mut Field + call f1(v3) + v10 = load v3 -> &mut Field + v11 = load v10 -> Field + v12 = eq v11, Field 2 + constrain v11 == Field 2 + return + } + acir(inline) fn foo f1 { + b0(v0: &mut Field): + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.mem2reg(); + // We expect the program to be unchanged + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn keep_repeat_loads_with_alias_store() { + // v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store + // to its aliases in between the repeat loads does not remove those loads. + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + jmpif v0 then: b2, else: b1 + b2(): + v6 = allocate -> &mut Field + store Field 0 at v6 + jmp b3(v6, v6, v6) + b3(v1: &mut Field, v2: &mut Field, v3: &mut Field): + v8 = load v1 -> Field + store Field 2 at v2 + v10 = load v1 -> Field + store Field 1 at v3 + v11 = load v1 -> Field + store Field 3 at v3 + v13 = load v1 -> Field + constrain v8 == Field 0 + constrain v10 == Field 2 + constrain v11 == Field 1 + constrain v13 == Field 3 + return + b1(): + v4 = allocate -> &mut Field + store Field 1 at v4 + jmp b3(v4, v4, v4) + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.mem2reg(); + // We expect the program to be unchanged + assert_normalized_ssa_equals(ssa, src); + } +} \ No newline at end of file From 4c9dd1fddb69bdb9ae51150e7917da7ba4e99b27 Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Fri, 29 Nov 2024 14:52:20 +0000 Subject: [PATCH 7/9] . --- .../src/ssa/opt/flatten_cfg.rs | 558 ++++++++---------- .../noirc_evaluator/src/ssa/ssa_gen/mod.rs | 5 +- 2 files changed, 254 insertions(+), 309 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 5d114672a55..08eb0905914 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -131,8 +131,7 @@ //! v11 = mul v4, Field 12 //! v12 = add v10, v11 //! store v12 at v5 (new store) -use fxhash::FxHashMap as HashMap; -use std::collections::{BTreeMap, HashSet}; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use acvm::{acir::AcirField, acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; @@ -186,18 +185,6 @@ struct Context<'f> { /// Maps start of branch -> end of branch branch_ends: HashMap, - /// Maps an address to the old and new value of the element at that address - /// These only hold stores for one block at a time and is cleared - /// between inlining of branches. - store_values: HashMap, - - /// Stores all allocations local to the current branch. - /// Since these branches are local to the current branch (ie. only defined within one branch of - /// an if expression), they should not be merged with their previous value or stored value in - /// the other branch since there is no such value. The ValueId here is that which is returned - /// by the allocate instruction. - local_allocations: HashSet, - /// A stack of each jmpif condition that was taken to reach a particular point in the program. /// When two branches are merged back into one, this constitutes a join point, and is analogous /// to the rest of the program after an if statement. When such a join point / end block is @@ -214,13 +201,15 @@ struct Context<'f> { /// When processing a block, we pop this stack to get its arguments /// and at the end we push the arguments for his successor arguments_stack: Vec>, -} -#[derive(Clone)] -pub(crate) struct Store { - old_value: ValueId, - new_value: ValueId, - call_stack: CallStack, + /// Stores all allocations local to the current branch. + /// + /// Since these branches are local to the current branch (i.e. only defined within one branch of + /// an if expression), they should not be merged with their previous value or stored value in + /// the other branch since there is no such value. + /// + /// The `ValueId` here is that which is returned by the allocate instruction. + local_allocations: HashSet, } #[derive(Clone)] @@ -231,8 +220,6 @@ struct ConditionalBranch { old_condition: ValueId, // The condition of the branch condition: ValueId, - // The store values accumulated when processing the branch - store_values: HashMap, // The allocations accumulated when processing the branch local_allocations: HashSet, } @@ -263,12 +250,11 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap Context<'f> { let old_condition = *condition; let then_condition = self.inserter.resolve(old_condition); - let old_stores = std::mem::take(&mut self.store_values); let old_allocations = std::mem::take(&mut self.local_allocations); let branch = ConditionalBranch { old_condition, condition: self.link_condition(then_condition), - store_values: old_stores, - local_allocations: old_allocations, last_block: *then_destination, + local_allocations: old_allocations, }; let cond_context = ConditionalContext { condition: then_condition, @@ -473,19 +457,12 @@ impl<'f> Context<'f> { ); let else_condition = self.link_condition(else_condition); - // Make sure the else branch sees the previous values of each store - // rather than any values created in the 'then' branch. - let old_stores = std::mem::take(&mut cond_context.then_branch.store_values); - cond_context.then_branch.store_values = std::mem::take(&mut self.store_values); - self.undo_stores_in_then_branch(&cond_context.then_branch.store_values); - let old_allocations = std::mem::take(&mut self.local_allocations); let else_branch = ConditionalBranch { old_condition: cond_context.then_branch.old_condition, condition: else_condition, - store_values: old_stores, - local_allocations: old_allocations, last_block: *block, + local_allocations: old_allocations, }; cond_context.then_branch.local_allocations.clear(); cond_context.else_branch = Some(else_branch); @@ -509,10 +486,8 @@ impl<'f> Context<'f> { } let mut else_branch = cond_context.else_branch.unwrap(); - let stores_in_branch = std::mem::replace(&mut self.store_values, else_branch.store_values); self.local_allocations = std::mem::take(&mut else_branch.local_allocations); else_branch.last_block = *block; - else_branch.store_values = stores_in_branch; cond_context.else_branch = Some(else_branch); // We must remember to reset whether side effects are enabled when both branches @@ -580,8 +555,6 @@ impl<'f> Context<'f> { .first() }); - let call_stack = cond_context.call_stack; - self.merge_stores(cond_context.then_branch, cond_context.else_branch, call_stack); self.arguments_stack.pop(); self.arguments_stack.pop(); self.arguments_stack.push(args); @@ -636,116 +609,29 @@ impl<'f> Context<'f> { self.insert_instruction_with_typevars(enable_side_effects, None, call_stack); } - /// Merge any store instructions found in each branch. - /// - /// This function relies on the 'then' branch being merged before the 'else' branch of a jmpif - /// instruction. If this ordering is changed, the ordering that store values are merged within - /// this function also needs to be changed to reflect that. - fn merge_stores( - &mut self, - then_branch: ConditionalBranch, - else_branch: Option, - call_stack: CallStack, - ) { - // Address -> (then_value, else_value, value_before_the_if) - let mut new_map = BTreeMap::new(); - - for (address, store) in then_branch.store_values { - new_map.insert(address, (store.new_value, store.old_value, store.old_value)); - } - - if else_branch.is_some() { - for (address, store) in else_branch.clone().unwrap().store_values { - if let Some(entry) = new_map.get_mut(&address) { - entry.1 = store.new_value; - } else { - new_map.insert(address, (store.old_value, store.new_value, store.old_value)); - } - } - } - - let then_condition = then_branch.condition; - let block = self.inserter.function.entry_block(); - - // Merging must occur in a separate loop as we cannot borrow `self` as mutable while `value_merger` does - let mut new_values = HashMap::default(); - for (address, (then_case, else_case, _)) in &new_map { - let instruction = Instruction::IfElse { - then_condition, - then_value: *then_case, - else_value: *else_case, - }; - let dfg = &mut self.inserter.function.dfg; - let value = dfg - .insert_instruction_and_results(instruction, block, None, call_stack.clone()) - .first(); - - new_values.insert(address, value); - } - - // Replace stores with new merged values - for (address, (_, _, old_value)) in &new_map { - let value = new_values[address]; - let address = *address; - self.insert_instruction_with_typevars( - Instruction::Store { address, value }, - None, - call_stack.clone(), - ); - - if let Some(store) = self.store_values.get_mut(&address) { - store.new_value = value; - } else { - self.store_values.insert( - address, - Store { - old_value: *old_value, - new_value: value, - call_stack: call_stack.clone(), - }, - ); - } - } - } - - fn remember_store(&mut self, address: ValueId, new_value: ValueId, call_stack: CallStack) { - if !self.local_allocations.contains(&address) { - if let Some(store_value) = self.store_values.get_mut(&address) { - store_value.new_value = new_value; - } else { - let load = Instruction::Load { address }; - - let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]); - let old_value = self - .insert_instruction_with_typevars(load.clone(), load_type, call_stack.clone()) - .first(); - - self.store_values.insert(address, Store { old_value, new_value, call_stack }); - } - } - } - /// Push the given instruction to the end of the entry block of the current function. /// /// Note that each ValueId of the instruction will be mapped via self.inserter.resolve. /// As a result, the instruction that will be pushed will actually be a new instruction /// with a different InstructionId from the original. The results of the given instruction /// will also be mapped to the results of the new instruction. - fn push_instruction(&mut self, id: InstructionId) -> Vec { + /// + /// `previous_allocate_result` should only be set to the result of an allocate instruction + /// if that instruction was the instruction immediately previous to this one - if there are + /// any instructions in between it should be None. + fn push_instruction(&mut self, id: InstructionId) { let (instruction, call_stack) = self.inserter.map_instruction(id); let instruction = self.handle_instruction_side_effects(instruction, call_stack.clone()); - let is_allocate = matches!(instruction, Instruction::Allocate); + let instruction_is_allocate = matches!(&instruction, Instruction::Allocate); let entry = self.inserter.function.entry_block(); let results = self.inserter.push_instruction_value(instruction, id, entry, call_stack); // Remember an allocate was created local to this branch so that we do not try to merge store // values across branches for it later. - if is_allocate { + if instruction_is_allocate { self.local_allocations.insert(results.first()); } - - results.results().into_owned() } /// If we are currently in a branch, we need to modify constrain instructions @@ -782,8 +668,32 @@ impl<'f> Context<'f> { Instruction::Constrain(lhs, rhs, message) } Instruction::Store { address, value } => { - self.remember_store(address, value, call_stack); - Instruction::Store { address, value } + // If this instruction immediately follows an allocate, and stores to that + // address there is no previous value to load and we don't need a merge anyway. + if self.local_allocations.contains(&address) { + Instruction::Store { address, value } + } else { + // Instead of storing `value`, store `if condition { value } else { previous_value }` + let typ = self.inserter.function.dfg.type_of_value(value); + let load = Instruction::Load { address }; + let previous_value = self + .insert_instruction_with_typevars( + load, + Some(vec![typ]), + call_stack.clone(), + ) + .first(); + + let instruction = Instruction::IfElse { + then_condition: condition, + then_value: value, + + else_value: previous_value, + }; + + let updated_value = self.insert_instruction(instruction, call_stack); + Instruction::Store { address, value: updated_value } + } } Instruction::RangeCheck { value, max_bit_size, assert_message } => { // Replace value with `value * predicate` to zero out value when predicate is inactive. @@ -905,23 +815,11 @@ impl<'f> Context<'f> { call_stack, ) } - - fn undo_stores_in_then_branch(&mut self, store_values: &HashMap) { - for (address, store) in store_values { - let address = *address; - let value = store.old_value; - let instruction = Instruction::Store { address, value }; - // Considering the location of undoing a store to be the same as the original store. - self.insert_instruction_with_typevars(instruction, None, store.call_stack.clone()); - } - } } #[cfg(test)] mod test { - use std::sync::Arc; - - use acvm::{acir::AcirField, FieldElement}; + use acvm::acir::AcirField; use crate::ssa::{ function_builder::FunctionBuilder, @@ -1023,15 +921,13 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - store Field 5 at v1 - v4 = not v0 - store v2 at v1 + v3 = cast v0 as Field + v5 = sub Field 5, v2 + v6 = mul v3, v5 + v7 = add v2, v6 + store v7 at v1 + v8 = not v0 enable_side_effects u1 1 - v6 = cast v0 as Field - v7 = sub Field 5, v2 - v8 = mul v6, v7 - v9 = add v2, v8 - store v9 at v1 return } "; @@ -1062,17 +958,20 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - store Field 5 at v1 - v4 = not v0 - store v2 at v1 - enable_side_effects v4 - v5 = load v1 -> Field - store Field 6 at v1 + v3 = cast v0 as Field + v5 = sub Field 5, v2 + v6 = mul v3, v5 + v7 = add v2, v6 + store v7 at v1 + v8 = not v0 + enable_side_effects v8 + v9 = load v1 -> Field + v10 = cast v8 as Field + v12 = sub Field 6, v9 + v13 = mul v10, v12 + v14 = add v9, v13 + store v14 at v1 enable_side_effects u1 1 - v8 = cast v0 as Field - v10 = mul v8, Field -1 - v11 = add Field 6, v10 - store v11 at v1 return } "; @@ -1115,84 +1014,123 @@ mod test { // b7 b8 // ↘ ↙ // b9 - let src = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u1): - v2 = allocate -> &mut Field - store Field 0 at v2 - v4 = load v2 -> Field - // call v1(Field 0, v4) - jmp b1() - b1(): - store Field 1 at v2 - v6 = load v2 -> Field - // call v1(Field 1, v6) - jmpif v0 then: b2, else: b3 - b2(): - store Field 2 at v2 - v8 = load v2 -> Field - // call v1(Field 2, v8) - jmp b4() - b4(): - v12 = load v2 -> Field - // call v1(Field 4, v12) - jmpif v1 then: b5, else: b6 - b5(): - store Field 5 at v2 - v14 = load v2 -> Field - // call v1(Field 5, v14) - jmp b7() - b7(): - v18 = load v2 -> Field - // call v1(Field 7, v18) - jmp b9() - b9(): - v22 = load v2 -> Field - // call v1(Field 9, v22) - v23 = load v2 -> Field - return v23 - b6(): - store Field 6 at v2 - v16 = load v2 -> Field - // call v1(Field 6, v16) - jmp b7() - b3(): - store Field 3 at v2 - v10 = load v2 -> Field - // call v1(Field 3, v10) - jmp b8() - b8(): - v20 = load v2 -> Field - // call v1(Field 8, v20) - jmp b9() - } - "; + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); - let ssa = Ssa::from_str(src).unwrap(); - let ssa = ssa.flatten_cfg().mem2reg(); + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + let b4 = builder.insert_block(); + let b5 = builder.insert_block(); + let b6 = builder.insert_block(); + let b7 = builder.insert_block(); + let b8 = builder.insert_block(); + let b9 = builder.insert_block(); + + let c1 = builder.add_parameter(Type::bool()); + let c4 = builder.add_parameter(Type::bool()); + + let r1 = builder.insert_allocate(Type::field()); + + let store_value = |builder: &mut FunctionBuilder, value: u128| { + let value = builder.field_constant(value); + builder.insert_store(r1, value); + }; + + let test_function = Id::test_new(1); + + let call_test_function = |builder: &mut FunctionBuilder, block: u128| { + let block = builder.field_constant(block); + let load = builder.insert_load(r1, Type::field()); + builder.insert_call(test_function, vec![block, load], Vec::new()); + }; + + let switch_store_and_test_function = + |builder: &mut FunctionBuilder, block, block_number: u128| { + builder.switch_to_block(block); + store_value(builder, block_number); + call_test_function(builder, block_number); + }; + + let switch_and_test_function = + |builder: &mut FunctionBuilder, block, block_number: u128| { + builder.switch_to_block(block); + call_test_function(builder, block_number); + }; + + store_value(&mut builder, 0); + call_test_function(&mut builder, 0); + builder.terminate_with_jmp(b1, vec![]); + + switch_store_and_test_function(&mut builder, b1, 1); + builder.terminate_with_jmpif(c1, b2, b3); + + switch_store_and_test_function(&mut builder, b2, 2); + builder.terminate_with_jmp(b4, vec![]); + + switch_store_and_test_function(&mut builder, b3, 3); + builder.terminate_with_jmp(b8, vec![]); + + switch_and_test_function(&mut builder, b4, 4); + builder.terminate_with_jmpif(c4, b5, b6); + + switch_store_and_test_function(&mut builder, b5, 5); + builder.terminate_with_jmp(b7, vec![]); + + switch_store_and_test_function(&mut builder, b6, 6); + builder.terminate_with_jmp(b7, vec![]); + + switch_and_test_function(&mut builder, b7, 7); + builder.terminate_with_jmp(b9, vec![]); + + switch_and_test_function(&mut builder, b8, 8); + builder.terminate_with_jmp(b9, vec![]); + + switch_and_test_function(&mut builder, b9, 9); + let load = builder.insert_load(r1, Type::field()); + builder.terminate_with_return(vec![load]); + + let ssa = builder.finish().flatten_cfg().mem2reg(); // Expected results after mem2reg removes the allocation and each load and store: - let expected = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u1): - v2 = allocate -> &mut Field - enable_side_effects v0 - v3 = mul v0, v1 - enable_side_effects v3 - v4 = not v1 - v5 = mul v0, v4 - enable_side_effects v0 - v6 = cast v3 as Field - v8 = mul v6, Field -1 - v10 = add Field 6, v8 - v11 = not v0 - enable_side_effects u1 1 - v13 = cast v0 as Field - v15 = sub v10, Field 3 - v16 = mul v13, v15 - v17 = add Field 3, v16 - return v17 - }"; + // + // fn main f0 { + // b0(v0: u1, v1: u1): + // call test_function(Field 0, Field 0) + // call test_function(Field 1, Field 1) + // enable_side_effects v0 + // call test_function(Field 2, Field 2) + // call test_function(Field 4, Field 2) + // v29 = and v0, v1 + // enable_side_effects v29 + // call test_function(Field 5, Field 5) + // v32 = not v1 + // v33 = and v0, v32 + // enable_side_effects v33 + // call test_function(Field 6, Field 6) + // enable_side_effects v0 + // v36 = mul v1, Field 5 + // v37 = mul v32, Field 2 + // v38 = add v36, v37 + // v39 = mul v1, Field 5 + // v40 = mul v32, Field 6 + // v41 = add v39, v40 + // call test_function(Field 7, v42) + // v43 = not v0 + // enable_side_effects v43 + // store Field 3 at v2 + // call test_function(Field 3, Field 3) + // call test_function(Field 8, Field 3) + // enable_side_effects Field 1 + // v47 = mul v0, v41 + // v48 = mul v43, Field 1 + // v49 = add v47, v48 + // v50 = mul v0, v44 + // v51 = mul v43, Field 3 + // v52 = add v50, v51 + // call test_function(Field 9, v53) + // return v54 + // } let main = ssa.main(); let ret = match main.dfg[main.entry_block()].terminator() { @@ -1201,12 +1139,7 @@ mod test { }; let merged_values = get_all_constants_reachable_from_instruction(&main.dfg, ret); - assert_eq!( - merged_values, - vec![FieldElement::from(3u128), FieldElement::from(6u128), -FieldElement::from(1u128)] - ); - - assert_normalized_ssa_equals(ssa, expected); + assert_eq!(merged_values, vec![1, 3, 5, 6]); } #[test] @@ -1287,7 +1220,7 @@ mod test { fn get_all_constants_reachable_from_instruction( dfg: &DataFlowGraph, value: ValueId, - ) -> Vec { + ) -> Vec { match dfg[value] { Value::Instruction { instruction, .. } => { let mut values = vec![]; @@ -1305,7 +1238,7 @@ mod test { values.dedup(); values } - Value::NumericConstant { constant, .. } => vec![constant], + Value::NumericConstant { constant, .. } => vec![constant.to_u128()], _ => Vec::new(), } } @@ -1344,63 +1277,74 @@ mod test { fn should_not_merge_incorrectly_to_false() { // Regression test for #1792 // Tests that it does not simplify a true constraint an always-false constraint - // acir(inline) fn main f1 { - // b0(v0: [u8; 2]): - // v5 = array_get v0, index u8 0 - // v6 = cast v5 as u32 - // v8 = truncate v6 to 1 bits, max_bit_size: 32 - // v9 = cast v8 as u1 - // v10 = allocate - // store u8 0 at v10 - // jmpif v9 then: b2, else: b3 - // b2(): - // v12 = cast v5 as Field - // v13 = add v12, Field 1 - // store v13 at v10 - // jmp b4() - // b4(): - // constrain v9 == u1 1 - // return - // b3(): - // store u8 0 at v10 - // jmp b4() - // } - let main_id = Id::test_new(1); - let mut builder = FunctionBuilder::new("main".into(), main_id); - builder.insert_block(); // b0 - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); - let element_type = Arc::new(vec![Type::unsigned(8)]); - let array_type = Type::Array(element_type.clone(), 2); - let array = builder.add_parameter(array_type); - let zero = builder.numeric_constant(0_u128, Type::unsigned(8)); - let v5 = builder.insert_array_get(array, zero, Type::unsigned(8)); - let v6 = builder.insert_cast(v5, Type::unsigned(32)); - let i_two = builder.numeric_constant(2_u128, Type::unsigned(32)); - let v8 = builder.insert_binary(v6, BinaryOp::Mod, i_two); - let v9 = builder.insert_cast(v8, Type::bool()); - let v10 = builder.insert_allocate(Type::field()); - builder.insert_store(v10, zero); - builder.terminate_with_jmpif(v9, b1, b2); - builder.switch_to_block(b1); - let one = builder.field_constant(1_u128); - let v5b = builder.insert_cast(v5, Type::field()); - let v13: Id = builder.insert_binary(v5b, BinaryOp::Add, one); - let v14 = builder.insert_cast(v13, Type::unsigned(8)); - builder.insert_store(v10, v14); - builder.terminate_with_jmp(b3, vec![]); - builder.switch_to_block(b2); - builder.insert_store(v10, zero); - builder.terminate_with_jmp(b3, vec![]); - builder.switch_to_block(b3); - let v_true = builder.numeric_constant(true, Type::bool()); - let v12 = builder.insert_binary(v9, BinaryOp::Eq, v_true); - builder.insert_constrain(v12, v_true, None); - builder.terminate_with_return(vec![]); - let ssa = builder.finish(); + + let src = " + acir(inline) fn main f0 { + b0(v0: [u8; 2]): + v2 = array_get v0, index u8 0 -> u8 + v3 = cast v2 as u32 + v4 = truncate v3 to 1 bits, max_bit_size: 32 + v5 = cast v4 as u1 + v6 = allocate -> &mut Field + store u8 0 at v6 + jmpif v5 then: b2, else: b1 + b2(): + v7 = cast v2 as Field + v9 = add v7, Field 1 + v10 = cast v9 as u8 + store v10 at v6 + jmp b3() + b3(): + constrain v5 == u1 1 + return + b1(): + store u8 0 at v6 + jmp b3() + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: [u8; 2]): + v2 = array_get v0, index u8 0 -> u8 + v3 = cast v2 as u32 + v4 = truncate v3 to 1 bits, max_bit_size: 32 + v5 = cast v4 as u1 + v6 = allocate -> &mut Field + store u8 0 at v6 + enable_side_effects v5 + v7 = cast v2 as Field + v9 = add v7, Field 1 + v10 = cast v9 as u8 + v11 = load v6 -> u8 + v12 = cast v4 as Field + v13 = cast v11 as Field + v14 = sub v9, v13 + v15 = mul v12, v14 + v16 = add v13, v15 + v17 = cast v16 as u8 + store v17 at v6 + v18 = not v5 + enable_side_effects v18 + v19 = load v6 -> u8 + v20 = cast v18 as Field + v21 = cast v19 as Field + v23 = sub Field 0, v21 + v24 = mul v20, v23 + v25 = add v21, v24 + v26 = cast v25 as u8 + store v26 at v6 + enable_side_effects u1 1 + constrain v5 == u1 1 + return + } + "; + let flattened_ssa = ssa.flatten_cfg(); let main = flattened_ssa.main(); + // Now assert that there is not an always-false constraint after flattening: let mut constrain_count = 0; for instruction in main.dfg[main.entry_block()].instructions() { @@ -1414,6 +1358,8 @@ mod test { } } assert_eq!(constrain_count, 1); + + assert_normalized_ssa_equals(flattened_ssa, expected); } #[test] @@ -1549,10 +1495,10 @@ mod test { b2(): return b1(): - jmp b2() + jmp b2() } "; let merged_ssa = Ssa::from_str(src).unwrap(); let _ = merged_ssa.flatten_cfg(); } -} +} \ No newline at end of file diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index c50f0a7f45c..d28236bd360 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -665,12 +665,11 @@ impl<'a> FunctionContext<'a> { values = values.map(|value| { let value = value.eval(self); - // Make sure to increment array reference counts on each let binding - self.builder.increment_array_reference_count(value); - Tree::Leaf(if let_expr.mutable { self.new_mutable_variable(value) } else { + // `new_mutable_variable` already increments rcs internally + self.builder.increment_array_reference_count(value); value::Value::Normal(value) }) }); From c0011b6a361fd0d3858ab5eec10967bdf76db2ac Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Fri, 29 Nov 2024 14:53:27 +0000 Subject: [PATCH 8/9] . --- .../src/ssa/opt/constant_folding.rs | 347 ++++++++++++------ 1 file changed, 233 insertions(+), 114 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 76c726e679a..fe981b73737 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -149,7 +149,8 @@ impl Function { use_constraint_info: bool, brillig_info: Option, ) { - let mut context = Context::new(self, use_constraint_info, brillig_info); + let mut context = Context::new(use_constraint_info, brillig_info); + let mut dom = DominatorTree::with_function(self); context.block_queue.push_back(self.entry_block()); while let Some(block) = context.block_queue.pop_front() { @@ -158,7 +159,7 @@ impl Function { } context.visited_blocks.insert(block); - context.fold_constants_in_block(self, block); + context.fold_constants_in_block(&mut self.dfg, &mut dom, block); } } } @@ -172,22 +173,15 @@ struct Context<'a> { /// Contains sets of values which are constrained to be equivalent to each other. /// - /// The mapping's structure is `side_effects_enabled_var => (constrained_value => [(block, simplified_value)])`. + /// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`. /// /// We partition the maps of constrained values according to the side-effects flag at the point /// at which the values are constrained. This prevents constraints which are only sometimes enforced /// being used to modify the rest of the program. - /// - /// We also keep track of how a value was simplified to other values per block. That is, - /// a same ValueId could have been simplified to one value in one block and to another value - /// in another block. - constraint_simplification_mappings: - HashMap>>, + constraint_simplification_mappings: ConstraintSimplificationCache, // Cache of instructions without any side-effects along with their outputs. cached_instruction_results: InstructionResultCache, - - dom: DominatorTree, } #[derive(Copy, Clone)] @@ -196,9 +190,56 @@ pub(crate) struct BrilligInfo<'a> { brillig_functions: &'a BTreeMap, } +/// Records a simplified equivalents of an [`Instruction`] in the blocks +/// where the constraint that advised the simplification has been encountered. +/// +/// For more information see [`ConstraintSimplificationCache`]. +#[derive(Default)] +struct SimplificationCache { + /// Simplified expressions where we found them. + /// + /// It will always have at least one value because `add` is called + /// after the default is constructed. + simplifications: HashMap, +} + +impl SimplificationCache { + /// Called with a newly encountered simplification. + fn add(&mut self, dfg: &DataFlowGraph, simple: ValueId, block: BasicBlockId) { + self.simplifications + .entry(block) + .and_modify(|existing| { + // `SimplificationCache` may already hold a simplification in this block + // so we check whether `simple` is a better simplification than the current one. + if let Some((_, simpler)) = simplify(dfg, *existing, simple) { + *existing = simpler; + }; + }) + .or_insert(simple); + } + + /// Try to find a simplification in a visible block. + fn get(&self, block: BasicBlockId, dom: &DominatorTree) -> Option { + // Deterministically walk up the dominator chain until we encounter a block that contains a simplification. + dom.find_map_dominator(block, |b| self.simplifications.get(&b).cloned()) + } +} + +/// HashMap from `(side_effects_enabled_var, Instruction)` to a simplified expression that it can +/// be replaced with based on constraints that testify to their equivalence, stored together +/// with the set of blocks at which this constraint has been observed. +/// +/// Only blocks dominated by one in the cache should have access to this information, otherwise +/// we create a sort of time paradox where we replace an instruction with a constant we believe +/// it _should_ equal to, without ever actually producing and asserting the value. +type ConstraintSimplificationCache = HashMap>; + /// HashMap from `(Instruction, side_effects_enabled_var)` to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. /// +/// The `side_effects_enabled_var` is optional because we only use them when `Instruction::requires_acir_gen_predicate` +/// is true _and_ the constraint information is also taken into account. +/// /// In addition to each result, the original BasicBlockId is stored as well. This allows us /// to deduplicate instructions across blocks as long as the new block dominates the original. type InstructionResultCache = HashMap, ResultCache>>; @@ -208,15 +249,11 @@ type InstructionResultCache = HashMap, Resu /// For more information see [`InstructionResultCache`]. #[derive(Default)] struct ResultCache { - results: Vec<(BasicBlockId, Vec)>, + result: Option<(BasicBlockId, Vec)>, } impl<'brillig> Context<'brillig> { - fn new( - function: &Function, - use_constraint_info: bool, - brillig_info: Option>, - ) -> Self { + fn new(use_constraint_info: bool, brillig_info: Option>) -> Self { Self { use_constraint_info, brillig_info, @@ -224,52 +261,65 @@ impl<'brillig> Context<'brillig> { block_queue: Default::default(), constraint_simplification_mappings: Default::default(), cached_instruction_results: Default::default(), - dom: DominatorTree::with_function(function), } } - fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { - let instructions = function.dfg[block].take_instructions(); + fn fold_constants_in_block( + &mut self, + dfg: &mut DataFlowGraph, + dom: &mut DominatorTree, + block: BasicBlockId, + ) { + let instructions = dfg[block].take_instructions(); - let mut side_effects_enabled_var = - function.dfg.make_constant(FieldElement::one(), Type::bool()); + // Default side effect condition variable with an enabled state. + let mut side_effects_enabled_var = dfg.make_constant(FieldElement::one(), Type::bool()); for instruction_id in instructions { self.fold_constants_into_instruction( - &mut function.dfg, + dfg, + dom, block, instruction_id, &mut side_effects_enabled_var, ); } - self.block_queue.extend(function.dfg[block].successors()); + self.block_queue.extend(dfg[block].successors()); } fn fold_constants_into_instruction( &mut self, dfg: &mut DataFlowGraph, - block: BasicBlockId, + dom: &mut DominatorTree, + mut block: BasicBlockId, id: InstructionId, side_effects_enabled_var: &mut ValueId, ) { - let constraint_simplification_mapping = - self.constraint_simplification_mappings.get(side_effects_enabled_var); - let instruction = Self::resolve_instruction( - id, - block, - dfg, - &mut self.dom, - constraint_simplification_mapping, - ); + let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); + + let instruction = + Self::resolve_instruction(id, block, dfg, dom, constraint_simplification_mapping); + let old_results = dfg.instruction_results(id).to_vec(); // If a copy of this instruction exists earlier in the block, then reuse the previous results. - if let Some(cached_results) = - self.get_cached(dfg, &instruction, *side_effects_enabled_var, block) + if let Some(cache_result) = + self.get_cached(dfg, dom, &instruction, *side_effects_enabled_var, block) { - Self::replace_result_ids(dfg, &old_results, cached_results); - return; - } + match cache_result { + CacheResult::Cached(cached) => { + Self::replace_result_ids(dfg, &old_results, cached); + return; + } + CacheResult::NeedToHoistToCommonBlock(dominator) => { + // Just change the block to insert in the common dominator instead. + // This will only move the current instance of the instruction right now. + // When constant folding is run a second time later on, it'll catch + // that the previous instance can be deduplicated to this instance. + block = dominator; + } + } + }; let new_results = // First try to inline a call to a brillig function with all constant arguments. @@ -314,7 +364,7 @@ impl<'brillig> Context<'brillig> { block: BasicBlockId, dfg: &DataFlowGraph, dom: &mut DominatorTree, - constraint_simplification_mapping: Option<&HashMap>>, + constraint_simplification_mapping: &HashMap, ) -> Instruction { let instruction = dfg[instruction_id].clone(); @@ -324,30 +374,28 @@ impl<'brillig> Context<'brillig> { // This allows us to reach a stable final `ValueId` for each instruction input as we add more // constraints to the cache. fn resolve_cache( + block: BasicBlockId, dfg: &DataFlowGraph, dom: &mut DominatorTree, - cache: Option<&HashMap>>, + cache: &HashMap, value_id: ValueId, - block: BasicBlockId, ) -> ValueId { let resolved_id = dfg.resolve(value_id); - let Some(cached_values) = cache.and_then(|cache| cache.get(&resolved_id)) else { - return resolved_id; - }; - - for (cached_block, cached_value) in cached_values { - // We can only use the simplified value if it was simplified in a block that dominates the current one - if dom.dominates(*cached_block, block) { - return resolve_cache(dfg, dom, cache, *cached_value, block); + match cache.get(&resolved_id) { + Some(simplification_cache) => { + if let Some(simplified) = simplification_cache.get(block, dom) { + resolve_cache(block, dfg, dom, cache, simplified) + } else { + resolved_id + } } + None => resolved_id, } - - resolved_id } // Resolve any inputs to ensure that we're comparing like-for-like instructions. instruction.map_values(|value_id| { - resolve_cache(dfg, dom, constraint_simplification_mapping, value_id, block) + resolve_cache(block, dfg, dom, constraint_simplification_mapping, value_id) }) } @@ -398,7 +446,7 @@ impl<'brillig> Context<'brillig> { self.get_constraint_map(side_effects_enabled_var) .entry(complex) .or_default() - .push((block, simple)); + .add(dfg, simple, block); } } } @@ -420,10 +468,12 @@ impl<'brillig> Context<'brillig> { } } + /// Get the simplification mapping from complex to simpler instructions, + /// which all depend on the same side effect condition variable. fn get_constraint_map( &mut self, side_effects_enabled_var: ValueId, - ) -> &mut HashMap> { + ) -> &mut HashMap { self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() } @@ -438,19 +488,20 @@ impl<'brillig> Context<'brillig> { } } - fn get_cached<'a>( - &'a mut self, + /// Get a cached result if it can be used in this context. + fn get_cached( + &self, dfg: &DataFlowGraph, + dom: &mut DominatorTree, instruction: &Instruction, side_effects_enabled_var: ValueId, block: BasicBlockId, - ) -> Option<&'a [ValueId]> { + ) -> Option { let results_for_instruction = self.cached_instruction_results.get(instruction)?; - let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); let predicate = predicate.then_some(side_effects_enabled_var); - results_for_instruction.get(&predicate)?.get(block, &mut self.dom) + results_for_instruction.get(&predicate)?.get(block, dom, instruction.has_side_effects(dfg)) } /// Checks if the given instruction is a call to a brillig function with all constant arguments. @@ -617,7 +668,9 @@ impl<'brillig> Context<'brillig> { impl ResultCache { /// Records that an `Instruction` in block `block` produced the result values `results`. fn cache(&mut self, block: BasicBlockId, results: Vec) { - self.results.push((block, results)); + if self.result.is_none() { + self.result = Some((block, results)); + } } /// Returns a set of [`ValueId`]s produced from a copy of this [`Instruction`] which sits @@ -626,13 +679,23 @@ impl ResultCache { /// We require that the cached instruction's block dominates `block` in order to avoid /// cycles causing issues (e.g. two instructions being replaced with the results of each other /// such that neither instruction exists anymore.) - fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<&[ValueId]> { - for (origin_block, results) in &self.results { + fn get( + &self, + block: BasicBlockId, + dom: &mut DominatorTree, + has_side_effects: bool, + ) -> Option { + self.result.as_ref().and_then(|(origin_block, results)| { if dom.dominates(*origin_block, block) { - return Some(results); + Some(CacheResult::Cached(results)) + } else if !has_side_effects { + // Insert a copy of this instruction in the common dominator + let dominator = dom.common_dominator(*origin_block, block); + Some(CacheResult::NeedToHoistToCommonBlock(dominator)) + } else { + None } - } - None + }) } } @@ -946,7 +1009,7 @@ mod test { // is disabled (only gets from index 0) and thus returns the wrong result. let src = " acir(inline) fn main f0 { - b0(v0: u1, v1: u64): + b0(v0: u1, v1: u64): enable_side_effects v0 v4 = make_array [Field 0, Field 1] : [Field; 2] v5 = array_get v4, index v1 -> Field @@ -954,7 +1017,8 @@ mod test { enable_side_effects v6 v7 = array_get v4, index v1 -> Field return - }"; + } + "; let ssa = Ssa::from_str(src).unwrap(); // Expected output is unchanged @@ -1021,7 +1085,6 @@ mod test { // v5 = call keccakf1600(v1) // v6 = call keccakf1600(v2) // } - // // Here we're checking a situation where two identical arrays are being initialized twice and being assigned separate `ValueId`s. // This would result in otherwise identical instructions not being deduplicated. let main_id = Id::test_new(0); @@ -1072,6 +1135,106 @@ mod test { assert_eq!(ending_instruction_count, 2); } + #[test] + fn deduplicate_across_blocks() { + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // v2 = not v0 + // return v2 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let b1 = builder.insert_block(); + + let v0 = builder.add_parameter(Type::bool()); + let _v1 = builder.insert_not(v0); + builder.terminate_with_jmp(b1, Vec::new()); + + builder.switch_to_block(b1); + let v2 = builder.insert_not(v0); + builder.terminate_with_return(vec![v2]); + + let ssa = builder.finish(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 1); + + // Expected output: + // + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // return v1 + // } + let ssa = ssa.fold_constants_using_constraints(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 0); + } + + #[test] + fn deduplicate_across_non_dominated_blocks() { + let src = " + brillig(inline) fn main f0 { + b0(v0: u32): + v2 = lt u32 1000, v0 + jmpif v2 then: b1, else: b2 + b1(): + v4 = add v0, u32 1 + v5 = lt v0, v4 + constrain v5 == u1 1 + jmp b2() + b2(): + v7 = lt u32 1000, v0 + jmpif v7 then: b3, else: b4 + b3(): + v8 = add v0, u32 1 + v9 = lt v0, v8 + constrain v9 == u1 1 + jmp b4() + b4(): + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + // v4 has been hoisted, although: + // - v5 has not yet been removed since it was encountered earlier in the program + // - v8 hasn't been recognized as a duplicate of v6 yet since they still reference v4 and + // v5 respectively + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v2 = lt u32 1000, v0 + v4 = add v0, u32 1 + jmpif v2 then: b1, else: b2 + b1(): + v5 = add v0, u32 1 + v6 = lt v0, v5 + constrain v6 == u1 1 + jmp b2() + b2(): + jmpif v2 then: b3, else: b4 + b3(): + v8 = lt v0, v4 + constrain v8 == u1 1 + jmp b4() + b4(): + return + } + "; + + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, expected); + } + #[test] fn inlines_brillig_call_without_arguments() { let src = " @@ -1248,50 +1411,6 @@ mod test { assert_normalized_ssa_equals(ssa, expected); } - #[test] - fn deduplicate_across_blocks() { - // fn main f0 { - // b0(v0: u1): - // v1 = not v0 - // jmp b1() - // b1(): - // v2 = not v0 - // return v2 - // } - let main_id = Id::test_new(0); - - // Compiling main - let mut builder = FunctionBuilder::new("main".into(), main_id); - let b1 = builder.insert_block(); - - let v0 = builder.add_parameter(Type::bool()); - let _v1 = builder.insert_not(v0); - builder.terminate_with_jmp(b1, Vec::new()); - - builder.switch_to_block(b1); - let v2 = builder.insert_not(v0); - builder.terminate_with_return(vec![v2]); - - let ssa = builder.finish(); - let main = ssa.main(); - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); - assert_eq!(main.dfg[b1].instructions().len(), 1); - - // Expected output: - // - // fn main f0 { - // b0(v0: u1): - // v1 = not v0 - // jmp b1() - // b1(): - // return v1 - // } - let ssa = ssa.fold_constants_using_constraints(); - let main = ssa.main(); - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); - assert_eq!(main.dfg[b1].instructions().len(), 0); - } - #[test] fn does_not_use_cached_constrain_in_block_that_is_not_dominated() { let src = " @@ -1376,4 +1495,4 @@ mod test { let ssa = ssa.fold_constants_using_constraints(); assert_normalized_ssa_equals(ssa, expected); } -} +} \ No newline at end of file From 25a35905cc69c1f373b40ea135c127b5d21fb400 Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Fri, 29 Nov 2024 14:55:14 +0000 Subject: [PATCH 9/9] . --- compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs | 2 +- compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs | 2 +- compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index fe981b73737..41c84c935b1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -1495,4 +1495,4 @@ mod test { let ssa = ssa.fold_constants_using_constraints(); assert_normalized_ssa_equals(ssa, expected); } -} \ No newline at end of file +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 08eb0905914..c8dd0e3c5a3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -1501,4 +1501,4 @@ mod test { let merged_ssa = Ssa::from_str(src).unwrap(); let _ = merged_ssa.flatten_cfg(); } -} \ No newline at end of file +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index f9fb5712892..53a31ae57c1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -1225,4 +1225,4 @@ mod tests { // We expect the program to be unchanged assert_normalized_ssa_equals(ssa, src); } -} \ No newline at end of file +}