diff --git a/compiler/noirc_evaluator/src/ssa/opt/array_set.rs b/compiler/noirc_evaluator/src/ssa/opt/array_set.rs index 491a17adb66..6d48b8c0d67 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/array_set.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/array_set.rs @@ -2,12 +2,13 @@ use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::DataFlowGraph, - instruction::{Instruction, InstructionId}, + instruction::{Instruction, InstructionId, TerminatorInstruction}, types::Type::{Array, Slice}, + value::ValueId, }, ssa_gen::Ssa, }; -use fxhash::{FxHashMap as HashMap, FxHashSet}; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; impl Ssa { /// Map arrays with the last instruction that uses it @@ -16,28 +17,42 @@ impl Ssa { #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn array_set_optimization(mut self) -> Self { for func in self.functions.values_mut() { - let mut reachable_blocks = func.reachable_blocks(); - let block = if !func.runtime().is_entry_point() { + let reachable_blocks = func.reachable_blocks(); + + if !func.runtime().is_entry_point() { assert_eq!(reachable_blocks.len(), 1, "Expected there to be 1 block remaining in Acir function for array_set optimization"); - reachable_blocks.pop_first().unwrap() - } else { - // We only apply the array set optimization in the return block of Brillig functions - func.find_last_block() - }; + } + let mut array_to_last_use = HashMap::default(); + let mut instructions_to_update = HashSet::default(); + let mut arrays_from_load = HashSet::default(); - let instructions_to_update = analyze_last_uses(&func.dfg, block); - make_mutable(&mut func.dfg, block, instructions_to_update); + for block in reachable_blocks.iter() { + analyze_last_uses( + &func.dfg, + *block, + &mut array_to_last_use, + &mut instructions_to_update, + &mut arrays_from_load, + ); + } + for block in reachable_blocks { + make_mutable(&mut func.dfg, block, &instructions_to_update); + } } self } } -/// Returns the set of ArraySet instructions that can be made mutable +/// Builds the set of ArraySet instructions that can be made mutable /// because their input value is unused elsewhere afterward. -fn analyze_last_uses(dfg: &DataFlowGraph, block_id: BasicBlockId) -> FxHashSet { +fn analyze_last_uses( + dfg: &DataFlowGraph, + block_id: BasicBlockId, + array_to_last_use: &mut HashMap, + instructions_that_can_be_made_mutable: &mut HashSet, + arrays_from_load: &mut HashSet, +) { let block = &dfg[block_id]; - let mut array_to_last_use = HashMap::default(); - let mut instructions_that_can_be_made_mutable = FxHashSet::default(); for instruction_id in block.instructions() { match &dfg[*instruction_id] { @@ -54,7 +69,22 @@ fn analyze_last_uses(dfg: &DataFlowGraph, block_id: BasicBlockId) -> FxHashSet { for argument in arguments { @@ -68,18 +98,22 @@ fn analyze_last_uses(dfg: &DataFlowGraph, block_id: BasicBlockId) -> FxHashSet { + let result = dfg.instruction_results(*instruction_id)[0]; + if matches!(dfg.type_of_value(result), Array { .. } | Slice { .. }) { + arrays_from_load.insert(result); + } + } _ => (), } } - - instructions_that_can_be_made_mutable } /// Make each ArraySet instruction in `instructions_to_update` mutable. fn make_mutable( dfg: &mut DataFlowGraph, block_id: BasicBlockId, - instructions_to_update: FxHashSet, + instructions_to_update: &HashSet, ) { if instructions_to_update.is_empty() { return; @@ -105,3 +139,129 @@ fn make_mutable( *dfg[block_id].instructions_mut() = instructions; } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use im::vector; + + use crate::ssa::{ + function_builder::FunctionBuilder, + ir::{ + function::RuntimeType, + instruction::{BinaryOp, Instruction}, + map::Id, + types::Type, + }, + }; + + #[test] + fn array_set_in_loop_with_conditional_clone() { + // We want to make sure that we do not mark a single array set mutable which is loaded + // from and cloned in a loop. If the array is inadvertently marked mutable, and is cloned in a previous iteration + // of the loop, its clone will also be altered. + // + // acir(inline) fn main f0 { + // b0(): + // v2 = allocate + // store [Field 0, Field 0, Field 0, Field 0, Field 0] at v2 + // v3 = allocate + // store [Field 0, Field 0, Field 0, Field 0, Field 0] at v3 + // jmp b1(u32 0) + // b1(v5: u32): + // v7 = lt v5, u32 5 + // jmpif v7 then: b3, else: b2 + // b3(): + // v8 = eq v5, u32 5 + // jmpif v8 then: b4, else: b5 + // b4(): + // v9 = load v2 + // store v9 at v3 + // jmp b5() + // b5(): + // v10 = load v2 + // v12 = array_set v10, index v5, value Field 20 + // store v12 at v2 + // v14 = add v5, u32 1 + // jmp b1(v14) + // b2(): + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + builder.set_runtime(RuntimeType::Brillig); + + let array_type = Type::Array(Arc::new(vec![Type::field()]), 5); + let zero = builder.field_constant(0u128); + let array_constant = + builder.array_constant(vector![zero, zero, zero, zero, zero], array_type.clone()); + + let v2 = builder.insert_allocate(array_type.clone()); + + builder.insert_store(v2, array_constant); + + let v3 = builder.insert_allocate(array_type.clone()); + builder.insert_store(v3, array_constant); + + let b1 = builder.insert_block(); + let zero_u32 = builder.numeric_constant(0u128, Type::unsigned(32)); + builder.terminate_with_jmp(b1, vec![zero_u32]); + + // Loop header + builder.switch_to_block(b1); + let v5 = builder.add_block_parameter(b1, Type::unsigned(32)); + let five = builder.numeric_constant(5u128, Type::unsigned(32)); + let v7 = builder.insert_binary(v5, BinaryOp::Lt, five); + + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + let b4 = builder.insert_block(); + let b5 = builder.insert_block(); + builder.terminate_with_jmpif(v7, b3, b2); + + // Loop body + // b3 is the if statement conditional + builder.switch_to_block(b3); + let two = builder.numeric_constant(5u128, Type::unsigned(32)); + let v8 = builder.insert_binary(v5, BinaryOp::Eq, two); + builder.terminate_with_jmpif(v8, b4, b5); + + // b4 is the rest of the loop after the if statement + builder.switch_to_block(b4); + let v9 = builder.insert_load(v2, array_type.clone()); + builder.insert_store(v3, v9); + builder.terminate_with_jmp(b5, vec![]); + + builder.switch_to_block(b5); + let v10 = builder.insert_load(v2, array_type.clone()); + let twenty = builder.field_constant(20u128); + let v12 = builder.insert_array_set(v10, v5, twenty); + builder.insert_store(v2, v12); + let one = builder.numeric_constant(1u128, Type::unsigned(32)); + let v14 = builder.insert_binary(v5, BinaryOp::Add, one); + builder.terminate_with_jmp(b1, vec![v14]); + + builder.switch_to_block(b2); + builder.terminate_with_return(vec![]); + + let ssa = builder.finish(); + // We expect the same result as above + let ssa = ssa.array_set_optimization(); + + let main = ssa.main(); + assert_eq!(main.reachable_blocks().len(), 6); + + let array_set_instructions = main.dfg[b5] + .instructions() + .iter() + .filter(|instruction| matches!(&main.dfg[**instruction], Instruction::ArraySet { .. })) + .collect::>(); + + assert_eq!(array_set_instructions.len(), 1); + if let Instruction::ArraySet { mutable, .. } = &main.dfg[*array_set_instructions[0]] { + // The single array set should not be marked mutable + assert!(!mutable); + } + } +}