Skip to content

Commit

Permalink
address review comments, move HashNode to datafusion_common::cse,…
Browse files Browse the repository at this point in the history
… shorter names for eliminator and controller, change `CSE::extract_common_nodes()` to return `Result<FoundCommonNodes<N>>` (instead of `Result<Transformed<FoundCommonNodes<N>>>`)
  • Loading branch information
peter-toth committed Oct 20, 2024
1 parent 807d186 commit 1cae61c
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 241 deletions.
116 changes: 64 additions & 52 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use crate::hash_node::HashNode;
//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with
//! a [`CSEController`], that defines how to eliminate common subtrees from a particular
//! [`TreeNode`] tree.
use crate::hash_utils::combine_hashes;
use crate::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
Expand All @@ -26,6 +29,26 @@ use indexmap::IndexMap;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::marker::PhantomData;
use std::sync::Arc;

/// Hashes the direct content of an [`TreeNode`] without recursing into its children.
///
/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds
/// a deep hash of a node and its descendants during the bottom-up phase of the first
/// traversal and so avoid computing the hash of the node and then the hash of its
/// descendants separately.
///
/// If a node doesn't have any children then the value returned by `hash_node()` is
/// similar to '.hash()`, but not necessarily returns the same value.
pub trait HashNode {
fn hash_node<H: Hasher>(&self, state: &mut H);
}

impl<T: HashNode + ?Sized> HashNode for Arc<T> {
fn hash_node<H: Hasher>(&self, state: &mut H) {
(**self).hash_node(state);
}
}

