Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): function inlining orphans calls #1747

Merged
merged 7 commits into from
Jun 21, 2023
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 116 additions & 8 deletions crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ struct PerFunctionContext<'function> {
/// argument values.
values: HashMap<ValueId, ValueId>,

/// Maps BasicBlockIds in the function being inlined to the new BasicBlockIds to use in the
/// function being inlined into.
/// Maps BasicBlockIds in the function being inlined from to the new BasicBlockIds to use in
/// the function being inlined into.
///
/// Note that this 1-to-1 mapping only holds true in respect to the start of a source block.
/// This is because the inlining process potentially splits a single block into many. It is
/// not safe to query this mapping in respect to a program point midway or at the end of a
/// source block because the mapping may be stale.
jfecher marked this conversation as resolved.
Show resolved Hide resolved
blocks: HashMap<BasicBlockId, BasicBlockId>,

/// Maps InstructionIds from the function being inlined to the function being inlined into.
Expand Down Expand Up @@ -280,11 +285,14 @@ impl<'function> PerFunctionContext<'function> {
let mut function_returns = vec![];

while let Some(source_block_id) = block_queue.pop() {
if seen_blocks.contains(&source_block_id) {
continue;
}
jfecher marked this conversation as resolved.
Show resolved Hide resolved
let translated_block_id = self.translate_block(source_block_id, &mut block_queue);
self.context.builder.switch_to_block(translated_block_id);

seen_blocks.insert(source_block_id);
self.inline_block(ssa, source_block_id);
self.inline_block_instructions(ssa, source_block_id);

if let Some((block, values)) =
self.handle_terminator_instruction(source_block_id, &mut block_queue)
Expand Down Expand Up @@ -329,7 +337,7 @@ impl<'function> PerFunctionContext<'function> {

/// Inline each instruction in the given block into the function being inlined into.
/// This may recurse if it finds another function to inline if a call instruction is within this block.
fn inline_block(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
let block = &self.source_function.dfg[block_id];
for id in block.instructions() {
match &self.source_function.dfg[*id] {
Expand Down Expand Up @@ -446,7 +454,14 @@ impl<'function> PerFunctionContext<'function> {
if self.inlining_main {
self.context.builder.terminate_with_return(return_values.clone());
}
let block_id = self.translate_block(block_id, block_queue);
// The decision of how to continue this block is handled by
// `handle_function_returns` once we know whether there are multiple returns to
// consider. Note that we identify the block containing the return terminator
// using `current_block` instead of `translate_block`. This is because the
// inlining of other calls within the this block can split it into multiple
// blocks. As such, the mapping provided by `translate_block` is considered stale
// and unusable for this annotation.
let block_id = self.context.builder.current_block();
jfecher marked this conversation as resolved.
Show resolved Hide resolved
Some((block_id, return_values))
}
}
Expand All @@ -455,10 +470,13 @@ impl<'function> PerFunctionContext<'function> {

#[cfg(test)]
mod test {
use acvm::FieldElement;

use crate::ssa_refactor::{
ir::{
basic_block::BasicBlockId,
function::RuntimeType,
instruction::{BinaryOp, TerminatorInstruction},
instruction::{BinaryOp, Intrinsic, TerminatorInstruction},
map::Id,
types::Type,
},
Expand Down Expand Up @@ -620,15 +638,26 @@ mod test {
// b0():
// jmp b1()
// b1():
// jmp b2()
// b2():
// jmp b3()
// b3():
// jmp b4()
// b4():
// jmp b5()
// b5():
// jmp b6()
// b6():
// return Field 120
// }
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);

let main = inlined.main();
let b1 = &main.dfg[b1];
let b6_id: BasicBlockId = Id::test_new(6);
let b6 = &main.dfg[b6_id];

match b1.terminator() {
match b6.terminator() {
Some(TerminatorInstruction::Return { return_values }) => {
assert_eq!(return_values.len(), 1);
let value = main
Expand All @@ -641,4 +670,83 @@ mod test {
other => unreachable!("Unexpected terminator {other:?}"),
}
}

#[test]
fn displaced_return_mapping() {
// This test is designed specifically to catch a regression in which the ids of blocks
// terminated by returns are badly tracked. As a result, the continuation of a source
// block after a call instruction could but inlined into a block that's already been
// terminated, producing an incorrect order and orphaning successors.

// fn main f0 {
// b0(v0: u1):
// v2 = call f1(v0)
// call println(v2)
// return
// }
// fn inner1 f1 {
// b0(v0: u1):
// v2 = call f2(v0)
// return v2
// }
// fn inner2 f2 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// jmp b3(Field 1)
// b3(v3: Field):
// return v3
// b2():
// jmp b3(Field 2)
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

let main_cond = builder.add_parameter(Type::bool());
let inner1_id = Id::test_new(1);
let inner1 = builder.import_function(inner1_id);
let main_v2 = builder.insert_call(inner1, vec![main_cond], vec![Type::field()])[0];
let println = builder.import_intrinsic_id(Intrinsic::Println);
builder.insert_call(println, vec![main_v2], vec![]);
builder.terminate_with_return(vec![]);

builder.new_function("inner1".into(), inner1_id);
let inner1_cond = builder.add_parameter(Type::bool());
let inner2_id = Id::test_new(2);
let inner2 = builder.import_function(inner2_id);
let inner1_v2 = builder.insert_call(inner2, vec![inner1_cond], vec![Type::field()])[0];
builder.terminate_with_return(vec![inner1_v2]);

builder.new_function("inner2".into(), inner2_id);
let inner2_cond = builder.add_parameter(Type::bool());
let then_block = builder.insert_block();
let else_block = builder.insert_block();
let join_block = builder.insert_block();
builder.terminate_with_jmpif(inner2_cond, then_block, else_block);
builder.switch_to_block(then_block);
let one = builder.numeric_constant(FieldElement::one(), Type::field());
builder.terminate_with_jmp(join_block, vec![one]);
builder.switch_to_block(else_block);
let two = builder.numeric_constant(FieldElement::from(2_u128), Type::field());
builder.terminate_with_jmp(join_block, vec![two]);
let join_param = builder.add_block_parameter(join_block, Type::field());
builder.switch_to_block(join_block);
builder.terminate_with_return(vec![join_param]);

let ssa = builder.finish().inline_functions();
// Expected result:
// fn main f3 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// jmp b3(Field 1)
// b3(v3: Field):
// call println(v3)
// return
// b2():
// jmp b3(Field 2)
// }
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 4);
}
}