diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 943a57c1bc0..46f1e7a2765 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -199,22 +199,47 @@ struct Context<'f> { /// found, the top of this conditions stack is popped since we are no longer under that /// condition. If we are under multiple conditions (a nested if), the topmost condition is /// the most recent condition combined with all previous conditions via `And` instructions. - conditions: Vec<(BasicBlockId, ValueId)>, + condition_stack: Vec, /// Maps SSA array values with a slice type to their size. /// This is maintained by appropriate calls to the `SliceCapacityTracker` and is used by the `ValueMerger`. slice_sizes: HashMap, + + /// Stack of block arguments + /// When processing a block, we pop this stack to get its arguments + /// and at the end we push the arguments for his successor + arguments_stack: Vec>, } +#[derive(Clone)] pub(crate) struct Store { old_value: ValueId, new_value: ValueId, } -struct Branch { - condition: ValueId, +#[derive(Clone)] +struct ConditionalBranch { + // Contains the last processed block during the processing of the branch. last_block: BasicBlockId, + // The unresolved condition of the branch + old_condition: ValueId, + // The condition of the branch + condition: ValueId, + // The store values accumulated when processing the branch store_values: HashMap, + // The allocations accumulated when processing the branch + local_allocations: HashSet, +} + +struct ConditionalContext { + // Condition from the conditional statement + condition: ValueId, + // Block containing the conditional statement + entry_block: BasicBlockId, + // First block of the then branch + then_branch: ConditionalBranch, + // First block of the else branch + else_branch: Option, } fn flatten_function_cfg(function: &mut Function) { @@ -233,90 +258,117 @@ fn flatten_function_cfg(function: &mut Function) { store_values: HashMap::default(), local_allocations: HashSet::new(), branch_ends, - conditions: Vec::new(), slice_sizes: HashMap::default(), + condition_stack: Vec::new(), + arguments_stack: Vec::new(), }; context.flatten(); } impl<'f> Context<'f> { fn flatten(&mut self) { - // Start with following the terminator of the entry block since we don't - // need to flatten the entry block into itself. - self.handle_terminator(self.inserter.function.entry_block()); + // Flatten the CFG by inlining all instructions from the queued blocks + // until all blocks have been flattened. + // We follow the terminator of each block to determine which blocks to + // process next + let mut queue = vec![self.inserter.function.entry_block()]; + while let Some(block) = queue.pop() { + self.inline_block(block); + let to_process = self.handle_terminator(block, &queue); + for incoming_block in to_process { + if !queue.contains(&incoming_block) { + queue.push(incoming_block); + } + } + } } - /// Check the terminator of the given block and recursively inline any blocks reachable from - /// it. Since each block from a jmpif terminator is inlined successively, we must handle - /// instructions with side effects like constrain and store specially to preserve correctness. - /// For these instructions we must keep track of what the current condition is and modify - /// the instructions according to the module-level comment at the top of this file. Note that - /// the current condition is all the jmpif conditions required to reach the current block, - /// combined via `And` instructions. - /// - /// Returns the last block to be inlined. This is either the return block of the function or, - /// if self.conditions is not empty, the end block of the most recent condition. - fn handle_terminator(&mut self, block: BasicBlockId) -> BasicBlockId { - // As we recursively flatten inner blocks, we need to track the slice information - // for the outer block before we start recursively inlining - let outer_block_instructions = self.inserter.function.dfg[block].instructions(); - let mut capacity_tracker = SliceCapacityTracker::new(&self.inserter.function.dfg); - for instruction in outer_block_instructions { - let results = self.inserter.function.dfg.instruction_results(*instruction); - let instruction = &self.inserter.function.dfg[*instruction]; + /// Returns the updated condition so that + /// it is 'AND-ed' with the previous condition (if any) + fn link_condition(&mut self, condition: ValueId) -> ValueId { + // Retrieve the previous condition + if let Some(context) = self.condition_stack.last() { + let previous_branch = context.else_branch.as_ref().unwrap_or(&context.then_branch); + let and = Instruction::binary(BinaryOp::And, previous_branch.condition, condition); + self.insert_instruction(and, CallStack::new()) + } else { + condition + } + } + + /// Returns the current condition + fn get_last_condition(&self) -> Option { + self.condition_stack.last().map(|context| match &context.else_branch { + Some(else_branch) => else_branch.condition, + None => context.then_branch.condition, + }) + } + + // Inline all instructions from the given block into the entry block, and track slice capacities + fn inline_block(&mut self, block: BasicBlockId) { + if self.inserter.function.entry_block() == block { + // we do not inline the entry block into itself + // for the outer block before we start inlining + let outer_block_instructions = self.inserter.function.dfg[block].instructions(); + let mut capacity_tracker = SliceCapacityTracker::new(&self.inserter.function.dfg); + for instruction in outer_block_instructions { + let results = self.inserter.function.dfg.instruction_results(*instruction); + let instruction = &self.inserter.function.dfg[*instruction]; + capacity_tracker.collect_slice_information( + instruction, + &mut self.slice_sizes, + results.to_vec(), + ); + } + + return; + } + + let arguments = self.arguments_stack.pop().unwrap(); + self.inserter.remember_block_params(block, &arguments); + + // If this is not a separate variable, clippy gets confused and says the to_vec is + // unnecessary, when removing it actually causes an aliasing/mutability error. + let instructions = self.inserter.function.dfg[block].instructions().to_vec(); + for instruction in instructions.iter() { + let results = self.push_instruction(*instruction); + let (instruction, _) = self.inserter.map_instruction(*instruction); + let mut capacity_tracker = SliceCapacityTracker::new(&self.inserter.function.dfg); capacity_tracker.collect_slice_information( - instruction, + &instruction, &mut self.slice_sizes, - results.to_vec(), + results, ); } + } - match self.inserter.function.dfg[block].unwrap_terminator() { + /// Returns the list of blocks that need to be processed after the given block + /// For a normal block, it would be its successor + /// For blocks related to a conditional statement, we ensure to process + /// the 'then-branch', then the 'else-branch' (if it exists), and finally the end block + fn handle_terminator( + &mut self, + block: BasicBlockId, + work_list: &[BasicBlockId], + ) -> Vec { + let terminator = self.inserter.function.dfg[block].unwrap_terminator().clone(); + match &terminator { TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => { - let old_condition = *condition; - let then_block = *then_destination; - let else_block = *else_destination; - let then_condition = self.inserter.resolve(old_condition); - - let one = FieldElement::one(); - let then_branch = - self.inline_branch(block, then_block, old_condition, then_condition, one); - - let else_condition = - self.insert_instruction(Instruction::Not(then_condition), CallStack::new()); - let zero = FieldElement::zero(); - - // Make sure the else branch sees the previous values of each store - // rather than any values created in the 'then' branch. - self.undo_stores_in_then_branch(&then_branch); - - let else_branch = - self.inline_branch(block, else_block, old_condition, else_condition, zero); - - // We must remember to reset whether side effects are enabled when both branches - // end, in addition to resetting the value of old_condition since it is set to - // known to be true/false within the then/else branch respectively. - self.insert_current_side_effects_enabled(); - - // We must map back to `then_condition` here. Mapping `old_condition` to itself would - // lose any previous mappings. - self.inserter.map_value(old_condition, then_condition); - - // While there is a condition on the stack we don't compile outside the condition - // until it is popped. This ensures we inline the full then and else branches - // before continuing from the end of the conditional here where they can be merged properly. - let end = self.branch_ends[&block]; - self.inline_branch_end(end, then_branch, else_branch) + self.arguments_stack.push(vec![]); + self.if_start(condition, then_destination, else_destination, &block) } TerminatorInstruction::Jmp { destination, arguments, call_stack: _ } => { - if let Some((end_block, _)) = self.conditions.last() { - if destination == end_block { - return block; + let arguments = vecmap(arguments.clone(), |value| self.inserter.resolve(value)); + self.arguments_stack.push(arguments); + if work_list.contains(destination) { + if work_list.last() == Some(destination) { + self.else_stop(&block) + } else { + self.then_stop(&block) } + } else { + vec![*destination] } - let destination = *destination; - let arguments = vecmap(arguments.clone(), |value| self.inserter.resolve(value)); - self.inline_block(destination, &arguments) } TerminatorInstruction::Return { return_values, call_stack } => { let call_stack = call_stack.clone(); @@ -326,133 +378,133 @@ impl<'f> Context<'f> { let entry = self.inserter.function.entry_block(); self.inserter.function.dfg.set_block_terminator(entry, new_return); - block + vec![] } } } - /// Push a condition to the stack of conditions. - /// - /// This condition should be present while we're inlining each block reachable from the 'then' - /// branch of a jmpif instruction, until the branches eventually join back together. Likewise, - /// !condition should be present while we're inlining each block reachable from the 'else' - /// branch of a jmpif instruction until the join block. - fn push_condition(&mut self, start_block: BasicBlockId, condition: ValueId) { - let end_block = self.branch_ends[&start_block]; - - if let Some((_, previous_condition)) = self.conditions.last() { - let and = Instruction::binary(BinaryOp::And, *previous_condition, condition); - let new_condition = self.insert_instruction(and, CallStack::new()); - self.conditions.push((end_block, new_condition)); - } else { - self.conditions.push((end_block, condition)); + /// Process a conditional statement + fn if_start( + &mut self, + condition: &ValueId, + then_destination: &BasicBlockId, + else_destination: &BasicBlockId, + if_entry: &BasicBlockId, + ) -> Vec { + // manage conditions + let old_condition = *condition; + let then_condition = self.inserter.resolve(old_condition); + + let one = FieldElement::one(); + let old_stores = std::mem::take(&mut self.store_values); + let old_allocations = std::mem::take(&mut self.local_allocations); + let branch = ConditionalBranch { + old_condition, + condition: self.link_condition(then_condition), + store_values: old_stores, + local_allocations: old_allocations, + last_block: *then_destination, + }; + let cond_context = ConditionalContext { + condition: then_condition, + entry_block: *if_entry, + then_branch: branch, + else_branch: None, + }; + self.condition_stack.push(cond_context); + self.insert_current_side_effects_enabled(); + // Optimization: within the then branch we know the condition to be true, so replace + // any references of it within this branch with true. Likewise, do the same with false + // with the else branch. We must be careful not to replace the condition if it is a + // known constant, otherwise we can end up setting 1 = 0 or vice-versa. + if self.inserter.function.dfg.get_numeric_constant(old_condition).is_none() { + let known_value = self.inserter.function.dfg.make_constant(one, Type::bool()); + + self.inserter.map_value(old_condition, known_value); } + vec![self.branch_ends[if_entry], *else_destination, *then_destination] } - /// Insert a new instruction into the function's entry block. - /// Unlike push_instruction, this function will not map any ValueIds. - /// within the given instruction, nor will it modify self.values in any way. - fn insert_instruction(&mut self, instruction: Instruction, call_stack: CallStack) -> ValueId { - let block = self.inserter.function.entry_block(); - self.inserter - .function - .dfg - .insert_instruction_and_results(instruction, block, None, call_stack) - .first() + /// Switch context to the 'else-branch' + fn then_stop(&mut self, block: &BasicBlockId) -> Vec { + let mut cond_context = self.condition_stack.pop().unwrap(); + cond_context.then_branch.last_block = *block; + + let else_condition = + self.insert_instruction(Instruction::Not(cond_context.condition), CallStack::new()); + let else_condition = self.link_condition(else_condition); + + let zero = FieldElement::zero(); + // Make sure the else branch sees the previous values of each store + // rather than any values created in the 'then' branch. + let old_stores = std::mem::take(&mut cond_context.then_branch.store_values); + cond_context.then_branch.store_values = std::mem::take(&mut self.store_values); + self.undo_stores_in_then_branch(&cond_context.then_branch.store_values); + + let old_allocations = std::mem::take(&mut self.local_allocations); + let else_branch = ConditionalBranch { + old_condition: cond_context.then_branch.old_condition, + condition: else_condition, + store_values: old_stores, + local_allocations: old_allocations, + last_block: *block, + }; + let old_condition = else_branch.old_condition; + cond_context.then_branch.local_allocations.clear(); + cond_context.else_branch = Some(else_branch); + self.condition_stack.push(cond_context); + + self.insert_current_side_effects_enabled(); + // Optimization: within the then branch we know the condition to be true, so replace + // any references of it within this branch with true. Likewise, do the same with false + // with the else branch. We must be careful not to replace the condition if it is a + // known constant, otherwise we can end up setting 1 = 0 or vice-versa. + if self.inserter.function.dfg.get_numeric_constant(old_condition).is_none() { + let known_value = self.inserter.function.dfg.make_constant(zero, Type::bool()); + + self.inserter.map_value(old_condition, known_value); + } + assert_eq!(self.cfg.successors(*block).len(), 1); + vec![self.cfg.successors(*block).next().unwrap()] } - /// Inserts a new instruction into the function's entry block, using the given - /// control type variables to specify result types if needed. - /// Unlike push_instruction, this function will not map any ValueIds. - /// within the given instruction, nor will it modify self.values in any way. - fn insert_instruction_with_typevars( - &mut self, - instruction: Instruction, - ctrl_typevars: Option>, - ) -> InsertInstructionResult { - let block = self.inserter.function.entry_block(); - self.inserter.function.dfg.insert_instruction_and_results( - instruction, - block, - ctrl_typevars, - CallStack::new(), - ) - } + /// Process the 'exit' block of a conditional statement + fn else_stop(&mut self, block: &BasicBlockId) -> Vec { + let mut cond_context = self.condition_stack.pop().unwrap(); + if cond_context.else_branch.is_none() { + // then_stop() has not been called, this means that the conditional statement has no else branch + // so we simply do the then_stop() now + self.condition_stack.push(cond_context); + self.then_stop(block); + cond_context = self.condition_stack.pop().unwrap(); + } - /// Checks the branch condition on the top of the stack and uses it to build and insert an - /// `EnableSideEffects` instruction into the entry block. - /// - /// If the stack is empty, a "true" u1 constant is taken to be the active condition. This is - /// necessary for re-enabling side-effects when re-emerging to a branch depth of 0. - fn insert_current_side_effects_enabled(&mut self) { - let condition = match self.conditions.last() { - Some((_, cond)) => *cond, - None => { - self.inserter.function.dfg.make_constant(FieldElement::one(), Type::unsigned(1)) - } - }; - let enable_side_effects = Instruction::EnableSideEffects { condition }; - self.insert_instruction_with_typevars(enable_side_effects, None); - } + let mut else_branch = cond_context.else_branch.unwrap(); + let stores_in_branch = std::mem::replace(&mut self.store_values, else_branch.store_values); + self.local_allocations = std::mem::take(&mut else_branch.local_allocations); + else_branch.last_block = *block; + else_branch.store_values = stores_in_branch; + cond_context.else_branch = Some(else_branch); - /// Inline one branch of a jmpif instruction. - /// - /// This will continue inlining recursively until the next end block is reached where each branch - /// of the jmpif instruction is joined back into a single block. - /// - /// Within a branch of a jmpif instruction, we can assume the condition of the jmpif to be - /// always true or false, depending on which branch we're in. - /// - /// Returns the ending block / join block of this branch. - fn inline_branch( - &mut self, - jmpif_block: BasicBlockId, - destination: BasicBlockId, - old_condition: ValueId, - new_condition: ValueId, - condition_value: FieldElement, - ) -> Branch { - if destination == self.branch_ends[&jmpif_block] { - // If the branch destination is the same as the end of the branch, this must be the - // 'else' case of an if with no else - so there is no else branch. - Branch { - condition: new_condition, - // The last block here is somewhat arbitrary. It only matters that it has no Jmp - // args that will be merged by inline_branch_end. Since jmpifs don't have - // block arguments, it is safe to use the jmpif block here. - last_block: jmpif_block, - store_values: HashMap::default(), - } - } else { - self.push_condition(jmpif_block, new_condition); - self.insert_current_side_effects_enabled(); - let old_stores = std::mem::take(&mut self.store_values); - let old_allocations = std::mem::take(&mut self.local_allocations); - - // Optimization: within the then branch we know the condition to be true, so replace - // any references of it within this branch with true. Likewise, do the same with false - // with the else branch. We must be careful not to replace the condition if it is a - // known constant, otherwise we can end up setting 1 = 0 or vice-versa. - if self.inserter.function.dfg.get_numeric_constant(old_condition).is_none() { - let known_value = - self.inserter.function.dfg.make_constant(condition_value, Type::bool()); - - self.inserter.map_value(old_condition, known_value); - } + // We must remember to reset whether side effects are enabled when both branches + // end, in addition to resetting the value of old_condition since it is set to + // known to be true/false within the then/else branch respectively. + self.insert_current_side_effects_enabled(); - let final_block = self.inline_block(destination, &[]); + // We must map back to `then_condition` here. Mapping `old_condition` to itself would + // lose any previous mappings. + self.inserter + .map_value(cond_context.then_branch.old_condition, cond_context.then_branch.condition); - self.conditions.pop(); + // While there is a condition on the stack we don't compile outside the condition + // until it is popped. This ensures we inline the full then and else branches + // before continuing from the end of the conditional here where they can be merged properly. + let end = self.branch_ends[&cond_context.entry_block]; - let stores_in_branch = std::mem::replace(&mut self.store_values, old_stores); - self.local_allocations = old_allocations; + // Merge arguments and stores from the else/end branches + self.inline_branch_end(end, cond_context); - Branch { - condition: new_condition, - last_block: final_block, - store_values: stores_in_branch, - } - } + vec![self.cfg.successors(*block).next().unwrap()] } /// Inline the ending block of a branch, the point where all blocks from a jmpif instruction @@ -467,15 +519,17 @@ impl<'f> Context<'f> { fn inline_branch_end( &mut self, destination: BasicBlockId, - then_branch: Branch, - else_branch: Branch, + cond_context: ConditionalContext, ) -> BasicBlockId { assert_eq!(self.cfg.predecessors(destination).len(), 2); + let last_then = cond_context.then_branch.last_block; + let mut else_args = Vec::new(); + if cond_context.else_branch.is_some() { + let last_else = cond_context.else_branch.clone().unwrap().last_block; + else_args = self.inserter.function.dfg[last_else].terminator_arguments().to_vec(); + } - let then_args = - self.inserter.function.dfg[then_branch.last_block].terminator_arguments().to_vec(); - let else_args = - self.inserter.function.dfg[else_branch.last_block].terminator_arguments().to_vec(); + let then_args = self.inserter.function.dfg[last_then].terminator_arguments().to_vec(); let params = self.inserter.function.dfg.block_parameters(destination); assert_eq!(params.len(), then_args.len()); @@ -500,17 +554,64 @@ impl<'f> Context<'f> { // Cannot include this in the previous vecmap since it requires exclusive access to self let args = vecmap(args, |(then_arg, else_arg)| { value_merger.merge_values( - then_branch.condition, - else_branch.condition, + cond_context.then_branch.condition, + cond_context.else_branch.clone().unwrap().condition, then_arg, else_arg, ) }); - self.merge_stores(then_branch, else_branch); + self.merge_stores(cond_context.then_branch, cond_context.else_branch); + self.arguments_stack.pop(); + self.arguments_stack.pop(); + self.arguments_stack.push(args); + destination + } - // insert merge instruction - self.inline_block(destination, &args) + /// Insert a new instruction into the function's entry block. + /// Unlike push_instruction, this function will not map any ValueIds. + /// within the given instruction, nor will it modify self.values in any way. + fn insert_instruction(&mut self, instruction: Instruction, call_stack: CallStack) -> ValueId { + let block = self.inserter.function.entry_block(); + self.inserter + .function + .dfg + .insert_instruction_and_results(instruction, block, None, call_stack) + .first() + } + + /// Inserts a new instruction into the function's entry block, using the given + /// control type variables to specify result types if needed. + /// Unlike push_instruction, this function will not map any ValueIds. + /// within the given instruction, nor will it modify self.values in any way. + fn insert_instruction_with_typevars( + &mut self, + instruction: Instruction, + ctrl_typevars: Option>, + ) -> InsertInstructionResult { + let block = self.inserter.function.entry_block(); + self.inserter.function.dfg.insert_instruction_and_results( + instruction, + block, + ctrl_typevars, + CallStack::new(), + ) + } + + /// Checks the branch condition on the top of the stack and uses it to build and insert an + /// `EnableSideEffects` instruction into the entry block. + /// + /// If the stack is empty, a "true" u1 constant is taken to be the active condition. This is + /// necessary for re-enabling side-effects when re-emerging to a branch depth of 0. + fn insert_current_side_effects_enabled(&mut self) { + let condition = match self.get_last_condition() { + Some(cond) => cond, + None => { + self.inserter.function.dfg.make_constant(FieldElement::one(), Type::unsigned(1)) + } + }; + let enable_side_effects = Instruction::EnableSideEffects { condition }; + self.insert_instruction_with_typevars(enable_side_effects, None); } /// Merge any store instructions found in each branch. @@ -518,7 +619,11 @@ impl<'f> Context<'f> { /// This function relies on the 'then' branch being merged before the 'else' branch of a jmpif /// instruction. If this ordering is changed, the ordering that store values are merged within /// this function also needs to be changed to reflect that. - fn merge_stores(&mut self, then_branch: Branch, else_branch: Branch) { + fn merge_stores( + &mut self, + then_branch: ConditionalBranch, + else_branch: Option, + ) { // Address -> (then_value, else_value, value_before_the_if) let mut new_map = BTreeMap::new(); @@ -526,11 +631,13 @@ impl<'f> Context<'f> { new_map.insert(address, (store.new_value, store.old_value, store.old_value)); } - for (address, store) in else_branch.store_values { - if let Some(entry) = new_map.get_mut(&address) { - entry.1 = store.new_value; - } else { - new_map.insert(address, (store.old_value, store.new_value, store.old_value)); + if else_branch.is_some() { + for (address, store) in else_branch.clone().unwrap().store_values { + if let Some(entry) = new_map.get_mut(&address) { + entry.1 = store.new_value; + } else { + new_map.insert(address, (store.old_value, store.new_value, store.old_value)); + } } } @@ -544,8 +651,11 @@ impl<'f> Context<'f> { } let then_condition = then_branch.condition; - let else_condition = else_branch.condition; - + let else_condition = if let Some(branch) = else_branch { + branch.condition + } else { + self.inserter.function.dfg.make_constant(FieldElement::zero(), Type::bool()) + }; let block = self.inserter.function.entry_block(); let mut value_merger = @@ -607,35 +717,6 @@ impl<'f> Context<'f> { } } - /// Inline all instructions from the given destination block into the entry block. - /// Afterwards, check the block's terminator and continue inlining recursively. - /// - /// Returns the final block that was inlined. - /// - /// Expects that the `arguments` given are already translated via self.inserter.resolve. - /// If they are not, it is possible some values which no longer exist, such as block - /// parameters, will be kept in the program. - fn inline_block(&mut self, destination: BasicBlockId, arguments: &[ValueId]) -> BasicBlockId { - self.inserter.remember_block_params(destination, arguments); - - // If this is not a separate variable, clippy gets confused and says the to_vec is - // unnecessary, when removing it actually causes an aliasing/mutability error. - let instructions = self.inserter.function.dfg[destination].instructions().to_vec(); - - for instruction in instructions.iter() { - let results = self.push_instruction(*instruction); - let (instruction, _) = self.inserter.map_instruction(*instruction); - let mut capacity_tracker = SliceCapacityTracker::new(&self.inserter.function.dfg); - capacity_tracker.collect_slice_information( - &instruction, - &mut self.slice_sizes, - results, - ); - } - - self.handle_terminator(destination) - } - /// Push the given instruction to the end of the entry block of the current function. /// /// Note that each ValueId of the instruction will be mapped via self.inserter.resolve. @@ -666,7 +747,7 @@ impl<'f> Context<'f> { instruction: Instruction, call_stack: CallStack, ) -> Instruction { - if let Some((_, condition)) = self.conditions.last().copied() { + if let Some(condition) = self.get_last_condition() { match instruction { Instruction::Constrain(lhs, rhs, message) => { // Replace constraint `lhs == rhs` with `condition * lhs == condition * rhs`. @@ -741,8 +822,8 @@ impl<'f> Context<'f> { } } - fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) { - for (address, store) in &then_branch.store_values { + fn undo_stores_in_then_branch(&mut self, store_values: &HashMap) { + for (address, store) in store_values { let address = *address; let value = store.old_value; self.insert_instruction_with_typevars(Instruction::Store { address, value }, None);