Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Deduplicate instructions across blocks #6499

Merged
merged 5 commits into from
Nov 11, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 125 additions & 47 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//!
//! This is the only pass which removes duplicated pure [`Instruction`]s however and so is needed when
//! different blocks are merged, i.e. after the [`flatten_cfg`][super::flatten_cfg] pass.
use std::collections::HashSet;
use std::collections::{HashSet, VecDeque};

use acvm::{acir::AcirField, FieldElement};
use iter_extended::vecmap;
Expand All @@ -28,6 +28,7 @@ use crate::ssa::{
ir::{
basic_block::BasicBlockId,
dfg::{DataFlowGraph, InsertInstructionResult},
dom::DominatorTree,
function::Function,
instruction::{Instruction, InstructionId},
types::Type,
Expand Down Expand Up @@ -67,10 +68,10 @@ impl Function {
/// The structure of this pass is simple:
/// Go through each block and re-insert all instructions.
pub(crate) fn constant_fold(&mut self, use_constraint_info: bool) {
let mut context = Context { use_constraint_info, ..Default::default() };
context.block_queue.push(self.entry_block());
let mut context = Context::new(self, use_constraint_info);
context.block_queue.push_back(self.entry_block());

while let Some(block) = context.block_queue.pop() {
while let Some(block) = context.block_queue.pop_front() {
if context.visited_blocks.contains(&block) {
continue;
}
Expand All @@ -81,34 +82,54 @@ impl Function {
}
}

#[derive(Default)]
struct Context {
use_constraint_info: bool,
/// Maps pre-folded ValueIds to the new ValueIds obtained by re-inserting the instruction.
visited_blocks: HashSet<BasicBlockId>,
block_queue: Vec<BasicBlockId>,
block_queue: VecDeque<BasicBlockId>,

// Contains sets of values which are constrained to be equivalent to each other.
//
// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`.
//
// We partition the maps of constrained values according to the side-effects flag at the point
// at which the values are constrained. This prevents constraints which are only sometimes enforced
// being used to modify the rest of the program.
constraint_simplification_mappings: HashMap<ValueId, HashMap<ValueId, ValueId>>,

// Cache of instructions without any side-effects along with their outputs.
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
cached_instruction_results: InstructionResultCache,

dom: DominatorTree,
}

/// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction.
/// Stored as a two-level map to avoid cloning Instructions during the `.get` call.
type InstructionResultCache = HashMap<Instruction, HashMap<Option<ValueId>, Vec<ValueId>>>;
///
/// In addition to each result, the original BasicBlockId is stored as well. This allows us
/// to deduplicate instructions across blocks as long as the new block dominates the original.
type InstructionResultCache = HashMap<Instruction, HashMap<Option<ValueId>, ResultCache>>;

#[derive(Default)]
struct ResultCache {
results: Vec<(BasicBlockId, Vec<ValueId>)>,
}

impl Context {
fn new(function: &Function, use_constraint_info: bool) -> Self {
Self {
use_constraint_info,
visited_blocks: Default::default(),
block_queue: Default::default(),
constraint_simplification_mappings: Default::default(),
cached_instruction_results: Default::default(),
dom: DominatorTree::with_function(function),
}
}

fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) {
let instructions = function.dfg[block].take_instructions();

// Cache of instructions without any side-effects along with their outputs.
let mut cached_instruction_results = HashMap::default();

// Contains sets of values which are constrained to be equivalent to each other.
//
// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`.
//
// We partition the maps of constrained values according to the side-effects flag at the point
// at which the values are constrained. This prevents constraints which are only sometimes enforced
// being used to modify the rest of the program.
let mut constraint_simplification_mappings: HashMap<ValueId, HashMap<ValueId, ValueId>> =
HashMap::default();
let mut side_effects_enabled_var =
function.dfg.make_constant(FieldElement::one(), Type::bool());

Expand All @@ -117,31 +138,26 @@ impl Context {
&mut function.dfg,
block,
instruction_id,
&mut cached_instruction_results,
&mut constraint_simplification_mappings,
&mut side_effects_enabled_var,
);
}
self.block_queue.extend(function.dfg[block].successors());
}

fn fold_constants_into_instruction(
&self,
&mut self,
dfg: &mut DataFlowGraph,
block: BasicBlockId,
id: InstructionId,
instruction_result_cache: &mut InstructionResultCache,
constraint_simplification_mappings: &mut HashMap<ValueId, HashMap<ValueId, ValueId>>,
side_effects_enabled_var: &mut ValueId,
) {
let constraint_simplification_mapping =
constraint_simplification_mappings.entry(*side_effects_enabled_var).or_default();
let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var);
let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping);
let old_results = dfg.instruction_results(id).to_vec();

// If a copy of this instruction exists earlier in the block, then reuse the previous results.
if let Some(cached_results) =
Self::get_cached(dfg, instruction_result_cache, &instruction, *side_effects_enabled_var)
self.get_cached(dfg, &instruction, *side_effects_enabled_var, block)
{
Self::replace_result_ids(dfg, &old_results, cached_results);
return;
Expand All @@ -156,9 +172,8 @@ impl Context {
instruction.clone(),
new_results,
dfg,
instruction_result_cache,
constraint_simplification_mapping,
*side_effects_enabled_var,
block,
);

// If we just inserted an `Instruction::EnableSideEffectsIf`, we need to update `side_effects_enabled_var`
Expand Down Expand Up @@ -229,13 +244,12 @@ impl Context {
}

fn cache_instruction(
&self,
&mut self,
instruction: Instruction,
instruction_results: Vec<ValueId>,
dfg: &DataFlowGraph,
instruction_result_cache: &mut InstructionResultCache,
constraint_simplification_mapping: &mut HashMap<ValueId, ValueId>,
side_effects_enabled_var: ValueId,
block: BasicBlockId,
) {
if self.use_constraint_info {
// If the instruction was a constraint, then create a link between the two `ValueId`s
Expand All @@ -248,18 +262,18 @@ impl Context {

// Prefer replacing with constants where possible.
(Value::NumericConstant { .. }, _) => {
constraint_simplification_mapping.insert(rhs, lhs);
self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs);
}
(_, Value::NumericConstant { .. }) => {
constraint_simplification_mapping.insert(lhs, rhs);
self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs);
}
// Otherwise prefer block parameters over instruction results.
// This is as block parameters are more likely to be a single witness rather than a full expression.
(Value::Param { .. }, Value::Instruction { .. }) => {
constraint_simplification_mapping.insert(rhs, lhs);
self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs);
}
(Value::Instruction { .. }, Value::Param { .. }) => {
constraint_simplification_mapping.insert(lhs, rhs);
self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs);
}
(_, _) => (),
}
Expand All @@ -273,13 +287,22 @@ impl Context {
self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg);
let predicate = use_predicate.then_some(side_effects_enabled_var);

instruction_result_cache
self.cached_instruction_results
.entry(instruction)
.or_default()
.insert(predicate, instruction_results);
.entry(predicate)
.or_default()
.cache(block, instruction_results);
}
}

