Skip to content

Commit

Permalink
Merge eb34ceb into 6acef6d
Browse files Browse the repository at this point in the history
  • Loading branch information
vezenovm authored Dec 2, 2024
2 parents 6acef6d + eb34ceb commit 2a2472f
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 10 deletions.
201 changes: 193 additions & 8 deletions compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
//! - Already marked as loop invariants
//!
//! We also check that we are not hoisting instructions with side effects.
use fxhash::FxHashSet as HashSet;
use acvm::{acir::AcirField, FieldElement};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};

use crate::ssa::{
ir::{
basic_block::BasicBlockId,
function::{Function, RuntimeType},
function_inserter::FunctionInserter,
instruction::InstructionId,
instruction::{Instruction, InstructionId},
types::Type,
value::ValueId,
},
Ssa,
Expand Down Expand Up @@ -45,25 +47,51 @@ impl Function {
}

impl Loops {
fn hoist_loop_invariants(self, function: &mut Function) {
fn hoist_loop_invariants(mut self, function: &mut Function) {
let mut context = LoopInvariantContext::new(function);

for loop_ in self.yet_to_unroll.iter() {
// The loops should be sorted by the number of blocks.
// We want to access outer nested loops first, which we do by popping
// from the top of the list.
while let Some(loop_) = self.yet_to_unroll.pop() {
let Ok(pre_header) = loop_.get_pre_header(context.inserter.function, &self.cfg) else {
// If the loop does not have a preheader we skip hoisting loop invariants for this loop
continue;
};
context.hoist_loop_invariants(loop_, pre_header);

context.hoist_loop_invariants(&loop_, pre_header);
}

context.map_dependent_instructions();
}
}

impl Loop {
/// Find the value that controls whether to perform a loop iteration.
/// This is going to be the block parameter of the loop header.
///
/// Consider the following example of a `for i in 0..4` loop:
/// ```text
/// brillig(inline) fn main f0 {
/// b0(v0: u32):
/// ...
/// jmp b1(u32 0)
/// b1(v1: u32): // Loop header
/// v5 = lt v1, u32 4 // Upper bound
/// jmpif v5 then: b3, else: b2
/// ```
/// In the example above, `v1` is the induction variable
fn get_induction_variable(&self, function: &Function) -> ValueId {
function.dfg.block_parameters(self.header)[0]
}
}

struct LoopInvariantContext<'f> {
inserter: FunctionInserter<'f>,
defined_in_loop: HashSet<ValueId>,
loop_invariants: HashSet<ValueId>,
// Maps induction variable -> fixed upper loop bound
outer_induction_variables: HashMap<ValueId, FieldElement>,
}

impl<'f> LoopInvariantContext<'f> {
Expand All @@ -72,6 +100,7 @@ impl<'f> LoopInvariantContext<'f> {
inserter: FunctionInserter::new(function),
defined_in_loop: HashSet::default(),
loop_invariants: HashSet::default(),
outer_induction_variables: HashMap::default(),
}
}

Expand All @@ -88,13 +117,29 @@ impl<'f> LoopInvariantContext<'f> {
self.inserter.push_instruction(instruction_id, *block);
}

self.update_values_defined_in_loop_and_invariants(instruction_id, hoist_invariant);
self.extend_values_defined_in_loop_and_invariants(instruction_id, hoist_invariant);
}
}

// Keep track of a loop induction variable and respective upper bound.
// This will be used by later loops to determine whether they have operations
// reliant upon the maximum induction variable.
let upper_bound = loop_.get_const_upper_bound(self.inserter.function);
if let Some(upper_bound) = upper_bound {
let induction_variable = loop_.get_induction_variable(self.inserter.function);
let induction_variable = self.inserter.resolve(induction_variable);
self.outer_induction_variables.insert(induction_variable, upper_bound);
}
}

