Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa): Check if result of array set is used in value of another array set #6197

Merged
merged 7 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
};
use acvm::acir::{brillig::MemoryAddress, AcirField};

pub(crate) const MAX_STACK_SIZE: usize = 2048;
pub(crate) const MAX_STACK_SIZE: usize = 32768;
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) const MAX_SCRATCH_SPACE: usize = 64;

impl<F: AcirField + DebugToString> BrilligContext<F, Stack> {
Expand Down Expand Up @@ -158,7 +158,7 @@
}

for (i, bit_size) in arguments.iter().flat_map(flat_bit_sizes).enumerate() {
// Calldatacopy tags everything with field type, so when downcast when necessary

Check warning on line 161 in compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Calldatacopy)
if bit_size < F::max_num_bits() {
self.cast_instruction(
SingleAddrVariable::new(
Expand Down
93 changes: 59 additions & 34 deletions compiler/noirc_evaluator/src/ssa/opt/array_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::ssa::{
ir::{
basic_block::BasicBlockId,
dfg::DataFlowGraph,
function::Function,
function::{Function, RuntimeType},
instruction::{Instruction, InstructionId, TerminatorInstruction},
types::Type::{Array, Slice},
value::ValueId,
Expand Down Expand Up @@ -34,14 +34,17 @@ impl Function {
let mut array_to_last_use = HashMap::default();
let mut instructions_to_update = HashSet::default();
let mut arrays_from_load = HashSet::default();
let mut inner_nested_arrays = HashMap::default();

for block in reachable_blocks.iter() {
analyze_last_uses(
&self.dfg,
*block,
matches!(self.runtime(), RuntimeType::Brillig),
&mut array_to_last_use,
&mut instructions_to_update,
&mut arrays_from_load,
&mut inner_nested_arrays,
);
}
for block in reachable_blocks {
Expand All @@ -55,9 +58,11 @@ impl Function {
fn analyze_last_uses(
dfg: &DataFlowGraph,
block_id: BasicBlockId,
is_brillig_func: bool,
array_to_last_use: &mut HashMap<ValueId, InstructionId>,
instructions_that_can_be_made_mutable: &mut HashSet<InstructionId>,
arrays_from_load: &mut HashSet<ValueId>,
inner_nested_arrays: &mut HashMap<ValueId, InstructionId>,
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
) {
let block = &dfg[block_id];

Expand All @@ -70,12 +75,22 @@ fn analyze_last_uses(
instructions_that_can_be_made_mutable.remove(&existing);
}
}
Instruction::ArraySet { array, .. } => {
Instruction::ArraySet { array, value, .. } => {
let array = dfg.resolve(*array);

if let Some(existing) = array_to_last_use.insert(array, *instruction_id) {
instructions_that_can_be_made_mutable.remove(&existing);
}
if is_brillig_func {
let value = dfg.resolve(*value);

if let Some(existing) = inner_nested_arrays.get(&value) {
instructions_that_can_be_made_mutable.remove(existing);
}
let result = dfg.instruction_results(*instruction_id)[0];
inner_nested_arrays.insert(result, *instruction_id);
}

// If the array we are setting does not come from a load we can safely mark it mutable.
// If the array comes from a load we may potentially being mutating an array at a reference
// that is loaded from by other values.
Expand Down Expand Up @@ -169,29 +184,31 @@ mod tests {
// from and cloned in a loop. If the array is inadvertently marked mutable, and is cloned in a previous iteration
// of the loop, its clone will also be altered.
//
// acir(inline) fn main f0 {
// brillig fn main f0 {
// b0():
// v2 = allocate
// store [Field 0, Field 0, Field 0, Field 0, Field 0] at v2
// v3 = allocate
// store [Field 0, Field 0, Field 0, Field 0, Field 0] at v3
// store [[Field 0, Field 0, Field 0, Field 0, Field 0], [Field 0, Field 0, Field 0, Field 0, Field 0]] at v3
// v4 = allocate
// store [[Field 0, Field 0, Field 0, Field 0, Field 0], [Field 0, Field 0, Field 0, Field 0, Field 0]] at v4
// jmp b1(u32 0)
// b1(v5: u32):
// v7 = lt v5, u32 5
// jmpif v7 then: b3, else: b2
// b1(v6: u32):
// v8 = lt v6, u32 5
// jmpif v8 then: b3, else: b2
// b3():
// v8 = eq v5, u32 5
// jmpif v8 then: b4, else: b5
// v9 = eq v6, u32 5
// jmpif v9 then: b4, else: b5
// b4():
// v9 = load v2
// store v9 at v3
// v10 = load v3
// store v10 at v4
// jmp b5()
// b5():
// v10 = load v2
// v12 = array_set v10, index v5, value Field 20
// store v12 at v2
// v14 = add v5, u32 1
// jmp b1(v14)
// v11 = load v3
// v13 = array_get v11, index Field 0
// v14 = array_set v13, index v6, value Field 20
// v15 = array_set v11, index v6, value v14
// store v15 at v3
// v17 = add v6, u32 1
// jmp b1(v17)
// b2():
// return
// }
Expand All @@ -203,13 +220,16 @@ mod tests {
let zero = builder.field_constant(0u128);
let array_constant =
builder.array_constant(vector![zero, zero, zero, zero, zero], array_type.clone());
let nested_array_type = Type::Array(Arc::new(vec![array_type.clone()]), 2);
let nested_array_constant = builder
.array_constant(vector![array_constant, array_constant], nested_array_type.clone());

let v2 = builder.insert_allocate(array_type.clone());
let v3 = builder.insert_allocate(array_type.clone());

builder.insert_store(v2, array_constant);
builder.insert_store(v3, nested_array_constant);

let v3 = builder.insert_allocate(array_type.clone());
builder.insert_store(v3, array_constant);
let v4 = builder.insert_allocate(array_type.clone());
builder.insert_store(v4, nested_array_constant);

let b1 = builder.insert_block();
let zero_u32 = builder.numeric_constant(0u128, Type::unsigned(32));
Expand All @@ -219,42 +239,47 @@ mod tests {
builder.switch_to_block(b1);
let v5 = builder.add_block_parameter(b1, Type::unsigned(32));
let five = builder.numeric_constant(5u128, Type::unsigned(32));
let v7 = builder.insert_binary(v5, BinaryOp::Lt, five);
let v8 = builder.insert_binary(v5, BinaryOp::Lt, five);

let b2 = builder.insert_block();
let b3 = builder.insert_block();
let b4 = builder.insert_block();
let b5 = builder.insert_block();
builder.terminate_with_jmpif(v7, b3, b2);
builder.terminate_with_jmpif(v8, b3, b2);

// Loop body
// b3 is the if statement conditional
builder.switch_to_block(b3);
let two = builder.numeric_constant(5u128, Type::unsigned(32));
let v8 = builder.insert_binary(v5, BinaryOp::Eq, two);
builder.terminate_with_jmpif(v8, b4, b5);
let v9 = builder.insert_binary(v5, BinaryOp::Eq, two);
builder.terminate_with_jmpif(v9, b4, b5);

// b4 is the rest of the loop after the if statement
builder.switch_to_block(b4);
let v9 = builder.insert_load(v2, array_type.clone());
builder.insert_store(v3, v9);
let v10 = builder.insert_load(v3, nested_array_type.clone());
builder.insert_store(v4, v10);
builder.terminate_with_jmp(b5, vec![]);

builder.switch_to_block(b5);
let v10 = builder.insert_load(v2, array_type.clone());
let v11 = builder.insert_load(v3, nested_array_type.clone());
let twenty = builder.field_constant(20u128);
let v12 = builder.insert_array_set(v10, v5, twenty);
builder.insert_store(v2, v12);
let v13 = builder.insert_array_get(v11, zero, array_type.clone());
let v14 = builder.insert_array_set(v13, v5, twenty);
let v15 = builder.insert_array_set(v11, v5, v14);

builder.insert_store(v3, v15);
let one = builder.numeric_constant(1u128, Type::unsigned(32));
let v14 = builder.insert_binary(v5, BinaryOp::Add, one);
builder.terminate_with_jmp(b1, vec![v14]);
let v17 = builder.insert_binary(v5, BinaryOp::Add, one);
builder.terminate_with_jmp(b1, vec![v17]);

builder.switch_to_block(b2);
builder.terminate_with_return(vec![]);

let ssa = builder.finish();
println!("{}", ssa);
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
// We expect the same result as above
let ssa = ssa.array_set_optimization();
println!("{}", ssa);
vezenovm marked this conversation as resolved.
Show resolved Hide resolved

let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 6);
Expand All @@ -265,7 +290,7 @@ mod tests {
.filter(|instruction| matches!(&main.dfg[**instruction], Instruction::ArraySet { .. }))
.collect::<Vec<_>>();

assert_eq!(array_set_instructions.len(), 1);
assert_eq!(array_set_instructions.len(), 2);
if let Instruction::ArraySet { mutable, .. } = &main.dfg[*array_set_instructions[0]] {
// The single array set should not be marked mutable
assert!(!mutable);
Expand Down
Loading