Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): Fix flattening pass inserting loads before stores occur #1783

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<'f> FunctionInserter<'f> {
instruction: Instruction,
id: InstructionId,
block: BasicBlockId,
) {
) -> InsertInstructionResult {
let results = self.function.dfg.instruction_results(id);
let results = vecmap(results, |id| self.function.dfg.resolve(*id));

Expand All @@ -79,24 +79,25 @@ impl<'f> FunctionInserter<'f> {
let new_results =
self.function.dfg.insert_instruction_and_results(instruction, block, ctrl_typevars);

Self::insert_new_instruction_results(&mut self.values, &results, new_results);
Self::insert_new_instruction_results(&mut self.values, &results, &new_results);
new_results
}

/// Modify the values HashMap to remember the mapping between an instruction result's previous
/// ValueId (from the source_function) and its new ValueId in the destination function.
pub(crate) fn insert_new_instruction_results(
values: &mut HashMap<ValueId, ValueId>,
old_results: &[ValueId],
new_results: InsertInstructionResult,
new_results: &InsertInstructionResult,
) {
assert_eq!(old_results.len(), new_results.len());

match new_results {
InsertInstructionResult::SimplifiedTo(new_result) => {
values.insert(old_results[0], new_result);
values.insert(old_results[0], *new_result);
}
InsertInstructionResult::Results(new_results) => {
for (old_result, new_result) in old_results.iter().zip(new_results) {
for (old_result, new_result) in old_results.iter().zip(*new_results) {
values.insert(*old_result, *new_result);
}
}
Expand Down
208 changes: 169 additions & 39 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@
//! v11 = mul v4, Field 12
//! v12 = add v10, v11
//! store v12 at v5 (new store)
use std::collections::HashMap;
use std::{
collections::{HashMap, HashSet},
rc::Rc,
};

use acvm::FieldElement;
use iter_extended::vecmap;
Expand All @@ -144,7 +147,7 @@ use crate::ssa_refactor::{
function::Function,
function_inserter::FunctionInserter,
instruction::{BinaryOp, Instruction, InstructionId, TerminatorInstruction},
types::Type,
types::{CompositeType, Type},
value::ValueId,
},
ssa_gen::Ssa,
Expand Down Expand Up @@ -179,6 +182,13 @@ struct Context<'f> {
/// Maps an address to the old and new value of the element at that address
store_values: HashMap<ValueId, Store>,

/// Stores all allocations local to the current branch.
/// Since these branches are local to the current branch (ie. only defined within one branch of
/// an if expression), they should not be merged with their previous value or stored value in
/// the other branch since there is no such value. The ValueId here is that which is returned
/// by the allocate instruction.
local_allocations: HashSet<ValueId>,

/// A stack of each jmpif condition that was taken to reach a particular point in the program.
/// When two branches are merged back into one, this constitutes a join point, and is analogous
/// to the rest of the program after an if statement. When such a join point / end block is
Expand All @@ -197,6 +207,7 @@ struct Branch {
condition: ValueId,
last_block: BasicBlockId,
store_values: HashMap<ValueId, Store>,
local_allocations: HashSet<ValueId>,
}

fn flatten_function_cfg(function: &mut Function) {
Expand All @@ -211,10 +222,12 @@ fn flatten_function_cfg(function: &mut Function) {
}
let cfg = ControlFlowGraph::with_function(function);
let branch_ends = branch_analysis::find_branch_ends(function, &cfg);

let mut context = Context {
inserter: FunctionInserter::new(function),
cfg,
store_values: HashMap::new(),
local_allocations: HashSet::new(),
branch_ends,
conditions: Vec::new(),
};
Expand Down Expand Up @@ -359,40 +372,60 @@ impl<'f> Context<'f> {
Type::Numeric(_) => {
self.merge_numeric_values(then_condition, else_condition, then_value, else_value)
}
Type::Array(element_types, len) => {
let mut merged = im::Vector::new();

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index = ((i * element_types.len() + element_index) as u128).into();
let index = self.inserter.function.dfg.make_constant(index, Type::field());

let typevars = Some(vec![element_type.clone()]);

let mut get_element = |array, typevars| {
let get = Instruction::ArrayGet { array, index };
self.insert_instruction_with_typevars(get, typevars).first()
};

let then_element = get_element(then_value, typevars.clone());
let else_element = get_element(else_value, typevars);

merged.push_back(self.merge_values(
then_condition,
else_condition,
then_element,
else_element,
));
}
}

self.inserter.function.dfg.make_array(merged, element_types)
}
Type::Array(element_types, len) => self.merge_array_values(
element_types,
len,
then_condition,
else_condition,
then_value,
else_value,
),
Type::Reference => panic!("Cannot return references from an if expression"),
Type::Function => panic!("Cannot return functions from an if expression"),
}
}

/// Given an if expression that returns an array: `if c { array1 } else { array2 }`,
/// this function will recursively merge array1 and array2 into a single resulting array
/// by creating a new array containing the result of self.merge_values for each element.
fn merge_array_values(
&mut self,
element_types: Rc<CompositeType>,
len: usize,
then_condition: ValueId,
else_condition: ValueId,
then_value: ValueId,
else_value: ValueId,
) -> ValueId {
let mut merged = im::Vector::new();

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index = ((i * element_types.len() + element_index) as u128).into();
let index = self.inserter.function.dfg.make_constant(index, Type::field());

let typevars = Some(vec![element_type.clone()]);

let mut get_element = |array, typevars| {
let get = Instruction::ArrayGet { array, index };
self.insert_instruction_with_typevars(get, typevars).first()
};

let then_element = get_element(then_value, typevars.clone());
let else_element = get_element(else_value, typevars);

merged.push_back(self.merge_values(
then_condition,
else_condition,
then_element,
else_element,
));
}
}

self.inserter.function.dfg.make_array(merged, element_types)
}

