Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): function inlining orphans calls #1747

Merged
merged 7 commits into from
Jun 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 116 additions & 9 deletions crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ struct PerFunctionContext<'function> {
/// argument values.
values: HashMap<ValueId, ValueId>,

/// Maps BasicBlockIds in the function being inlined to the new BasicBlockIds to use in the
/// function being inlined into.
/// Maps blocks in the source function to blocks in the function being inlined into, where
/// each mapping is from the start of a source block to an inlined block in which the
/// analogous program point occurs.
///
/// Note that the starts of multiple source blocks can map into a single inlined block.
/// Conversely the whole of a source block is not guaranteed to map into a single inlined
/// block.
blocks: HashMap<BasicBlockId, BasicBlockId>,

/// Maps InstructionIds from the function being inlined to the function being inlined into.
Expand Down Expand Up @@ -220,7 +225,9 @@ impl<'function> PerFunctionContext<'function> {
new_value
}

/// Translate a block id from the source function to one of the target function.
/// Translates the program point representing the start of the given `source_block` to the
/// inlined block in which the analogous program point occurs. (Once inlined, the source
/// block's analogous program region may span multiple inlined blocks.)
///
/// If the block isn't already known, this will insert a new block into the target function
/// with the same parameter types as the source block.
Expand Down Expand Up @@ -280,11 +287,14 @@ impl<'function> PerFunctionContext<'function> {
let mut function_returns = vec![];

while let Some(source_block_id) = block_queue.pop() {
if seen_blocks.contains(&source_block_id) {
continue;
}
jfecher marked this conversation as resolved.
Show resolved Hide resolved
let translated_block_id = self.translate_block(source_block_id, &mut block_queue);
self.context.builder.switch_to_block(translated_block_id);

seen_blocks.insert(source_block_id);
self.inline_block(ssa, source_block_id);
self.inline_block_instructions(ssa, source_block_id);

if let Some((block, values)) =
self.handle_terminator_instruction(source_block_id, &mut block_queue)
Expand Down Expand Up @@ -329,7 +339,7 @@ impl<'function> PerFunctionContext<'function> {

/// Inline each instruction in the given block into the function being inlined into.
/// This may recurse if it finds another function to inline if a call instruction is within this block.
fn inline_block(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
let block = &self.source_function.dfg[block_id];
for id in block.instructions() {
match &self.source_function.dfg[*id] {
Expand Down Expand Up @@ -446,7 +456,11 @@ impl<'function> PerFunctionContext<'function> {
if self.inlining_main {
self.context.builder.terminate_with_return(return_values.clone());
}
let block_id = self.translate_block(block_id, block_queue);
// Note that `translate_block` would take us back to the point at which the
// inlining of this source block began. Since additional blocks may have been
// inlined since, we are interested in the block representing the current program
// point, obtained via `current_block`.
let block_id = self.context.builder.current_block();
jfecher marked this conversation as resolved.
Show resolved Hide resolved
Some((block_id, return_values))
}
}
Expand All @@ -455,10 +469,13 @@ impl<'function> PerFunctionContext<'function> {

#[cfg(test)]
mod test {
use acvm::FieldElement;

use crate::ssa_refactor::{
ir::{
basic_block::BasicBlockId,
function::RuntimeType,
instruction::{BinaryOp, TerminatorInstruction},
instruction::{BinaryOp, Intrinsic, TerminatorInstruction},
map::Id,
types::Type,
},
Expand Down Expand Up @@ -620,15 +637,26 @@ mod test {
// b0():
// jmp b1()
// b1():
// jmp b2()
// b2():
// jmp b3()
// b3():
// jmp b4()
// b4():
// jmp b5()
// b5():
// jmp b6()
// b6():
// return Field 120
// }
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);

let main = inlined.main();
let b1 = &main.dfg[b1];
let b6_id: BasicBlockId = Id::test_new(6);
let b6 = &main.dfg[b6_id];

match b1.terminator() {
match b6.terminator() {
Some(TerminatorInstruction::Return { return_values }) => {
assert_eq!(return_values.len(), 1);
let value = main
Expand All @@ -641,4 +669,83 @@ mod test {
other => unreachable!("Unexpected terminator {other:?}"),
}
}

#[test]
fn displaced_return_mapping() {
// This test is designed specifically to catch a regression in which the ids of blocks
// terminated by returns are badly tracked. As a result, the continuation of a source
// block after a call instruction could but inlined into a block that's already been
// terminated, producing an incorrect order and orphaning successors.

// fn main f0 {
// b0(v0: u1):
// v2 = call f1(v0)
// call println(v2)
// return
// }
// fn inner1 f1 {
// b0(v0: u1):
// v2 = call f2(v0)
// return v2
// }
// fn inner2 f2 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// jmp b3(Field 1)
// b3(v3: Field):
// return v3
// b2():
// jmp b3(Field 2)
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

let main_cond = builder.add_parameter(Type::bool());
let inner1_id = Id::test_new(1);
let inner1 = builder.import_function(inner1_id);
let main_v2 = builder.insert_call(inner1, vec![main_cond], vec![Type::field()])[0];
let println = builder.import_intrinsic_id(Intrinsic::Println);
builder.insert_call(println, vec![main_v2], vec![]);
builder.terminate_with_return(vec![]);

builder.new_function("inner1".into(), inner1_id);
let inner1_cond = builder.add_parameter(Type::bool());
let inner2_id = Id::test_new(2);
let inner2 = builder.import_function(inner2_id);
let inner1_v2 = builder.insert_call(inner2, vec![inner1_cond], vec![Type::field()])[0];
builder.terminate_with_return(vec![inner1_v2]);

builder.new_function("inner2".into(), inner2_id);
let inner2_cond = builder.add_parameter(Type::bool());
let then_block = builder.insert_block();
let else_block = builder.insert_block();
let join_block = builder.insert_block();
builder.terminate_with_jmpif(inner2_cond, then_block, else_block);
builder.switch_to_block(then_block);
let one = builder.numeric_constant(FieldElement::one(), Type::field());
builder.terminate_with_jmp(join_block, vec![one]);
builder.switch_to_block(else_block);
let two = builder.numeric_constant(FieldElement::from(2_u128), Type::field());
builder.terminate_with_jmp(join_block, vec![two]);
let join_param = builder.add_block_parameter(join_block, Type::field());
builder.switch_to_block(join_block);
builder.terminate_with_return(vec![join_param]);

let ssa = builder.finish().inline_functions();
// Expected result:
// fn main f3 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// jmp b3(Field 1)
// b3(v3: Field):
// call println(v3)
// return
// b2():
// jmp b3(Field 2)
// }
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 4);
}
}