Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): Prevent stores in 'then' branch from affecting the 'else' branch #1827

Merged
merged 3 commits into from
Jun 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 181 additions & 46 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ impl<'f> Context<'f> {
let else_condition = self.insert_instruction(Instruction::Not(then_condition));
let zero = FieldElement::zero();

// Make sure the else branch sees the previous values of each store
// rather than any values created in the 'then' branch.
self.undo_stores_in_then_branch(&then_branch);

let else_branch =
self.inline_branch(block, else_block, old_condition, else_condition, zero);

Expand Down Expand Up @@ -572,9 +576,25 @@ impl<'f> Context<'f> {
/// instruction. If this ordering is changed, the ordering that store values are merged within
/// this function also needs to be changed to reflect that.
fn merge_stores(&mut self, then_branch: Branch, else_branch: Branch) {
let mut merge_store = |address, then_case, else_case, old_value| {
let then_condition = then_branch.condition;
let else_condition = else_branch.condition;
// Address -> (then_value, else_value, value_before_the_if)
let mut new_map = HashMap::with_capacity(then_branch.store_values.len());

for (address, store) in then_branch.store_values {
new_map.insert(address, (store.new_value, store.old_value, store.old_value));
}

for (address, store) in else_branch.store_values {
if let Some(entry) = new_map.get_mut(&address) {
entry.1 = store.new_value;
} else {
new_map.insert(address, (store.old_value, store.new_value, store.old_value));
}
}

let then_condition = then_branch.condition;
let else_condition = else_branch.condition;

for (address, (then_case, else_case, old_value)) in new_map {
let value = self.merge_values(then_condition, else_condition, then_case, else_case);
self.insert_instruction_with_typevars(Instruction::Store { address, value }, None);

Expand All @@ -583,14 +603,6 @@ impl<'f> Context<'f> {
} else {
self.store_values.insert(address, Store { old_value, new_value: value });
}
};

for (address, store) in then_branch.store_values {
merge_store(address, store.new_value, store.old_value, store.old_value);
}

for (address, store) in else_branch.store_values {
merge_store(address, store.old_value, store.new_value, store.old_value);
}
}

Expand Down Expand Up @@ -676,6 +688,14 @@ impl<'f> Context<'f> {
instruction
}
}

fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) {
for (address, store) in &then_branch.store_values {
let address = *address;
let value = store.old_value;
self.insert_instruction_with_typevars(Instruction::Store { address, value }, None);
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -836,36 +856,35 @@ mod test {
// v4 = load v1
// store Field 5 at v1
// v5 = not v0
// enable_side_effects v5
// store v4 at v1
// enable_side_effects u1 1
// v7 = mul v0, Field 5
// v8 = mul v5, v4
// v9 = add v7, v8
// store v9 at v1
// v6 = cast v0 as Field
// v7 = cast v5 as Field
// v8 = mul v6, Field 5
// v9 = mul v7, v4
// v10 = add v8, v9
// store v10 at v1
// return
// }
let ssa = ssa.flatten_cfg();
let main = ssa.main();

assert_eq!(main.reachable_blocks().len(), 1);

let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
assert_eq!(store_count, 2);
assert_eq!(store_count, 3);
}

// Currently failing since the offsets create additions with different ValueIds which are
// treated wrongly as different addresses.
#[test]
fn merge_stores_with_offsets() {
fn merge_stores_with_else_block() {
// fn main f0 {
// b0(v0: u1, v1: ref):
// jmpif v0, then: b1, else: b2
// b1():
// v2 = add v1, 1
// store v2, Field 5
// store Field 5 in v1
// jmp b3()
// b2():
// v3 = add v1, 1
// store v3, Field 6
// store Field 6 in v1
// jmp b3()
// b3():
// return
Expand All @@ -883,16 +902,13 @@ mod test {
builder.terminate_with_jmpif(v0, b1, b2);

builder.switch_to_block(b1);
let one = builder.field_constant(1u128);
let v2 = builder.insert_binary(v1, BinaryOp::Add, one);
let five = builder.field_constant(5u128);
builder.insert_store(v2, five);
builder.insert_store(v1, five);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b2);
let v3 = builder.insert_binary(v1, BinaryOp::Add, one);
let six = builder.field_constant(6u128);
builder.insert_store(v3, six);
builder.insert_store(v1, six);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b3);
Expand All @@ -904,27 +920,25 @@ mod test {
// fn main f0 {
// b0(v0: u1, v1: reference):
// enable_side_effects v0
// v7 = add v1, Field 1
// v8 = load v7
// store Field 5 at v7
// v9 = not v0
// enable_side_effects v9
// v11 = add v1, Field 1
// v12 = load v11
// store Field 6 at v11
// enable_side_effects Field 1
// v13 = mul v0, Field 5
// v14 = mul v9, v8
// v15 = add v13, v14
// store v15 at v7
// v16 = mul v0, v12
// v17 = mul v9, Field 6
// v18 = add v16, v17
// store v18 at v11
// v5 = load v1
// store Field 5 at v1
// v6 = not v0
// store v5 at v1
// enable_side_effects v6
// v8 = load v1
// store Field 6 at v1
// enable_side_effects u1 1
// v9 = cast v0 as Field
// v10 = cast v6 as Field
// v11 = mul v9, Field 5
// v12 = mul v10, Field 6
// v13 = add v11, v12
// store v13 at v1
// return
// }
let ssa = ssa.flatten_cfg();
let main = ssa.main();
println!("{ssa}");
assert_eq!(main.reachable_blocks().len(), 1);

let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
Expand Down Expand Up @@ -1333,4 +1347,125 @@ mod test {
}
assert_eq!(constrain_count, 1);
}

#[test]
fn undo_stores() {
// Regression test for #1826. Ensures the `else` branch does not see the stores of the
// `then` branch.
//
// fn main f1 {
// b0():
// v0 = allocate
// store Field 0 at v0
// v2 = allocate
// store Field 2 at v2
// v4 = load v2
// v5 = lt v4, Field 2
// jmpif v5 then: b1, else: b2
// b1():
// v24 = load v0
// v25 = load v2
// v26 = mul v25, Field 10
// v27 = add v24, v26
// store v27 at v0
// v28 = load v2
// v29 = add v28, Field 1
// store v29 at v2
// jmp b5()
// b5():
// v14 = load v0
// return v14
// b2():
// v6 = load v2
// v8 = lt v6, Field 4
// jmpif v8 then: b3, else: b4
// b3():
// v16 = load v0
// v17 = load v2
// v19 = mul v17, Field 100
// v20 = add v16, v19
// store v20 at v0
// v21 = load v2
// v23 = add v21, Field 1
// store v23 at v2
// jmp b4()
// b4():
// jmp b5()
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

let b1 = builder.insert_block();
let b2 = builder.insert_block();
let b3 = builder.insert_block();
let b4 = builder.insert_block();
let b5 = builder.insert_block();

let zero = builder.field_constant(0u128);
let one = builder.field_constant(1u128);
let two = builder.field_constant(2u128);
let four = builder.field_constant(4u128);
let ten = builder.field_constant(10u128);
let one_hundred = builder.field_constant(100u128);

let v0 = builder.insert_allocate();
builder.insert_store(v0, zero);
let v2 = builder.insert_allocate();
builder.insert_store(v2, two);
let v4 = builder.insert_load(v2, Type::field());
let v5 = builder.insert_binary(v4, BinaryOp::Lt, two);
builder.terminate_with_jmpif(v5, b1, b2);

builder.switch_to_block(b1);
let v24 = builder.insert_load(v0, Type::field());
let v25 = builder.insert_load(v2, Type::field());
let v26 = builder.insert_binary(v25, BinaryOp::Mul, ten);
let v27 = builder.insert_binary(v24, BinaryOp::Add, v26);
builder.insert_store(v0, v27);
let v28 = builder.insert_load(v2, Type::field());
let v29 = builder.insert_binary(v28, BinaryOp::Add, one);
builder.insert_store(v2, v29);
builder.terminate_with_jmp(b5, vec![]);

builder.switch_to_block(b5);
let v14 = builder.insert_load(v0, Type::field());
builder.terminate_with_return(vec![v14]);

builder.switch_to_block(b2);
let v6 = builder.insert_load(v2, Type::field());
let v8 = builder.insert_binary(v6, BinaryOp::Lt, four);
builder.terminate_with_jmpif(v8, b3, b4);

builder.switch_to_block(b3);
let v16 = builder.insert_load(v0, Type::field());
let v17 = builder.insert_load(v2, Type::field());
let v19 = builder.insert_binary(v17, BinaryOp::Mul, one_hundred);
let v20 = builder.insert_binary(v16, BinaryOp::Add, v19);
builder.insert_store(v0, v20);
let v21 = builder.insert_load(v2, Type::field());
let v23 = builder.insert_binary(v21, BinaryOp::Add, one);
builder.insert_store(v2, v23);
builder.terminate_with_jmp(b4, vec![]);

builder.switch_to_block(b4);
builder.terminate_with_jmp(b5, vec![]);

let ssa = builder.finish().flatten_cfg().mem2reg().fold_constants();

let main = ssa.main();

// The return value should be 200, not 310
match main.dfg[main.entry_block()].terminator() {
Some(TerminatorInstruction::Return { return_values }) => {
match main.dfg.get_numeric_constant(return_values[0]) {
Some(constant) => {
let value = constant.to_u128();
assert_eq!(value, 200);
}
None => unreachable!("Expected constant 200 for return value"),
}
}
_ => unreachable!(),
}
}
}