/// Merge two numeric values a and b from separate basic blocks to a single value. This
/// function would return the result of `if c { a } else { b }` as `c*a + (!c)*b`.
fn merge_numeric_values(
Expand Down Expand Up @@ -437,13 +470,18 @@ impl<'f> Context<'f> {
// 'else' case of an if with no else - so there is no else branch.
Branch {
condition: new_condition,
// The last block here is somewhat arbitrary. It only matters that it has no Jmp
// args that will be merged by inline_branch_end. Since jmpifs don't have
// block arguments, it is safe to use the jmpif block here.
last_block: jmpif_block,
store_values: HashMap::new(),
local_allocations: HashSet::new(),
}
} else {
self.push_condition(jmpif_block, new_condition);
self.insert_current_side_effects_enabled();
let old_stores = std::mem::take(&mut self.store_values);
let old_allocations = std::mem::take(&mut self.local_allocations);

// Remember the old condition value is now known to be true/false within this branch
let known_value =
Expand All @@ -453,12 +491,15 @@ impl<'f> Context<'f> {
let final_block = self.inline_block(destination, &[]);

self.conditions.pop();

let stores_in_branch = std::mem::replace(&mut self.store_values, old_stores);
let local_allocations = std::mem::replace(&mut self.local_allocations, old_allocations);

Branch {
condition: new_condition,
last_block: final_block,
store_values: stores_in_branch,
local_allocations,
}
}
}
Expand Down Expand Up @@ -533,14 +574,16 @@ impl<'f> Context<'f> {
}

fn remember_store(&mut self, address: ValueId, new_value: ValueId) {
if let Some(store_value) = self.store_values.get_mut(&address) {
store_value.new_value = new_value;
} else {
let load = Instruction::Load { address };
let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]);
let old_value = self.insert_instruction_with_typevars(load, load_type).first();
if !self.local_allocations.contains(&address) {
if let Some(store_value) = self.store_values.get_mut(&address) {
store_value.new_value = new_value;
} else {
let load = Instruction::Load { address };
let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]);
let old_value = self.insert_instruction_with_typevars(load, load_type).first();

self.store_values.insert(address, Store { old_value, new_value });
self.store_values.insert(address, Store { old_value, new_value });
}
}
}

Expand All @@ -558,6 +601,7 @@ impl<'f> Context<'f> {
// If this is not a separate variable, clippy gets confused and says the to_vec is
// unnecessary, when removing it actually causes an aliasing/mutability error.
let instructions = self.inserter.function.dfg[destination].instructions().to_vec();

for instruction in instructions {
self.push_instruction(instruction);
}
Expand All @@ -574,8 +618,16 @@ impl<'f> Context<'f> {
fn push_instruction(&mut self, id: InstructionId) {
let instruction = self.inserter.map_instruction(id);
let instruction = self.handle_instruction_side_effects(instruction);
let is_allocate = matches!(instruction, Instruction::Allocate);

let entry = self.inserter.function.entry_block();
self.inserter.push_instruction_value(instruction, id, entry);
let results = self.inserter.push_instruction_value(instruction, id, entry);

// Remember an allocate was created local to this branch so that we do not try to merge store
// values across branches for it later.
if is_allocate {
self.local_allocations.insert(results.first());
}
}

/// If we are currently in a branch, we need to modify constrain instructions
Expand Down Expand Up @@ -1020,6 +1072,84 @@ mod test {
assert_eq!(merged_values, vec![3, 5, 6]);
}

#[test]
fn allocate_in_single_branch() {
// Regression test for #1756
// fn foo() -> Field {
// let mut x = 0;
// x
// }
//
// fn main(cond:bool) {
// if cond {
// foo();
// };
// }
//
// // Translates to the following before the flattening pass:
// fn main f2 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// v2 = allocate
// store Field 0 at v2
// v4 = load v2
// jmp b2()
// b2():
// return
// }
// The bug is that the flattening pass previously inserted a load
// before the first store to allocate, which loaded an uninitialized value.
// In this test we assert the ordering is strictly Allocate then Store then Load.
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

let b1 = builder.insert_block();
let b2 = builder.insert_block();

let v0 = builder.add_parameter(Type::bool());
builder.terminate_with_jmpif(v0, b1, b2);

builder.switch_to_block(b1);
let v2 = builder.insert_allocate();
let zero = builder.field_constant(0u128);
builder.insert_store(v2, zero);
let _v4 = builder.insert_load(v2, Type::field());
builder.terminate_with_jmp(b2, vec![]);

builder.switch_to_block(b2);
builder.terminate_with_return(vec![]);

let ssa = builder.finish().flatten_cfg();
let main = ssa.main();

// Now assert that there is not a load between the allocate and its first store
// The Expected IR is:
//
// fn main f2 {
// b0(v0: u1):
// enable_side_effects v0
// v6 = allocate
// store Field 0 at v6
// v7 = load v6
// v8 = not v0
// enable_side_effects u1 1
// return
// }
let instructions = main.dfg[main.entry_block()].instructions();

let find_instruction = |predicate: fn(&Instruction) -> bool| {
instructions.iter().position(|id| predicate(&main.dfg[*id])).unwrap()
};

let allocate_index = find_instruction(|i| matches!(i, Instruction::Allocate));
let store_index = find_instruction(|i| matches!(i, Instruction::Store { .. }));
let load_index = find_instruction(|i| matches!(i, Instruction::Load { .. }));

assert!(allocate_index < store_index);
assert!(store_index < load_index);
}

/// Work backwards from an instruction to find all the constant values
/// that were used to construct it. E.g for:
///
Expand Down