diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 75b1047167b..7d55c53ccad 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -73,8 +73,13 @@ struct PerFunctionContext<'function> { /// argument values. values: HashMap, - /// Maps BasicBlockIds in the function being inlined to the new BasicBlockIds to use in the - /// function being inlined into. + /// Maps blocks in the source function to blocks in the function being inlined into, where + /// each mapping is from the start of a source block to an inlined block in which the + /// analogous program point occurs. + /// + /// Note that the starts of multiple source blocks can map into a single inlined block. + /// Conversely the whole of a source block is not guaranteed to map into a single inlined + /// block. blocks: HashMap, /// Maps InstructionIds from the function being inlined to the function being inlined into. @@ -220,7 +225,9 @@ impl<'function> PerFunctionContext<'function> { new_value } - /// Translate a block id from the source function to one of the target function. + /// Translates the program point representing the start of the given `source_block` to the + /// inlined block in which the analogous program point occurs. (Once inlined, the source + /// block's analogous program region may span multiple inlined blocks.) /// /// If the block isn't already known, this will insert a new block into the target function /// with the same parameter types as the source block. @@ -280,11 +287,14 @@ impl<'function> PerFunctionContext<'function> { let mut function_returns = vec![]; while let Some(source_block_id) = block_queue.pop() { + if seen_blocks.contains(&source_block_id) { + continue; + } let translated_block_id = self.translate_block(source_block_id, &mut block_queue); self.context.builder.switch_to_block(translated_block_id); seen_blocks.insert(source_block_id); - self.inline_block(ssa, source_block_id); + self.inline_block_instructions(ssa, source_block_id); if let Some((block, values)) = self.handle_terminator_instruction(source_block_id, &mut block_queue) @@ -329,7 +339,7 @@ impl<'function> PerFunctionContext<'function> { /// Inline each instruction in the given block into the function being inlined into. /// This may recurse if it finds another function to inline if a call instruction is within this block. - fn inline_block(&mut self, ssa: &Ssa, block_id: BasicBlockId) { + fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) { let block = &self.source_function.dfg[block_id]; for id in block.instructions() { match &self.source_function.dfg[*id] { @@ -446,7 +456,11 @@ impl<'function> PerFunctionContext<'function> { if self.inlining_main { self.context.builder.terminate_with_return(return_values.clone()); } - let block_id = self.translate_block(block_id, block_queue); + // Note that `translate_block` would take us back to the point at which the + // inlining of this source block began. Since additional blocks may have been + // inlined since, we are interested in the block representing the current program + // point, obtained via `current_block`. + let block_id = self.context.builder.current_block(); Some((block_id, return_values)) } } @@ -455,10 +469,13 @@ impl<'function> PerFunctionContext<'function> { #[cfg(test)] mod test { + use acvm::FieldElement; + use crate::ssa_refactor::{ ir::{ + basic_block::BasicBlockId, function::RuntimeType, - instruction::{BinaryOp, TerminatorInstruction}, + instruction::{BinaryOp, Intrinsic, TerminatorInstruction}, map::Id, types::Type, }, @@ -620,15 +637,26 @@ mod test { // b0(): // jmp b1() // b1(): + // jmp b2() + // b2(): + // jmp b3() + // b3(): + // jmp b4() + // b4(): + // jmp b5() + // b5(): + // jmp b6() + // b6(): // return Field 120 // } let inlined = ssa.inline_functions(); assert_eq!(inlined.functions.len(), 1); let main = inlined.main(); - let b1 = &main.dfg[b1]; + let b6_id: BasicBlockId = Id::test_new(6); + let b6 = &main.dfg[b6_id]; - match b1.terminator() { + match b6.terminator() { Some(TerminatorInstruction::Return { return_values }) => { assert_eq!(return_values.len(), 1); let value = main @@ -641,4 +669,83 @@ mod test { other => unreachable!("Unexpected terminator {other:?}"), } } + + #[test] + fn displaced_return_mapping() { + // This test is designed specifically to catch a regression in which the ids of blocks + // terminated by returns are badly tracked. As a result, the continuation of a source + // block after a call instruction could but inlined into a block that's already been + // terminated, producing an incorrect order and orphaning successors. + + // fn main f0 { + // b0(v0: u1): + // v2 = call f1(v0) + // call println(v2) + // return + // } + // fn inner1 f1 { + // b0(v0: u1): + // v2 = call f2(v0) + // return v2 + // } + // fn inner2 f2 { + // b0(v0: u1): + // jmpif v0 then: b1, else: b2 + // b1(): + // jmp b3(Field 1) + // b3(v3: Field): + // return v3 + // b2(): + // jmp b3(Field 2) + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); + + let main_cond = builder.add_parameter(Type::bool()); + let inner1_id = Id::test_new(1); + let inner1 = builder.import_function(inner1_id); + let main_v2 = builder.insert_call(inner1, vec![main_cond], vec![Type::field()])[0]; + let println = builder.import_intrinsic_id(Intrinsic::Println); + builder.insert_call(println, vec![main_v2], vec![]); + builder.terminate_with_return(vec![]); + + builder.new_function("inner1".into(), inner1_id); + let inner1_cond = builder.add_parameter(Type::bool()); + let inner2_id = Id::test_new(2); + let inner2 = builder.import_function(inner2_id); + let inner1_v2 = builder.insert_call(inner2, vec![inner1_cond], vec![Type::field()])[0]; + builder.terminate_with_return(vec![inner1_v2]); + + builder.new_function("inner2".into(), inner2_id); + let inner2_cond = builder.add_parameter(Type::bool()); + let then_block = builder.insert_block(); + let else_block = builder.insert_block(); + let join_block = builder.insert_block(); + builder.terminate_with_jmpif(inner2_cond, then_block, else_block); + builder.switch_to_block(then_block); + let one = builder.numeric_constant(FieldElement::one(), Type::field()); + builder.terminate_with_jmp(join_block, vec![one]); + builder.switch_to_block(else_block); + let two = builder.numeric_constant(FieldElement::from(2_u128), Type::field()); + builder.terminate_with_jmp(join_block, vec![two]); + let join_param = builder.add_block_parameter(join_block, Type::field()); + builder.switch_to_block(join_block); + builder.terminate_with_return(vec![join_param]); + + let ssa = builder.finish().inline_functions(); + // Expected result: + // fn main f3 { + // b0(v0: u1): + // jmpif v0 then: b1, else: b2 + // b1(): + // jmp b3(Field 1) + // b3(v3: Field): + // call println(v3) + // return + // b2(): + // jmp b3(Field 2) + // } + let main = ssa.main(); + assert_eq!(main.reachable_blocks().len(), 4); + } }