Skip to content

Commit

Permalink
Merge 666c1bc into 046fec7
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Nov 11, 2024
2 parents 046fec7 + 666c1bc commit 6a441fd
Showing 1 changed file with 136 additions and 48 deletions.
184 changes: 136 additions & 48 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//!
//! This is the only pass which removes duplicated pure [`Instruction`]s however and so is needed when
//! different blocks are merged, i.e. after the [`flatten_cfg`][super::flatten_cfg] pass.
use std::collections::HashSet;
use std::collections::{HashSet, VecDeque};

use acvm::{acir::AcirField, FieldElement};
use iter_extended::vecmap;
Expand All @@ -28,6 +28,7 @@ use crate::ssa::{
ir::{
basic_block::BasicBlockId,
dfg::{DataFlowGraph, InsertInstructionResult},
dom::DominatorTree,
function::Function,
instruction::{Instruction, InstructionId},
types::Type,
Expand Down Expand Up @@ -67,10 +68,10 @@ impl Function {
/// The structure of this pass is simple:
/// Go through each block and re-insert all instructions.
pub(crate) fn constant_fold(&mut self, use_constraint_info: bool) {
let mut context = Context { use_constraint_info, ..Default::default() };
context.block_queue.push(self.entry_block());
let mut context = Context::new(self, use_constraint_info);
context.block_queue.push_back(self.entry_block());

while let Some(block) = context.block_queue.pop() {
while let Some(block) = context.block_queue.pop_front() {
if context.visited_blocks.contains(&block) {
continue;
}
Expand All @@ -81,34 +82,57 @@ impl Function {
}
}

#[derive(Default)]
struct Context {
use_constraint_info: bool,
/// Maps pre-folded ValueIds to the new ValueIds obtained by re-inserting the instruction.
visited_blocks: HashSet<BasicBlockId>,
block_queue: Vec<BasicBlockId>,
block_queue: VecDeque<BasicBlockId>,

/// Contains sets of values which are constrained to be equivalent to each other.
///
/// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`.
///
/// We partition the maps of constrained values according to the side-effects flag at the point
/// at which the values are constrained. This prevents constraints which are only sometimes enforced
/// being used to modify the rest of the program.
constraint_simplification_mappings: HashMap<ValueId, HashMap<ValueId, ValueId>>,

// Cache of instructions without any side-effects along with their outputs.
cached_instruction_results: InstructionResultCache,

dom: DominatorTree,
}

/// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction.
/// Stored as a two-level map to avoid cloning Instructions during the `.get` call.
type InstructionResultCache = HashMap<Instruction, HashMap<Option<ValueId>, Vec<ValueId>>>;
///
/// In addition to each result, the original BasicBlockId is stored as well. This allows us
/// to deduplicate instructions across blocks as long as the new block dominates the original.
type InstructionResultCache = HashMap<Instruction, HashMap<Option<ValueId>, ResultCache>>;

/// Records the results of all duplicate [`Instruction`]s along with the blocks in which they sit.
///
/// For more information see [`InstructionResultCache`].
#[derive(Default)]
struct ResultCache {
results: Vec<(BasicBlockId, Vec<ValueId>)>,
}

impl Context {
fn new(function: &Function, use_constraint_info: bool) -> Self {
Self {
use_constraint_info,
visited_blocks: Default::default(),
block_queue: Default::default(),
constraint_simplification_mappings: Default::default(),
cached_instruction_results: Default::default(),
dom: DominatorTree::with_function(function),
}
}

fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) {
let instructions = function.dfg[block].take_instructions();

// Cache of instructions without any side-effects along with their outputs.
let mut cached_instruction_results = HashMap::default();

// Contains sets of values which are constrained to be equivalent to each other.
//
// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`.
//
// We partition the maps of constrained values according to the side-effects flag at the point
// at which the values are constrained. This prevents constraints which are only sometimes enforced
// being used to modify the rest of the program.
let mut constraint_simplification_mappings: HashMap<ValueId, HashMap<ValueId, ValueId>> =
HashMap::default();
let mut side_effects_enabled_var =
function.dfg.make_constant(FieldElement::one(), Type::bool());

Expand All @@ -117,31 +141,26 @@ impl Context {
&mut function.dfg,
block,
instruction_id,
&mut cached_instruction_results,
&mut constraint_simplification_mappings,
&mut side_effects_enabled_var,
);
}
self.block_queue.extend(function.dfg[block].successors());
}

fn fold_constants_into_instruction(
&self,
&mut self,
dfg: &mut DataFlowGraph,
block: BasicBlockId,
id: InstructionId,
instruction_result_cache: &mut InstructionResultCache,
constraint_simplification_mappings: &mut HashMap<ValueId, HashMap<ValueId, ValueId>>,
side_effects_enabled_var: &mut ValueId,
) {
let constraint_simplification_mapping =
constraint_simplification_mappings.entry(*side_effects_enabled_var).or_default();
let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var);
let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping);
let old_results = dfg.instruction_results(id).to_vec();

// If a copy of this instruction exists earlier in the block, then reuse the previous results.
if let Some(cached_results) =
Self::get_cached(dfg, instruction_result_cache, &instruction, *side_effects_enabled_var)
self.get_cached(dfg, &instruction, *side_effects_enabled_var, block)
{
Self::replace_result_ids(dfg, &old_results, cached_results);
return;
Expand All @@ -156,9 +175,8 @@ impl Context {
instruction.clone(),
new_results,
dfg,
instruction_result_cache,
constraint_simplification_mapping,
*side_effects_enabled_var,
block,
);

// If we just inserted an `Instruction::EnableSideEffectsIf`, we need to update `side_effects_enabled_var`
Expand Down Expand Up @@ -229,13 +247,12 @@ impl Context {
}

fn cache_instruction(
&self,
&mut self,
instruction: Instruction,
instruction_results: Vec<ValueId>,
dfg: &DataFlowGraph,
instruction_result_cache: &mut InstructionResultCache,
constraint_simplification_mapping: &mut HashMap<ValueId, ValueId>,
side_effects_enabled_var: ValueId,
block: BasicBlockId,
) {
if self.use_constraint_info {
// If the instruction was a constraint, then create a link between the two `ValueId`s
Expand All @@ -248,18 +265,18 @@ impl Context {

// Prefer replacing with constants where possible.
(Value::NumericConstant { .. }, _) => {
constraint_simplification_mapping.insert(rhs, lhs);
self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs);
}
(_, Value::NumericConstant { .. }) => {
constraint_simplification_mapping.insert(lhs, rhs);
self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs);
}
// Otherwise prefer block parameters over instruction results.
// This is as block parameters are more likely to be a single witness rather than a full expression.
(Value::Param { .. }, Value::Instruction { .. }) => {
constraint_simplification_mapping.insert(rhs, lhs);
self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs);
}
(Value::Instruction { .. }, Value::Param { .. }) => {
constraint_simplification_mapping.insert(lhs, rhs);
self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs);
}
(_, _) => (),
}
Expand All @@ -273,13 +290,22 @@ impl Context {
self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg);
let predicate = use_predicate.then_some(side_effects_enabled_var);

instruction_result_cache
self.cached_instruction_results
.entry(instruction)
.or_default()
.insert(predicate, instruction_results);
.entry(predicate)
.or_default()
.cache(block, instruction_results);
}
}

