Skip to content

Commit

Permalink
feat(perf): mem2reg function state for value loads to optimize across…
Browse files Browse the repository at this point in the history
… blocks (#5757)

# Description

## Problem\*

Resolves <!-- Link to GitHub Issue -->

Part of general effort to reduce Brillig bytecode sizes. No issue as
found while working on other brillig optimizations.

## Summary\*

Now that we have a brillig gates diff I want to pull out individual
changes to live on their own to see only the individual impact across
all of our brillig tests.

This PR simply sees what is the last load we have performed on an
address across all blocks in a function. If we do not have a load on a
value, but any of the blocks still have store instructions to that
address, we can safely remove those store instructions.

## Additional Context



## Documentation\*

Check one:
- [X] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <[email protected]>
Co-authored-by: jfecher <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent f44e0b3 commit 0b297b3
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ mod block;

use std::collections::{BTreeMap, BTreeSet};

use fxhash::FxHashMap as HashMap;

use crate::ssa::{
ir::{
basic_block::BasicBlockId,
Expand Down Expand Up @@ -111,6 +113,10 @@ struct PerFunctionContext<'f> {
/// We avoid removing individual instructions as we go since removing elements
/// from the middle of Vecs many times will be slower than a single call to `retain`.
instructions_to_remove: BTreeSet<InstructionId>,

/// Track a value's last load across all blocks.
/// If a value is not used in anymore loads we can remove the last store to that value.
last_loads: HashMap<ValueId, InstructionId>,
}

impl<'f> PerFunctionContext<'f> {
Expand All @@ -124,6 +130,7 @@ impl<'f> PerFunctionContext<'f> {
inserter: FunctionInserter::new(function),
blocks: BTreeMap::new(),
instructions_to_remove: BTreeSet::new(),
last_loads: HashMap::default(),
}
}

Expand All @@ -140,6 +147,18 @@ impl<'f> PerFunctionContext<'f> {
let references = self.find_starting_references(block);
self.analyze_block(block, references);
}

// If we never load from an address within a function we can remove all stores to that address.
// This rule does not apply to reference parameters, which we must also check for before removing these stores.
for (block_id, block) in self.blocks.iter() {
let block_params = self.inserter.function.dfg.block_parameters(*block_id);
for (value, store_instruction) in block.last_stores.iter() {
let is_reference_param = block_params.contains(value);
if self.last_loads.get(value).is_none() && !is_reference_param {
self.instructions_to_remove.insert(*store_instruction);
}
}
}
}

/// The value of each reference at the start of the given block is the unification
Expand Down Expand Up @@ -239,6 +258,8 @@ impl<'f> PerFunctionContext<'f> {
self.instructions_to_remove.insert(instruction);
} else {
references.mark_value_used(address, self.inserter.function);

self.last_loads.insert(address, instruction);
}
}
Instruction::Store { address, value } => {
Expand Down Expand Up @@ -594,10 +615,8 @@ mod tests {
// acir fn main f0 {
// b0():
// v7 = allocate
// store Field 5 at v7
// jmp b1(Field 5)
// b1(v3: Field):
// store Field 6 at v7
// return v3, Field 5, Field 6
// }
let ssa = ssa.mem2reg();
Expand All @@ -609,9 +628,9 @@ mod tests {
assert_eq!(count_loads(main.entry_block(), &main.dfg), 0);
assert_eq!(count_loads(b1, &main.dfg), 0);

// Neither store is removed since they are each the last in the block and there are multiple blocks
assert_eq!(count_stores(main.entry_block(), &main.dfg), 1);
assert_eq!(count_stores(b1, &main.dfg), 1);
// All stores are removed as there are no loads to the values being stored anywhere in the function.
assert_eq!(count_stores(main.entry_block(), &main.dfg), 0);
assert_eq!(count_stores(b1, &main.dfg), 0);

// The jmp to b1 should also be a constant 5 now
match main.dfg[main.entry_block()].terminator() {
Expand Down Expand Up @@ -641,8 +660,8 @@ mod tests {
// b1():
// store Field 1 at v3
// store Field 2 at v4
// v8 = load v3
// v9 = eq v8, Field 2
// v7 = load v3
// v8 = eq v7, Field 2
// return
// }
let main_id = Id::test_new(0);
Expand Down Expand Up @@ -681,12 +700,9 @@ mod tests {
// acir fn main f0 {
// b0():
// v9 = allocate
// store Field 0 at v9
// v10 = allocate
// store v9 at v10
// jmp b1()
// b1():
// store Field 2 at v9
// return
// }
let ssa = ssa.mem2reg();
Expand All @@ -698,14 +714,17 @@ mod tests {
assert_eq!(count_loads(main.entry_block(), &main.dfg), 0);
assert_eq!(count_loads(b1, &main.dfg), 0);

// Only the first store in b1 is removed since there is another store to the same reference
// All stores should be removed.
// The first store in b1 is removed since there is another store to the same reference
// in the same block, and the store is not needed before the later store.
assert_eq!(count_stores(main.entry_block(), &main.dfg), 2);
assert_eq!(count_stores(b1, &main.dfg), 1);
// The rest of the stores are also removed as no loads are done within any blocks
// to the stored values.
assert_eq!(count_stores(main.entry_block(), &main.dfg), 0);
assert_eq!(count_stores(b1, &main.dfg), 0);

let b1_instructions = main.dfg[b1].instructions();

// We expect the last eq to be optimized out
assert_eq!(b1_instructions.len(), 1);
assert_eq!(b1_instructions.len(), 0);
}
}

0 comments on commit 0b297b3

Please sign in to comment.