/// Gather the variables declared within the loop
fn set_values_defined_in_loop(&mut self, loop_: &Loop) {
// Clear any values that may be defined in previous loops, as the context is per function.
self.defined_in_loop.clear();
// These are safe to keep per function, but we want to be clear that these values
// are used per loop.
self.loop_invariants.clear();

for block in loop_.blocks.iter() {
let params = self.inserter.function.dfg.block_parameters(*block);
self.defined_in_loop.extend(params);
Expand All @@ -107,7 +152,7 @@ impl<'f> LoopInvariantContext<'f> {

/// Update any values defined in the loop and loop invariants after a
/// analyzing and re-inserting a loop's instruction.
fn update_values_defined_in_loop_and_invariants(
fn extend_values_defined_in_loop_and_invariants(
&mut self,
instruction_id: InstructionId,
hoist_invariant: bool,
Expand Down Expand Up @@ -143,9 +188,45 @@ impl<'f> LoopInvariantContext<'f> {
is_loop_invariant &=
!self.defined_in_loop.contains(&value) || self.loop_invariants.contains(&value);
});
is_loop_invariant && instruction.can_be_deduplicated(&self.inserter.function.dfg, false)

let can_be_deduplicated = instruction
.can_be_deduplicated(&self.inserter.function.dfg, false)
|| self.can_be_deduplicated_from_upper_bound(&instruction);

is_loop_invariant && can_be_deduplicated
}

/// Certain instructions can take advantage of that our induction variable has a fixed maximum.
///
/// For example, an array access can usually only be safely deduplicated when we have a constant
/// index that is below the length of the array.
/// Checking an array get where the index is the loop's induction variable on its own
/// would determine that the instruction is not safe for hoisting.
/// However, if we know that the induction variable's upper bound will always be in bounds of the array
/// we can safely hoist the array access.
fn can_be_deduplicated_from_upper_bound(&self, instruction: &Instruction) -> bool {
match instruction {
Instruction::ArrayGet { array, index } => {
let array_typ = self.inserter.function.dfg.type_of_value(*array);
let upper_bound = self.outer_induction_variables.get(index);
if let (Type::Array(_, len), Some(upper_bound)) = (array_typ, upper_bound) {
upper_bound.to_u128() as usize <= len
} else {
false
}
}
_ => false,
}
}

/// Loop invariant hoisting only operates over loop instructions.
/// The `FunctionInserter` is used for mapping old values to new values after
/// re-inserting loop invariant instructions.
/// However, there may be instructions which are not within loops that are
/// still reliant upon the instruction results altered during the pass.
/// This method re-inserts all instructions so that all instructions have
/// correct new value IDs based upon the `FunctionInserter` internal map.
/// Leaving out this mapping could lead to instructions with values that do not exist.
fn map_dependent_instructions(&mut self) {
let blocks = self.inserter.function.reachable_blocks();
for block in blocks {
Expand Down Expand Up @@ -375,4 +456,108 @@ mod test {
// The code should be unchanged
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn hoist_array_gets_using_induction_variable_with_const_bound() {
// SSA for the following program:
//
// fn triple_loop(x: u32) {
// let arr = [2; 5];
// for i in 0..4 {
// for j in 0..4 {
// for _ in 0..4 {
// assert_eq(arr[i], x);
// assert_eq(arr[j], x);
// }
// }
// }
// }
//
// `arr[i]` and `arr[j]` are safe to hoist as we know the maximum possible index
// to be used for both array accesses.
// We want to make sure `arr[i]` is hoisted to the outermost loop body and that
// `arr[j]` is hoisted to the second outermost loop body.
let src = "
brillig(inline) fn main f0 {
b0(v0: u32, v1: u32):
v6 = make_array [u32 2, u32 2, u32 2, u32 2, u32 2] : [u32; 5]
inc_rc v6
jmp b1(u32 0)
b1(v2: u32):
v9 = lt v2, u32 4
jmpif v9 then: b3, else: b2
b3():
jmp b4(u32 0)
b4(v3: u32):
v10 = lt v3, u32 4
jmpif v10 then: b6, else: b5
b6():
jmp b7(u32 0)
b7(v4: u32):
v13 = lt v4, u32 4
jmpif v13 then: b9, else: b8
b9():
v15 = array_get v6, index v2 -> u32
v16 = eq v15, v0
constrain v15 == v0
v17 = array_get v6, index v3 -> u32
v18 = eq v17, v0
constrain v17 == v0
v19 = add v4, u32 1
jmp b7(v19)
b8():
v14 = add v3, u32 1
jmp b4(v14)
b5():
v12 = add v2, u32 1
jmp b1(v12)
b2():
return
}
";

let ssa = Ssa::from_str(src).unwrap();

let expected = "
brillig(inline) fn main f0 {
b0(v0: u32, v1: u32):
v6 = make_array [u32 2, u32 2, u32 2, u32 2, u32 2] : [u32; 5]
inc_rc v6
jmp b1(u32 0)
b1(v2: u32):
v9 = lt v2, u32 4
jmpif v9 then: b3, else: b2
b3():
v10 = array_get v6, index v2 -> u32
v11 = eq v10, v0
jmp b4(u32 0)
b4(v3: u32):
v12 = lt v3, u32 4
jmpif v12 then: b6, else: b5
b6():
v15 = array_get v6, index v3 -> u32
v16 = eq v15, v0
jmp b7(u32 0)
b7(v4: u32):
v17 = lt v4, u32 4
jmpif v17 then: b9, else: b8
b9():
constrain v10 == v0
constrain v15 == v0
v19 = add v4, u32 1
jmp b7(v19)
b8():
v18 = add v3, u32 1
jmp b4(v18)
b5():
v14 = add v2, u32 1
jmp b1(v14)
b2():
return
}
";

let ssa = ssa.loop_invariant_code_motion();
assert_normalized_ssa_equals(ssa, expected);
}
}
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl Function {
pub(super) struct Loop {
/// The header block of a loop is the block which dominates all the
/// other blocks in the loop.
header: BasicBlockId,
pub(super) header: BasicBlockId,

/// The start of the back_edge n -> d is the block n at the end of
/// the loop that jumps back to the header block d which restarts the loop.
Expand Down Expand Up @@ -299,7 +299,7 @@ impl Loop {
/// v5 = lt v1, u32 4 // Upper bound
/// jmpif v5 then: b3, else: b2
/// ```
fn get_const_upper_bound(&self, function: &Function) -> Option<FieldElement> {
pub(super) fn get_const_upper_bound(&self, function: &Function) -> Option<FieldElement> {
let block = &function.dfg[self.header];
let instructions = block.instructions();
assert_eq!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// to be hoisted to the loop's pre-header block.
fn main(x: u32, y: u32) {
loop(4, x, y);
array_read_loop(4, x);
}

fn loop(upper_bound: u32, x: u32, y: u32) {
Expand All @@ -11,3 +12,15 @@ fn loop(upper_bound: u32, x: u32, y: u32) {
assert_eq(z, 12);
}
}

fn array_read_loop(upper_bound: u32, x: u32) {
let arr = [2; 5];
for i in 0..upper_bound {
for j in 0..upper_bound {
for _ in 0..upper_bound {
assert_eq(arr[i], x);
assert_eq(arr[j], x);
}
}
}
}

0 comments on commit 2a2472f

Please sign in to comment.