/// Identifier that represents a [`TreeNode`] tree.
///
Expand Down Expand Up @@ -72,8 +95,8 @@ impl<'n, N: HashNode> Identifier<'n, N> {
/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the
/// preorder index of the nodes.
///
/// This cache is filled by [`CommonSubTreeNodeVisitor`] during the first traversal and is
/// used by [`CommonSubTreeNodeRewriter`] during the second traversal.
/// This cache is filled by [`CSEVisitor`] during the first traversal and is
/// used by [`CSERewriter`] during the second traversal.
///
/// The purpose of this cache is to quickly find the identifier of a node during the
/// second traversal.
Expand Down Expand Up @@ -108,7 +131,8 @@ type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;

type ChildrenList<N> = (Vec<N>, Vec<N>);

pub trait SubTreeNodeEliminatorController {
/// The [`TreeNode`] specific definition of elimination.
pub trait CSEController {
/// The type of the tree nodes.
type Node;

Expand All @@ -135,11 +159,11 @@ pub trait SubTreeNodeEliminatorController {

// A helper method called on each node during top-down traversal during the second,
// rewriting traversal of CSE.
fn rewrite_f_down(&mut self, node: &Self::Node);
fn rewrite_f_down(&mut self, _node: &Self::Node) {}

// A helper method called on each node during bottom-up traversal during the second,
// rewriting traversal of CSE.
fn rewrite_f_up(&mut self, node: &Self::Node);
fn rewrite_f_up(&mut self, _node: &Self::Node) {}
}

/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common
Expand Down Expand Up @@ -181,7 +205,7 @@ pub enum FoundCommonNodes<N> {
///
/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
/// because they should not be recognized as common subtree.
struct CommonSubTreeNodeVisitor<'a, 'n, N, C: SubTreeNodeEliminatorController<Node = N>> {
struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
/// statistics of [`TreeNode`]s
node_stats: &'a mut NodeStats<'n, N>,

Expand Down Expand Up @@ -226,9 +250,7 @@ enum VisitRecord<'n, N> {
NodeItem(Identifier<'n, N>, bool),
}

impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController<Node = N>>
CommonSubTreeNodeVisitor<'_, 'n, N, C>
{
impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n, N, C> {
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
/// it. Returns a tuple that contains:
/// - The pre-order index of the [`TreeNode`] we marked.
Expand Down Expand Up @@ -260,8 +282,8 @@ impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController<Node = N>>
}
}

impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController<Node = N>>
TreeNodeVisitor<'n> for CommonSubTreeNodeVisitor<'_, 'n, N, C>
impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisitor<'n>
for CSEVisitor<'_, 'n, N, C>
{
type Node = N;

Expand Down Expand Up @@ -331,8 +353,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController<Node =
/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
/// replaced [`TreeNode`] tree.
struct CommonSubTreeNodeRewriter<'a, 'n, N, C: SubTreeNodeEliminatorController<Node = N>>
{
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
/// statistics of [`TreeNode`]s
node_stats: &'a NodeStats<'n, N>,

Expand All @@ -349,8 +370,8 @@ struct CommonSubTreeNodeRewriter<'a, 'n, N, C: SubTreeNodeEliminatorController<N
controller: &'a mut C,
}

impl<N: TreeNode + Eq, C: SubTreeNodeEliminatorController<Node = N>> TreeNodeRewriter
for CommonSubTreeNodeRewriter<'_, '_, N, C>
impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
for CSERewriter<'_, '_, N, C>
{
type Node = N;

Expand Down Expand Up @@ -393,17 +414,18 @@ impl<N: TreeNode + Eq, C: SubTreeNodeEliminatorController<Node = N>> TreeNodeRew
}
}

pub struct SubTreeNodeEliminator<N, C: SubTreeNodeEliminatorController<Node = N>> {
/// The main entry point of Common Subexpression Elimination.
///
/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular
/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the
/// [`CSE::extract_common_nodes()`] method.
pub struct CSE<N, C: CSEController<Node = N>> {
random_state: RandomState,
phantom_data: PhantomData<N>,
controller: C,
}

impl<
N: TreeNode + HashNode + Clone + Eq,
C: SubTreeNodeEliminatorController<Node = N>,
> SubTreeNodeEliminator<N, C>
{
impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C> {
pub fn new(controller: C) -> Self {
Self {
random_state: RandomState::new(),
Expand All @@ -419,7 +441,7 @@ impl<
node_stats: &mut NodeStats<'n, N>,
id_array: &mut IdArray<'n, N>,
) -> Result<bool> {
let mut visitor = CommonSubTreeNodeVisitor {
let mut visitor = CSEVisitor {
node_stats,
id_array,
visit_stack: vec![],
Expand All @@ -440,7 +462,7 @@ impl<
///
/// Returns and array with 1 element for each input node in `nodes`
///
/// Each element is itself the result of [`SubTreeNodeEliminator::node_to_id_array`] for that node
/// Each element is itself the result of [`CSE::node_to_id_array`] for that node
/// (e.g. the identifiers for each node in the tree)
fn to_arrays<'n>(
&self,
Expand Down Expand Up @@ -475,7 +497,7 @@ impl<
if id_array.is_empty() {
Ok(Transformed::no(node))
} else {
node.rewrite(&mut CommonSubTreeNodeRewriter {
node.rewrite(&mut CSERewriter {
node_stats,
id_array,
common_nodes,
Expand Down Expand Up @@ -516,7 +538,7 @@ impl<
pub fn extract_common_nodes(
&mut self,
nodes_list: Vec<Vec<N>>,
) -> Result<Transformed<FoundCommonNodes<N>>> {
) -> Result<FoundCommonNodes<N>> {
let mut found_common = false;
let mut node_stats = NodeStats::new();
let id_arrays_list = nodes_list
Expand All @@ -542,27 +564,23 @@ impl<
)?;
assert!(!common_nodes.is_empty());

Ok(Transformed::yes(FoundCommonNodes::Yes {
Ok(FoundCommonNodes::Yes {
common_nodes: common_nodes.into_values().collect(),
new_nodes_list,
original_nodes_list: nodes_list,
}))
})
} else {
Ok(Transformed::no(FoundCommonNodes::No {
Ok(FoundCommonNodes::No {
original_nodes_list: nodes_list,
}))
})
}
}
}

#[cfg(test)]
mod test {
use crate::alias::AliasGenerator;
use crate::cse::{
IdArray, Identifier, NodeStats, SubTreeNodeEliminator,
SubTreeNodeEliminatorController,
};
use crate::hash_node::HashNode;
use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE};
use crate::tree_node::tests::TestTreeNode;
use crate::Result;
use std::collections::HashSet;
Expand All @@ -576,12 +594,12 @@ mod test {
NormalAndAggregates,
}

pub struct SubTestTreeNodeEliminatorController<'a> {
pub struct TestTreeNodeCSEController<'a> {
alias_generator: &'a AliasGenerator,
mask: TestTreeNodeMask,
}

impl<'a> SubTestTreeNodeEliminatorController<'a> {
impl<'a> TestTreeNodeCSEController<'a> {
fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self {
Self {
alias_generator,
Expand All @@ -590,7 +608,7 @@ mod test {
}
}

impl SubTreeNodeEliminatorController for SubTestTreeNodeEliminatorController<'_> {
impl CSEController for TestTreeNodeCSEController<'_> {
type Node = TestTreeNode<String>;

fn conditional_children(
Expand Down Expand Up @@ -620,10 +638,6 @@ mod test {
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
}

fn rewrite_f_down(&mut self, _node: &Self::Node) {}

fn rewrite_f_up(&mut self, _node: &Self::Node) {}
}

impl HashNode for TestTreeNode<String> {
Expand All @@ -635,11 +649,10 @@ mod test {
#[test]
fn id_array_visitor() -> Result<()> {
let alias_generator = AliasGenerator::new();
let eliminator =
SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new(
&alias_generator,
TestTreeNodeMask::Normal,
));
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::Normal,
));

let a_plus_1 = TestTreeNode::new(
vec![
Expand Down Expand Up @@ -728,11 +741,10 @@ mod test {
assert_eq!(expected, id_array);

// include aggregates
let eliminator =
SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new(
&alias_generator,
TestTreeNodeMask::NormalAndAggregates,
));
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::NormalAndAggregates,
));

let mut id_array = vec![];
eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
Expand Down
29 changes: 0 additions & 29 deletions datafusion/common/src/hash_node.rs

This file was deleted.

1 change: 0 additions & 1 deletion datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub mod display;
pub mod error;
pub mod file_options;
pub mod format;
pub mod hash_node;
pub mod hash_utils;
pub mod instant;
pub mod parsers;
Expand Down
12 changes: 1 addition & 11 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
};

use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::hash_node::HashNode;
use datafusion_common::cse::HashNode;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
Expand Down Expand Up @@ -1656,16 +1656,6 @@ impl Expr {
}

impl HashNode for Expr {
/// Hashes the direct content of an `Expr` without recursing into its children.
///
/// This method is useful to incrementally compute hashes, such as in
/// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants
/// during the bottom-up phase of the first traversal and so avoid computing the hash
/// of the node and then the hash of its descendants separately.
///
/// If a node doesn't have any children then this method is similar to `.hash()`, but
/// not necessarily returns the same value.
///
/// As it is pretty easy to forget changing this method when `Expr` changes the
/// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes
/// compile time.
Expand Down
Loading

0 comments on commit 1cae61c

Please sign in to comment.