Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): Speedup find-branch-ends #1786

Merged
merged 2 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pub(crate) struct DominatorTree {
/// After dominator tree computation has complete, this will contain a node for every
/// reachable block, and no nodes for unreachable blocks.
nodes: HashMap<BasicBlockId, DominatorTreeNode>,

/// Subsequent calls to `dominates` are cached to speed up access
cache: HashMap<(BasicBlockId, BasicBlockId), bool>,
}

/// Methods for querying the dominator tree.
Expand Down Expand Up @@ -83,7 +86,21 @@ impl DominatorTree {
/// This function panics if either of the blocks are unreachable.
///
/// An instruction is considered to dominate itself.
pub(crate) fn dominates(&self, block_a_id: BasicBlockId, mut block_b_id: BasicBlockId) -> bool {
pub(crate) fn dominates(&mut self, block_a_id: BasicBlockId, block_b_id: BasicBlockId) -> bool {
if let Some(res) = self.cache.get(&(block_a_id, block_b_id)) {
return *res;
}

let result = self.dominates_helper(block_a_id, block_b_id);
self.cache.insert((block_a_id, block_b_id), result);
result
}

pub(crate) fn dominates_helper(
&self,
block_a_id: BasicBlockId,
mut block_b_id: BasicBlockId,
) -> bool {
// Walk up the dominator tree from "b" until we encounter or pass "a". Doing the
// comparison on the reverse post-order may allows to test whether we have passed "a"
// without waiting until we reach the root of the tree.
Expand All @@ -104,7 +121,7 @@ impl DominatorTree {
/// Allocate and compute a dominator tree from a pre-computed control flow graph and
/// post-order counterpart.
pub(crate) fn with_cfg_and_post_order(cfg: &ControlFlowGraph, post_order: &PostOrder) -> Self {
let mut dom_tree = DominatorTree { nodes: HashMap::new() };
let mut dom_tree = DominatorTree { nodes: HashMap::new(), cache: HashMap::new() };
dom_tree.compute_dominator_tree(cfg, post_order);
dom_tree
}
Expand Down Expand Up @@ -249,7 +266,7 @@ mod tests {
block0_id,
TerminatorInstruction::Return { return_values: vec![] },
);
let dom_tree = DominatorTree::with_function(&func);
let mut dom_tree = DominatorTree::with_function(&func);
assert!(dom_tree.dominates(block0_id, block0_id));
}

Expand Down Expand Up @@ -308,7 +325,7 @@ mod tests {
// unreachable, performing this query indicates an internal compiler error.
#[test]
fn unreachable_node_asserts() {
let (dt, b0, _b1, b2, b3) = unreachable_node_setup();
let (mut dt, b0, _b1, b2, b3) = unreachable_node_setup();

assert!(dt.dominates(b0, b0));
assert!(dt.dominates(b0, b2));
Expand All @@ -326,42 +343,42 @@ mod tests {
#[test]
#[should_panic]
fn unreachable_node_panic_b0_b1() {
let (dt, b0, b1, _b2, _b3) = unreachable_node_setup();
let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup();
dt.dominates(b0, b1);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b0() {
let (dt, b0, b1, _b2, _b3) = unreachable_node_setup();
let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup();
dt.dominates(b1, b0);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b1() {
let (dt, _b0, b1, _b2, _b3) = unreachable_node_setup();
let (mut dt, _b0, b1, _b2, _b3) = unreachable_node_setup();
dt.dominates(b1, b1);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b2() {
let (dt, _b0, b1, b2, _b3) = unreachable_node_setup();
let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup();
dt.dominates(b1, b2);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b3() {
let (dt, _b0, b1, _b2, b3) = unreachable_node_setup();
let (mut dt, _b0, b1, _b2, b3) = unreachable_node_setup();
dt.dominates(b1, b3);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b3_b1() {
let (dt, _b0, b1, b2, _b3) = unreachable_node_setup();
let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup();
dt.dominates(b2, b1);
}

Expand Down Expand Up @@ -390,7 +407,7 @@ mod tests {
let func = ssa.main();
let block0_id = func.entry_block();

let dt = DominatorTree::with_function(func);
let mut dt = DominatorTree::with_function(func);

// Expected dominance tree:
// block0 {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
//! This is an algorithm for identifying branch starts and ends.
use std::collections::{HashMap, HashSet};
//!
//! The algorithm is split into two parts:
//! 1. The outer part:
//! A. An (unrolled) CFG can be though of as a linear sequence of blocks where some nodes split
//! off, but eventually rejoin to a new node and continue the linear sequence.
//! B. Follow this sequence in order, and whenever a split is found call
//! `find_join_point_of_branches` and then recur from the join point it returns until the
//! return instruction is found.
//!
//! 2. The inner part defined by `find_join_point_of_branches`:
//! A. For each of the two branches in a jmpif block:
//! - Check if either has multiple predecessors. If so, it is a join point.
//! - If not, continue to search the linear sequence of successor blocks from that block.
//! - If another split point is found, recur in `find_join_point_of_branches`
//! - If a block with multiple predecessors is found, return it.
//! - After, we should have identified a join point for both branches. This is expected to be
//! the same block for both and can be returned from here to continue iteration.
//!
//! This algorithm will remember each join point found in `find_join_point_of_branches` and
//! the resulting map from each split block to each join block is returned.
use std::collections::HashMap;

use crate::ssa_refactor::ir::{
basic_block::BasicBlockId, cfg::ControlFlowGraph, dom::DominatorTree, function::Function,
post_order::PostOrder,
basic_block::BasicBlockId, cfg::ControlFlowGraph, function::Function,
};

/// Returns a `HashMap` mapping blocks that start a branch (i.e. blocks terminated with jmpif) to
Expand All @@ -16,121 +35,78 @@ pub(super) fn find_branch_ends(
function: &Function,
cfg: &ControlFlowGraph,
) -> HashMap<BasicBlockId, BasicBlockId> {
let post_order = PostOrder::with_function(function);
let dom_tree = DominatorTree::with_cfg_and_post_order(cfg, &post_order);
let mut stepper = Stepper::new(function.entry_block());
// This outer `visited` set is inconsequential, and simply here to satisfy the recursive
// stepper interface.
let mut visited = HashSet::new();
let mut branch_ends = HashMap::new();
while !stepper.finished {
stepper.step(cfg, &dom_tree, &mut visited, &mut branch_ends);
}
branch_ends
}
let mut block = function.entry_block();
let mut context = Context::new(cfg);

/// Returns the block at which `left` and `right` converge, at the same time identifying branch
/// ends in any sub branches.
///
/// This function is called by `Stepper::step` and is thus recursive.
fn step_until_rejoin(
cfg: &ControlFlowGraph,
dom_tree: &DominatorTree,
branch_ends: &mut HashMap<BasicBlockId, BasicBlockId>,
left: BasicBlockId,
right: BasicBlockId,
) -> BasicBlockId {
let mut visited = HashSet::new();
let mut left_stepper = Stepper::new(left);
let mut right_stepper = Stepper::new(right);
loop {
let mut successors = cfg.successors(block);

while !left_stepper.finished || !right_stepper.finished {
left_stepper.step(cfg, dom_tree, &mut visited, branch_ends);
right_stepper.step(cfg, dom_tree, &mut visited, branch_ends);
if successors.len() == 2 {
block = context.find_join_point_of_branches(block, successors);
} else if successors.len() == 1 {
block = successors.next().unwrap();
} else if successors.len() == 0 {
// return encountered. We have nothing to join, so we're done
break;
} else {
unreachable!("A block can only have 0, 1, or 2 successors");
}
}
let collision = match (left_stepper.collision, right_stepper.collision) {
(Some(collision), None) | (None, Some(collision)) => collision,
(Some(_),Some(_))=> unreachable!("A collision on both branches indicates a loop"),
_ => unreachable!(
"Until we support multiple returns, branches always re-converge. Once supported this case should return `None`"
),
};
collision

context.branch_ends
}

/// Tracks traversal along the arm of a branch. Steppers are progressed in pairs, such that the
/// re-convergence point of two arms is discovered as soon as possible. The exceptional case is
/// that of the top level stepper, which conveniently steps the whole CFG as if it were a single
/// arm.
struct Stepper {
/// The block that will be interrogated when calling `step`
current_block: BasicBlockId,
/// Indicates that the stepper has no more block successors to process, either because it has
/// reached the end of the CFG, or because it encountered a block already visited by its
/// sibling stepper.
finished: bool,
/// Once finished this option indicates whether a collision was encountered before reaching
/// the end of the CFG.
collision: Option<BasicBlockId>,
struct Context<'cfg> {
branch_ends: HashMap<BasicBlockId, BasicBlockId>,
cfg: &'cfg ControlFlowGraph,
}

impl Stepper {
/// Creates a fresh stepper instance
fn new(current_block: BasicBlockId) -> Self {
Stepper { current_block, finished: false, collision: None }
impl<'cfg> Context<'cfg> {
fn new(cfg: &'cfg ControlFlowGraph) -> Self {
Self { cfg, branch_ends: HashMap::new() }
}

/// Checks the current block to see if it has already been visited and if so marks it as a
/// collision. If a sub-branch is encountered `step_until_rejoin` is called to start a pair
/// of child steppers stepping along its arms.
///
/// It is safe to call this even when the stepper has reached its end.
fn step(
fn find_join_point_of_branches(
&mut self,
cfg: &ControlFlowGraph,
dom_tree: &DominatorTree,
visited: &mut HashSet<BasicBlockId>,
branch_ends: &mut HashMap<BasicBlockId, BasicBlockId>,
) {
if self.finished {
// The caller still needs to progress the other stepper, while this one sits idle.
return;
}
if visited.contains(&self.current_block) {
// The other stepper has already visited this block - thus this block is the
// re.-convergence point.
self.collision = Some(self.current_block);
self.finished = true;
start: BasicBlockId,
mut successors: impl Iterator<Item = BasicBlockId>,
) -> BasicBlockId {
let left = successors.next().unwrap();
let right = successors.next().unwrap();

let left_join = self.find_join_point(left);
let right_join = self.find_join_point(right);

assert_eq!(left_join, right_join, "Expected two blocks to join to the same block");
self.branch_ends.insert(start, left_join);

left_join
}

fn find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId {
let predecessors = self.cfg.predecessors(block);
if predecessors.len() > 1 {
return block;
}
visited.insert(self.current_block);
// The join point is not this block, so continue on
self.skip_then_find_join_point(block)
}

let mut successors = cfg.successors(self.current_block);
match successors.len() {
0 => {
// Reached the end of the CFG without a collision - this will happen in the other
// stepper assuming the CFG contains no early returns.
self.finished = true;
}
1 => {
// This block doesn't describe any branch starts or ends - move on.
self.current_block = successors.next().unwrap();
}
2 => {
// Sub-branch start encountered - recurse to find the end of the sub branch
let left = successors.next().unwrap();
let right = successors.next().unwrap();
let sub_branch_end = step_until_rejoin(cfg, dom_tree, branch_ends, left, right);
for collision_predecessor in cfg.predecessors(sub_branch_end) {
assert!(dom_tree.dominates(self.current_block, collision_predecessor));
}
branch_ends.insert(self.current_block, sub_branch_end);
fn skip_then_find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId {
let mut successors = self.cfg.successors(block);

// Resume stepping though the current arm fro where the sub-branch left off
self.current_block = sub_branch_end;
}
_ => {
unreachable!("Basic blocks never have more than 2 successors")
}
if successors.len() == 2 {
let join = self.find_join_point_of_branches(block, successors);
// Note that we call skip_then_find_join_point here instead of find_join_point.
// We already know this `join` is a join point, but it cannot be for the current block
// since we already know it is the join point of the successors of the current block.
self.skip_then_find_join_point(join)
} else if successors.len() == 1 {
self.find_join_point(successors.next().unwrap())
} else if successors.len() == 0 {
unreachable!("return encountered before a join point was found. This can only happen if early-return was added to the language without implementing it by jmping to a join block first")
} else {
unreachable!("A block can only have 0, 1, or 2 successors");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct Loops {
fn find_all_loops(function: &Function) -> Loops {
let cfg = ControlFlowGraph::with_function(function);
let post_order = PostOrder::with_function(function);
let dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order);
let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order);

let mut loops = vec![];

Expand Down