fn get_constraint_map(
&mut self,
side_effects_enabled_var: ValueId,
) -> &mut HashMap<ValueId, ValueId> {
self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default()
}

/// Replaces a set of [`ValueId`]s inside the [`DataFlowGraph`] with another.
fn replace_result_ids(
dfg: &mut DataFlowGraph,
Expand All @@ -292,22 +315,33 @@ impl Context {
}

fn get_cached<'a>(
&'a mut self,
dfg: &DataFlowGraph,
instruction_result_cache: &'a mut InstructionResultCache,
instruction: &Instruction,
side_effects_enabled_var: ValueId,
block: BasicBlockId,
) -> Option<&'a Vec<ValueId>> {
let results_for_instruction = instruction_result_cache.get(instruction);
let results_for_instruction = self.cached_instruction_results.get(instruction)?;

// See if there's a cached version with no predicate first
if let Some(results) = results_for_instruction.and_then(|map| map.get(&None)) {
return Some(results);
}
let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg);
let predicate = predicate.then_some(side_effects_enabled_var);

results_for_instruction.get(&predicate)?.get(block, &mut self.dom)
}
}

let predicate =
instruction.requires_acir_gen_predicate(dfg).then_some(side_effects_enabled_var);
impl ResultCache {
fn cache(&mut self, block: BasicBlockId, results: Vec<ValueId>) {
self.results.push((block, results));
}

results_for_instruction.and_then(|map| map.get(&predicate))
fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<&Vec<ValueId>> {
for (origin_block, results) in &self.results {
if dom.dominates(*origin_block, block) {
return Some(results);
}
}
None
}
}

Expand Down Expand Up @@ -896,4 +930,48 @@ mod test {
let ending_instruction_count = instructions.len();
assert_eq!(ending_instruction_count, 1);
}

#[test]
fn deduplicate_across_blocks() {
// fn main f0 {
// b0(v0: u1):
// v1 = not v0
// jmp b1()
// b1():
// v2 = not v0
// return v2
// }
let main_id = Id::test_new(0);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id);
let b1 = builder.insert_block();

let v0 = builder.add_parameter(Type::bool());
let _v1 = builder.insert_not(v0);
builder.terminate_with_jmp(b1, Vec::new());

builder.switch_to_block(b1);
let v2 = builder.insert_not(v0);
builder.terminate_with_return(vec![v2]);

let ssa = builder.finish();
let main = ssa.main();
assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1);
assert_eq!(main.dfg[b1].instructions().len(), 1);

// Expected output:
//
// fn main f0 {
// b0(v0: u1):
// v1 = not v0
// jmp b1()
// b1():
// return v1
// }
let ssa = ssa.fold_constants_using_constraints();
let main = ssa.main();
assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1);
assert_eq!(main.dfg[b1].instructions().len(), 0);
}
}
Loading