Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ssa refactor): Simplify inlining pass and fix inlining failure #1337

Merged
merged 4 commits into from
May 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 127 additions & 27 deletions crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ struct PerFunctionContext<'function> {
/// Maps InstructionIds from the function being inlined to the function being inlined into.
instructions: HashMap<InstructionId, InstructionId>,

/// The TerminatorInstruction::Return in the source_function will be mapped to a jmp to
/// this block in the destination function instead.
return_destination: BasicBlockId,

/// True if we're currently working on the main function.
inlining_main: bool,
}
Expand Down Expand Up @@ -124,7 +120,12 @@ impl InlineContext {

/// Inlines a function into the current function and returns the translated return values
/// of the inlined function.
fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) -> &[ValueId] {
fn inline_function(
&mut self,
ssa: &Ssa,
id: FunctionId,
arguments: &[ValueId],
) -> Vec<ValueId> {
self.recursion_level += 1;

if self.recursion_level > RECURSION_LIMIT {
Expand All @@ -143,9 +144,7 @@ impl InlineContext {
let current_block = context.context.builder.current_block();
context.blocks.insert(source_function.entry_block(), current_block);

context.inline_blocks(ssa);
let return_destination = context.return_destination;
self.builder.block_parameters(return_destination)
context.inline_blocks(ssa)
}

/// Finish inlining and return the new Ssa struct with the inlined version of main.
Expand Down Expand Up @@ -175,10 +174,7 @@ impl<'function> PerFunctionContext<'function> {
/// for containing the mapping between parameters in the source_function and
/// the arguments of the destination function.
fn new(context: &'function mut InlineContext, source_function: &'function Function) -> Self {
// Create the block to return to but don't insert its parameters until we
// have the types of the actual return values later.
Self {
return_destination: context.builder.insert_block(),
context,
source_function,
blocks: HashMap::new(),
Expand Down Expand Up @@ -265,20 +261,60 @@ impl<'function> PerFunctionContext<'function> {
}

/// Inline all reachable blocks within the source_function into the destination function.
fn inline_blocks(&mut self, ssa: &Ssa) {
fn inline_blocks(&mut self, ssa: &Ssa) -> Vec<ValueId> {
let mut seen_blocks = HashSet::new();
let mut block_queue = vec![self.source_function.entry_block()];

// This Vec will contain each block with a Return instruction along with the
// returned values of that block.
let mut function_returns = vec![];

while let Some(source_block_id) = block_queue.pop() {
let translated_block_id = self.translate_block(source_block_id, &mut block_queue);
self.context.builder.switch_to_block(translated_block_id);

seen_blocks.insert(source_block_id);
self.inline_block(ssa, source_block_id);
self.handle_terminator_instruction(source_block_id, &mut block_queue);

if let Some((block, values)) =
self.handle_terminator_instruction(source_block_id, &mut block_queue)
{
function_returns.push((block, values));
}
}

self.context.builder.switch_to_block(self.return_destination);
self.handle_function_returns(function_returns)
}

/// Handle inlining a function's possibly multiple return instructions.
/// If there is only 1 return we can just continue inserting into that block.
/// If there are multiple, we'll need to create a join block to jump to with each value.
fn handle_function_returns(
&mut self,
mut returns: Vec<(BasicBlockId, Vec<ValueId>)>,
) -> Vec<ValueId> {
// Clippy complains if this were written as an if statement
match returns.len() {
1 => {
let (return_block, return_values) = returns.remove(0);
self.context.builder.switch_to_block(return_block);
return_values
}
n if n > 1 => {
// If there is more than 1 return instruction we'll need to create a single block we
// can return to and continue inserting in afterwards.
let return_block = self.context.builder.insert_block();

for (block, return_values) in returns {
self.context.builder.switch_to_block(block);
self.context.builder.terminate_with_jmp(return_block, return_values);
}

self.context.builder.switch_to_block(return_block);
self.context.builder.block_parameters(return_block).to_vec()
}
_ => unreachable!("Inlined function had no return values"),
}
}

/// Inline each instruction in the given block into the function being inlined into.
Expand Down Expand Up @@ -307,7 +343,7 @@ impl<'function> PerFunctionContext<'function> {
let old_results = self.source_function.dfg.instruction_results(call_id);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
let new_results = self.context.inline_function(ssa, function, &arguments);
Self::insert_new_instruction_results(&mut self.values, old_results, new_results);
Self::insert_new_instruction_results(&mut self.values, old_results, &new_results);
}

/// Push the given instruction from the source_function into the current block of the
Expand Down Expand Up @@ -340,16 +376,20 @@ impl<'function> PerFunctionContext<'function> {
/// Handle the given terminator instruction from the given source function block.
/// This will push any new blocks to the destination function as needed, add them
/// to the block queue, and set the terminator instruction for the current block.
///
/// If the terminator instruction was a Return, this will return the block this instruction
/// was in as well as the values that were returned.
fn handle_terminator_instruction(
&mut self,
block_id: BasicBlockId,
block_queue: &mut Vec<BasicBlockId>,
) {
) -> Option<(BasicBlockId, Vec<ValueId>)> {
match self.source_function.dfg[block_id].terminator() {
Some(TerminatorInstruction::Jmp { destination, arguments }) => {
let destination = self.translate_block(*destination, block_queue);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
self.context.builder.terminate_with_jmp(destination, arguments);
None
}
Some(TerminatorInstruction::JmpIf {
condition,
Expand All @@ -360,21 +400,15 @@ impl<'function> PerFunctionContext<'function> {
let then_block = self.translate_block(*then_destination, block_queue);
let else_block = self.translate_block(*else_destination, block_queue);
self.context.builder.terminate_with_jmpif(condition, then_block, else_block);
None
}
Some(TerminatorInstruction::Return { return_values }) => {
let return_values = vecmap(return_values, |value| self.translate_value(*value));

if self.inlining_main {
self.context.builder.terminate_with_return(return_values);
} else {
for value in &return_values {
// Add the block parameters for the return block here since we don't do
// it when inserting the block in PerFunctionContext::new
let typ = self.context.builder.current_function.dfg.type_of_value(*value);
self.context.builder.add_block_parameter(self.return_destination, typ);
}
self.context.builder.terminate_with_jmp(self.return_destination, return_values);
self.context.builder.terminate_with_return(return_values.clone());
}
let block_id = self.translate_block(block_id, block_queue);
Some((block_id, return_values))
}
None => unreachable!("Block has no terminator instruction"),
}
Expand All @@ -384,7 +418,7 @@ impl<'function> PerFunctionContext<'function> {
#[cfg(test)]
mod test {
use crate::ssa_refactor::{
ir::{map::Id, types::Type},
ir::{instruction::BinaryOp, map::Id, types::Type},
ssa_builder::FunctionBuilder,
};

Expand Down Expand Up @@ -418,4 +452,70 @@ mod test {
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);
}

#[test]
fn complex_inlining() {
// This SSA is from issue #1327 which previously failed to inline properly
//
// fn main f0 {
// b0(v0: Field):
// v7 = call f2(f1)
// v13 = call f3(v7)
// v16 = call v13(v0)
// return v16
// }
// fn square f1 {
// b0(v0: Field):
// v2 = mul v0, v0
// return v2
// }
// fn id1 f2 {
// b0(v0: function):
// return v0
// }
// fn id2 f3 {
// b0(v0: function):
// return v0
// }
let main_id = Id::test_new(0);
let square_id = Id::test_new(1);
let id1_id = Id::test_new(2);
let id2_id = Id::test_new(3);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id);
let main_v0 = builder.add_parameter(Type::field());

let main_f1 = builder.import_function(square_id);
let main_f2 = builder.import_function(id1_id);
let main_f3 = builder.import_function(id2_id);

let main_v7 = builder.insert_call(main_f2, vec![main_f1], vec![Type::Function])[0];
let main_v13 = builder.insert_call(main_f3, vec![main_v7], vec![Type::Function])[0];
let main_v16 = builder.insert_call(main_v13, vec![main_v0], vec![Type::field()])[0];
builder.terminate_with_return(vec![main_v16]);

// Compiling square f1
builder.new_function("square".into(), square_id);
let square_v0 = builder.add_parameter(Type::field());
let square_v2 = builder.insert_binary(square_v0, BinaryOp::Mul, square_v0);
builder.terminate_with_return(vec![square_v2]);

// Compiling id1 f2
builder.new_function("id1".into(), id1_id);
let id1_v0 = builder.add_parameter(Type::Function);
builder.terminate_with_return(vec![id1_v0]);

// Compiling id2 f3
builder.new_function("id2".into(), id2_id);
let id2_v0 = builder.add_parameter(Type::Function);
builder.terminate_with_return(vec![id2_v0]);

// Done, now we test that we can successfully inline all functions.
let ssa = builder.finish();
assert_eq!(ssa.functions.len(), 4);

let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);
}
}