Skip to content

Commit

Permalink
chore: Cleanup mem2reg pass (#2531)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Sep 5, 2023
1 parent 255febd commit 8af53bf
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 233 deletions.
264 changes: 31 additions & 233 deletions crates/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
//! SSA optimization pipeline, although it will be more successful the simpler the program's CFG is.
//! This pass is currently performed several times to enable other passes - most notably being
//! performed before loop unrolling to try to allow for mutable variables used for loop indices.
mod alias_set;
mod block;

use std::collections::{BTreeMap, BTreeSet};

use crate::ssa::{
Expand All @@ -78,6 +81,9 @@ use crate::ssa::{
ssa_gen::Ssa,
};

use self::alias_set::AliasSet;
use self::block::{Block, Expression};

impl Ssa {
/// Attempts to remove any load instructions that recover values that are already available in
/// scope, and attempts to remove stores that are subsequently redundant.
Expand Down Expand Up @@ -107,45 +113,6 @@ struct PerFunctionContext<'f> {
instructions_to_remove: BTreeSet<InstructionId>,
}

#[derive(Debug, Default, Clone)]
struct Block {
/// Maps a ValueId to the Expression it represents.
/// Multiple ValueIds can map to the same Expression, e.g.
/// dereferences to the same allocation.
expressions: BTreeMap<ValueId, Expression>,

/// Each expression is tracked as to how many aliases it
/// may have. If there is only 1, we can attempt to optimize
/// out any known loads to that alias. Note that "alias" here
/// includes the original reference as well.
aliases: BTreeMap<Expression, BTreeSet<ValueId>>,

/// Each allocate instruction result (and some reference block parameters)
/// will map to a Reference value which tracks whether the last value stored
/// to the reference is known.
references: BTreeMap<ValueId, ReferenceValue>,

/// The last instance of a `Store` instruction to each address in this block
last_stores: BTreeMap<ValueId, InstructionId>,
}

/// An `Expression` here is used to represent a canonical key
/// into the aliases map since otherwise two dereferences of the
/// same address will be given different ValueIds.
#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
enum Expression {
Dereference(Box<Expression>),
ArrayElement(Box<Expression>),
Other(ValueId),
}

/// Every reference's value is either Known and can be optimized away, or Unknown.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum ReferenceValue {
Unknown,
Known(ValueId),
}

impl<'f> PerFunctionContext<'f> {
fn new(function: &'f mut Function) -> Self {
let cfg = ControlFlowGraph::with_function(function);
Expand Down Expand Up @@ -234,9 +201,10 @@ impl<'f> PerFunctionContext<'f> {
if let Some(expression) = references.expressions.get(allocation) {
if let Some(aliases) = references.aliases.get(expression) {
let allocation_aliases_parameter =
aliases.iter().any(|alias| reference_parameters.contains(alias));
aliases.any(|alias| reference_parameters.contains(&alias));

if !aliases.is_empty() && !allocation_aliases_parameter {
// If `allocation_aliases_parameter` is known to be false
if allocation_aliases_parameter == Some(false) {
self.instructions_to_remove.insert(*instruction);
}
}
Expand Down Expand Up @@ -290,17 +258,11 @@ impl<'f> PerFunctionContext<'f> {
references.set_known_value(address, value);
references.last_stores.insert(address, instruction);
}
Instruction::Call { arguments, .. } => {
self.mark_all_unknown(arguments, references);
}
Instruction::Allocate => {
// Register the new reference
let result = self.inserter.function.dfg.instruction_results(instruction)[0];
references.expressions.insert(result, Expression::Other(result));

let mut aliases = BTreeSet::new();
aliases.insert(result);
references.aliases.insert(Expression::Other(result), aliases);
references.aliases.insert(Expression::Other(result), AliasSet::known(result));
}
Instruction::ArrayGet { array, .. } => {
let result = self.inserter.function.dfg.instruction_results(instruction)[0];
Expand All @@ -317,28 +279,34 @@ impl<'f> PerFunctionContext<'f> {
}
Instruction::ArraySet { array, value, .. } => {
references.mark_value_used(*array, self.inserter.function);
let element_type = self.inserter.function.dfg.type_of_value(*value);

if self.inserter.function.dfg.value_is_reference(*value) {
if Self::contains_references(&element_type) {
let result = self.inserter.function.dfg.instruction_results(instruction)[0];
let array = self.inserter.function.dfg.resolve(*array);

let expression = Expression::ArrayElement(Box::new(Expression::Other(array)));

if let Some(aliases) = references.aliases.get_mut(&expression) {
aliases.insert(result);
let mut aliases = if let Some(aliases) = references.aliases.get_mut(&expression)
{
aliases.clone()
} else if let Some((elements, _)) =
self.inserter.function.dfg.get_array_constant(array)
{
// TODO: This should be a unification of each alias set
// If any are empty, the whole should be as well.
for reference in elements {
self.try_add_alias(references, reference, array);
}
}
let aliases = references.collect_all_aliases(elements);
self.set_aliases(references, array, aliases.clone());
aliases
} else {
AliasSet::unknown()
};

aliases.unify(&references.get_aliases_for_value(*value));

references.expressions.insert(result, expression);
references.expressions.insert(result, expression.clone());
references.aliases.insert(expression, aliases);
}
}
Instruction::Call { arguments, .. } => self.mark_all_unknown(arguments, references),
_ => (),
}
}
Expand Down Expand Up @@ -369,12 +337,11 @@ impl<'f> PerFunctionContext<'f> {
}
}

fn try_add_alias(&self, references: &mut Block, reference: ValueId, alias: ValueId) {
if let Some(expression) = references.expressions.get(&reference) {
if let Some(aliases) = references.aliases.get_mut(expression) {
aliases.insert(alias);
}
}
fn set_aliases(&self, references: &mut Block, address: ValueId, new_aliases: AliasSet) {
let expression =
references.expressions.entry(address).or_insert(Expression::Other(address));
let aliases = references.aliases.entry(expression.clone()).or_default();
*aliases = new_aliases;
}

fn mark_all_unknown(&self, values: &[ValueId], references: &mut Block) {
Expand Down Expand Up @@ -432,175 +399,6 @@ impl<'f> PerFunctionContext<'f> {
}
}

impl Block {
/// If the given reference id points to a known value, return the value
fn get_known_value(&self, address: ValueId) -> Option<ValueId> {
if let Some(expression) = self.expressions.get(&address) {
if let Some(aliases) = self.aliases.get(expression) {
// We could allow multiple aliases if we check that the reference
// value in each is equal.
if aliases.len() == 1 {
let alias = aliases.first().expect("There should be exactly 1 alias");

if let Some(ReferenceValue::Known(value)) = self.references.get(alias) {
return Some(*value);
}
}
}
}
None
}

/// If the given address is known, set its value to `ReferenceValue::Known(value)`.
fn set_known_value(&mut self, address: ValueId, value: ValueId) {
self.set_value(address, ReferenceValue::Known(value));
}

fn set_unknown(&mut self, address: ValueId) {
self.set_value(address, ReferenceValue::Unknown);
}

fn set_value(&mut self, address: ValueId, value: ReferenceValue) {
let expression = self.expressions.entry(address).or_insert(Expression::Other(address));
let aliases = self.aliases.entry(expression.clone()).or_default();

if aliases.is_empty() {
// uh-oh, we don't know at all what this reference refers to, could be anything.
// Now we have to invalidate every reference we know of
self.invalidate_all_references();
} else if aliases.len() == 1 {
let alias = aliases.first().expect("There should be exactly 1 alias");
self.references.insert(*alias, value);
} else {
// More than one alias. We're not sure which it refers to so we have to
// conservatively invalidate all references it may refer to.
for alias in aliases.iter() {
if let Some(reference_value) = self.references.get_mut(alias) {
*reference_value = ReferenceValue::Unknown;
}
}
}
}

fn invalidate_all_references(&mut self) {
for reference_value in self.references.values_mut() {
*reference_value = ReferenceValue::Unknown;
}

self.last_stores.clear();
}

fn unify(mut self, other: &Self) -> Self {
for (value_id, expression) in &other.expressions {
if let Some(existing) = self.expressions.get(value_id) {
assert_eq!(existing, expression, "Expected expressions for {value_id} to be equal");
} else {
self.expressions.insert(*value_id, expression.clone());
}
}

for (expression, new_aliases) in &other.aliases {
let expression = expression.clone();

self.aliases
.entry(expression)
.and_modify(|aliases| {
for alias in new_aliases {
aliases.insert(*alias);
}
})
.or_insert_with(|| new_aliases.clone());
}

// Keep only the references present in both maps.
let mut intersection = BTreeMap::new();
for (value_id, reference) in &other.references {
if let Some(existing) = self.references.get(value_id) {
intersection.insert(*value_id, existing.unify(*reference));
}
}
self.references = intersection;

self
}

/// Remember that `result` is the result of dereferencing `address`. This is important to
/// track aliasing when references are stored within other references.
fn remember_dereference(&mut self, function: &Function, address: ValueId, result: ValueId) {
if function.dfg.value_is_reference(result) {
if let Some(known_address) = self.get_known_value(address) {
self.expressions.insert(result, Expression::Other(known_address));
} else {
let expression = Expression::Dereference(Box::new(Expression::Other(address)));
self.expressions.insert(result, expression);
// No known aliases to insert for this expression... can we find an alias
// even if we don't have a known address? If not we'll have to invalidate all
// known references if this reference is ever stored to.
}
}
}

/// Iterate through each known alias of the given address and apply the function `f` to each.
fn for_each_alias_of<T>(
&mut self,
address: ValueId,
mut f: impl FnMut(&mut Self, ValueId) -> T,
) {
if let Some(expr) = self.expressions.get(&address) {
if let Some(aliases) = self.aliases.get(expr).cloned() {
for alias in aliases {
f(self, alias);
}
}
}
}

fn keep_last_stores_for(&mut self, address: ValueId, function: &Function) {
let address = function.dfg.resolve(address);
self.keep_last_store(address, function);
self.for_each_alias_of(address, |t, alias| t.keep_last_store(alias, function));
}

fn keep_last_store(&mut self, address: ValueId, function: &Function) {
let address = function.dfg.resolve(address);

if let Some(instruction) = self.last_stores.remove(&address) {
// Whenever we decide we want to keep a store instruction, we also need
// to go through its stored value and mark that used as well.
match &function.dfg[instruction] {
Instruction::Store { value, .. } => {
self.mark_value_used(*value, function);
}
other => {
unreachable!("last_store held an id of a non-store instruction: {other:?}")
}
}
}
}

fn mark_value_used(&mut self, value: ValueId, function: &Function) {
self.keep_last_stores_for(value, function);

// We must do a recursive check for arrays since they're the only Values which may contain
// other ValueIds.
if let Some((array, _)) = function.dfg.get_array_constant(value) {
for value in array {
self.mark_value_used(value, function);
}
}
}
}

impl ReferenceValue {
fn unify(self, other: Self) -> Self {
if self == other {
self
} else {
ReferenceValue::Unknown
}
}
}

#[cfg(test)]
mod tests {
use std::rc::Rc;
Expand Down
Loading

0 comments on commit 8af53bf

Please sign in to comment.