fn get_constraint_map(
&mut self,
side_effects_enabled_var: ValueId,
) -> &mut HashMap<ValueId, ValueId> {
self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default()
}

/// Replaces a set of [`ValueId`]s inside the [`DataFlowGraph`] with another.
fn replace_result_ids(
dfg: &mut DataFlowGraph,
Expand All @@ -292,22 +318,40 @@ impl Context {
}

fn get_cached<'a>(
&'a mut self,
dfg: &DataFlowGraph,
instruction_result_cache: &'a mut InstructionResultCache,
instruction: &Instruction,
side_effects_enabled_var: ValueId,
) -> Option<&'a Vec<ValueId>> {
let results_for_instruction = instruction_result_cache.get(instruction);
block: BasicBlockId,
) -> Option<&'a [ValueId]> {
let results_for_instruction = self.cached_instruction_results.get(instruction)?;

// See if there's a cached version with no predicate first
if let Some(results) = results_for_instruction.and_then(|map| map.get(&None)) {
return Some(results);
}
let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg);
let predicate = predicate.then_some(side_effects_enabled_var);

results_for_instruction.get(&predicate)?.get(block, &mut self.dom)
}
}

let predicate =
instruction.requires_acir_gen_predicate(dfg).then_some(side_effects_enabled_var);
impl ResultCache {
/// Records that an `Instruction` in block `block` produced the result values `results`.
fn cache(&mut self, block: BasicBlockId, results: Vec<ValueId>) {
self.results.push((block, results));
}

results_for_instruction.and_then(|map| map.get(&predicate))
/// Returns a set of [`ValueId`]s produced from a copy of this [`Instruction`] which sits
/// within a block which dominates `block`.
///
/// We require that the cached instruction's block dominates `block` in order to avoid
/// cycles causing issues (e.g. two instructions being replaced with the results of each other
/// such that neither instruction exists anymore.)
fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<&[ValueId]> {
for (origin_block, results) in &self.results {
if dom.dominates(*origin_block, block) {
return Some(results);
}
}
None
}
}

Expand Down Expand Up @@ -896,4 +940,48 @@ mod test {
let ending_instruction_count = instructions.len();
assert_eq!(ending_instruction_count, 1);
}

#[test]
fn deduplicate_across_blocks() {
// fn main f0 {
// b0(v0: u1):
// v1 = not v0
// jmp b1()
// b1():
// v2 = not v0
// return v2
// }
let main_id = Id::test_new(0);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id);
let b1 = builder.insert_block();

let v0 = builder.add_parameter(Type::bool());
let _v1 = builder.insert_not(v0);
builder.terminate_with_jmp(b1, Vec::new());

builder.switch_to_block(b1);
let v2 = builder.insert_not(v0);
builder.terminate_with_return(vec![v2]);

let ssa = builder.finish();
let main = ssa.main();
assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1);
assert_eq!(main.dfg[b1].instructions().len(), 1);

// Expected output:
//
// fn main f0 {
// b0(v0: u1):
// v1 = not v0
// jmp b1()
// b1():
// return v1
// }
let ssa = ssa.fold_constants_using_constraints();
let main = ssa.main();
assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1);
assert_eq!(main.dfg[b1].instructions().len(), 0);
}
}

0 comments on commit 6a441fd

Please sign in to comment.