Skip to content

Commit

Permalink
Improve CSE stats
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Oct 23, 2024
1 parent 211e76e commit cc68012
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,17 @@ impl<'n, N: HashNode> Identifier<'n, N> {
/// ```
type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;

/// A map that contains the number of normal and conditional occurrences of [`TreeNode`]s
/// by their identifiers.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize)>;
#[derive(PartialEq, Eq)]
/// How many times a node is evaluated. A node can be considered common if evaluated
/// surely at least 2 times or surely only once but also conditionally.
enum NodeEvaluation {
SurelyOnce,
ConditionallyAtLeastOnce,
Common,
}

/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;

/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
/// extracted during the second, rewriting traversal.
Expand Down Expand Up @@ -331,16 +339,24 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
self.id_array[down_index].0 = self.up_index;
if is_valid && !self.controller.is_ignored(node) {
self.id_array[down_index].1 = Some(node_id);
let (count, conditional_count) =
self.node_stats.entry(node_id).or_insert((0, 0));
if self.conditional {
*conditional_count += 1;
} else {
*count += 1;
}
if *count > 1 || (*count == 1 && *conditional_count > 0) {
self.found_common = true;
}
self.node_stats
.entry(node_id)
.and_modify(|evaluation| {
if *evaluation == NodeEvaluation::SurelyOnce
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
&& !self.conditional
{
*evaluation = NodeEvaluation::Common;
self.found_common = true;
}
})
.or_insert_with(|| {
if self.conditional {
NodeEvaluation::ConditionallyAtLeastOnce
} else {
NodeEvaluation::SurelyOnce
}
});
}
self.visit_stack
.push(VisitRecord::NodeItem(node_id, is_valid));
Expand Down Expand Up @@ -383,8 +399,8 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter

// Handle nodes with identifiers only
if let Some(node_id) = node_id {
let (count, conditional_count) = self.node_stats.get(&node_id).unwrap();
if *count > 1 || *count == 1 && *conditional_count > 0 {
let evaluation = self.node_stats.get(&node_id).unwrap();
if *evaluation == NodeEvaluation::Common {
// step index to skip all sub-node (which has smaller series number).
while self.down_index < self.id_array.len()
&& self.id_array[self.down_index].0 < up_index
Expand Down

0 comments on commit cc68012

Please sign in to comment.