Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract CSE logic to datafusion_common #13002

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 64 additions & 52 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add some documentation explaining what is in the module and what the main entry points are/

Something like

Suggested change
//! Common Subexpression Elimination logic: [`SubTreeNodeEliminator`]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, good idea. Added i documentation in 1cae61c.

use crate::hash_node::HashNode;
//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with
//! a [`CSEController`], that defines how to eliminate common subtrees from a particular
//! [`TreeNode`] tree.

use crate::hash_utils::combine_hashes;
use crate::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
Expand All @@ -26,6 +29,26 @@ use indexmap::IndexMap;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::marker::PhantomData;
use std::sync::Arc;

/// Hashes the direct content of an [`TreeNode`] without recursing into its children.
///
/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds
/// a deep hash of a node and its descendants during the bottom-up phase of the first
/// traversal and so avoid computing the hash of the node and then the hash of its
/// descendants separately.
///
/// If a node doesn't have any children then the value returned by `hash_node()` is
/// similar to '.hash()`, but not necessarily returns the same value.
pub trait HashNode {
fn hash_node<H: Hasher>(&self, state: &mut H);
}

impl<T: HashNode + ?Sized> HashNode for Arc<T> {
fn hash_node<H: Hasher>(&self, state: &mut H) {
(**self).hash_node(state);
}
}

/// Identifier that represents a [`TreeNode`] tree.
///
Expand Down Expand Up @@ -72,8 +95,8 @@ impl<'n, N: HashNode> Identifier<'n, N> {
/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the
/// preorder index of the nodes.
///
/// This cache is filled by [`CommonSubTreeNodeVisitor`] during the first traversal and is
/// used by [`CommonSubTreeNodeRewriter`] during the second traversal.
/// This cache is filled by [`CSEVisitor`] during the first traversal and is
/// used by [`CSERewriter`] during the second traversal.
///
/// The purpose of this cache is to quickly find the identifier of a node during the
/// second traversal.
Expand Down Expand Up @@ -108,7 +131,8 @@ type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;

type ChildrenList<N> = (Vec<N>, Vec<N>);

pub trait SubTreeNodeEliminatorController {
/// The [`TreeNode`] specific definition of elimination.
pub trait CSEController {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 for the new name

/// The type of the tree nodes.
type Node;

Expand All @@ -135,11 +159,11 @@ pub trait SubTreeNodeEliminatorController {

// A helper method called on each node during top-down traversal during the second,
// rewriting traversal of CSE.
fn rewrite_f_down(&mut self, node: &Self::Node);
fn rewrite_f_down(&mut self, _node: &Self::Node) {}

// A helper method called on each node during bottom-up traversal during the second,
// rewriting traversal of CSE.
fn rewrite_f_up(&mut self, node: &Self::Node);
fn rewrite_f_up(&mut self, _node: &Self::Node) {}
}

/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common
Expand Down Expand Up @@ -181,7 +205,7 @@ pub enum FoundCommonNodes<N> {
///
/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
/// because they should not be recognized as common subtree.
struct CommonSubTreeNodeVisitor<'a, 'n, N, C: SubTreeNodeEliminatorController<Node = N>> {
struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
/// statistics of [`TreeNode`]s
node_stats: &'a mut NodeStats<'n, N>,

Expand Down Expand Up @@ -226,9 +250,7 @@ enum VisitRecord<'n, N> {
NodeItem(Identifier<'n, N>, bool),
}

impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController<Node = N>>
CommonSubTreeNodeVisitor<'_, 'n, N, C>
{
impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n, N, C> {
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
/// it. Returns a tuple that contains:
/// - The pre-order index of the [`TreeNode`] we marked.
Expand Down Expand Up @@ -260,8 +282,8 @@ impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController<Node = N>>
}
}

impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController<Node = N>>
TreeNodeVisitor<'n> for CommonSubTreeNodeVisitor<'_, 'n, N, C>
impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisitor<'n>
for CSEVisitor<'_, 'n, N, C>
{
type Node = N;

Expand Down Expand Up @@ -331,8 +353,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController<Node =
/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
/// replaced [`TreeNode`] tree.
struct CommonSubTreeNodeRewriter<'a, 'n, N, C: SubTreeNodeEliminatorController<Node = N>>
{
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
/// statistics of [`TreeNode`]s
node_stats: &'a NodeStats<'n, N>,

Expand All @@ -349,8 +370,8 @@ struct CommonSubTreeNodeRewriter<'a, 'n, N, C: SubTreeNodeEliminatorController<N
controller: &'a mut C,
}

impl<N: TreeNode + Eq, C: SubTreeNodeEliminatorController<Node = N>> TreeNodeRewriter
for CommonSubTreeNodeRewriter<'_, '_, N, C>
impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
for CSERewriter<'_, '_, N, C>
{
type Node = N;

Expand Down Expand Up @@ -393,17 +414,18 @@ impl<N: TreeNode + Eq, C: SubTreeNodeEliminatorController<Node = N>> TreeNodeRew
}
}

pub struct SubTreeNodeEliminator<N, C: SubTreeNodeEliminatorController<Node = N>> {
/// The main entry point of Common Subexpression Elimination.
///
/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular
/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the
/// [`CSE::extract_common_nodes()`] method.
pub struct CSE<N, C: CSEController<Node = N>> {
random_state: RandomState,
phantom_data: PhantomData<N>,
controller: C,
}

impl<
N: TreeNode + HashNode + Clone + Eq,
C: SubTreeNodeEliminatorController<Node = N>,
> SubTreeNodeEliminator<N, C>
{
impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C> {
pub fn new(controller: C) -> Self {
Self {
random_state: RandomState::new(),
Expand All @@ -419,7 +441,7 @@ impl<
node_stats: &mut NodeStats<'n, N>,
id_array: &mut IdArray<'n, N>,
) -> Result<bool> {
let mut visitor = CommonSubTreeNodeVisitor {
let mut visitor = CSEVisitor {
node_stats,
id_array,
visit_stack: vec![],
Expand All @@ -440,7 +462,7 @@ impl<
///
/// Returns and array with 1 element for each input node in `nodes`
///
/// Each element is itself the result of [`SubTreeNodeEliminator::node_to_id_array`] for that node
/// Each element is itself the result of [`CSE::node_to_id_array`] for that node
/// (e.g. the identifiers for each node in the tree)
fn to_arrays<'n>(
&self,
Expand Down Expand Up @@ -475,7 +497,7 @@ impl<
if id_array.is_empty() {
Ok(Transformed::no(node))
} else {
node.rewrite(&mut CommonSubTreeNodeRewriter {
node.rewrite(&mut CSERewriter {
node_stats,
id_array,
common_nodes,
Expand Down Expand Up @@ -516,7 +538,7 @@ impl<
pub fn extract_common_nodes(
&mut self,
nodes_list: Vec<Vec<N>>,
) -> Result<Transformed<FoundCommonNodes<N>>> {
) -> Result<FoundCommonNodes<N>> {
let mut found_common = false;
let mut node_stats = NodeStats::new();
let id_arrays_list = nodes_list
Expand All @@ -542,27 +564,23 @@ impl<
)?;
assert!(!common_nodes.is_empty());

Ok(Transformed::yes(FoundCommonNodes::Yes {
Ok(FoundCommonNodes::Yes {
common_nodes: common_nodes.into_values().collect(),
new_nodes_list,
original_nodes_list: nodes_list,
}))
})
} else {
Ok(Transformed::no(FoundCommonNodes::No {
Ok(FoundCommonNodes::No {
original_nodes_list: nodes_list,
}))
})
}
}
}

#[cfg(test)]
mod test {
use crate::alias::AliasGenerator;
use crate::cse::{
IdArray, Identifier, NodeStats, SubTreeNodeEliminator,
SubTreeNodeEliminatorController,
};
use crate::hash_node::HashNode;
use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE};
use crate::tree_node::tests::TestTreeNode;
use crate::Result;
use std::collections::HashSet;
Expand All @@ -576,12 +594,12 @@ mod test {
NormalAndAggregates,
}

pub struct SubTestTreeNodeEliminatorController<'a> {
pub struct TestTreeNodeCSEController<'a> {
alias_generator: &'a AliasGenerator,
mask: TestTreeNodeMask,
}

impl<'a> SubTestTreeNodeEliminatorController<'a> {
impl<'a> TestTreeNodeCSEController<'a> {
fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self {
Self {
alias_generator,
Expand All @@ -590,7 +608,7 @@ mod test {
}
}

impl SubTreeNodeEliminatorController for SubTestTreeNodeEliminatorController<'_> {
impl CSEController for TestTreeNodeCSEController<'_> {
type Node = TestTreeNode<String>;

fn conditional_children(
Expand Down Expand Up @@ -620,10 +638,6 @@ mod test {
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
}

fn rewrite_f_down(&mut self, _node: &Self::Node) {}

fn rewrite_f_up(&mut self, _node: &Self::Node) {}
}

impl HashNode for TestTreeNode<String> {
Expand All @@ -635,11 +649,10 @@ mod test {
#[test]
fn id_array_visitor() -> Result<()> {
let alias_generator = AliasGenerator::new();
let eliminator =
SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new(
&alias_generator,
TestTreeNodeMask::Normal,
));
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::Normal,
));

let a_plus_1 = TestTreeNode::new(
vec![
Expand Down Expand Up @@ -728,11 +741,10 @@ mod test {
assert_eq!(expected, id_array);

// include aggregates
let eliminator =
SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new(
&alias_generator,
TestTreeNodeMask::NormalAndAggregates,
));
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::NormalAndAggregates,
));

let mut id_array = vec![];
eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
Expand Down
29 changes: 0 additions & 29 deletions datafusion/common/src/hash_node.rs

This file was deleted.

1 change: 0 additions & 1 deletion datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub mod display;
pub mod error;
pub mod file_options;
pub mod format;
pub mod hash_node;
pub mod hash_utils;
pub mod instant;
pub mod parsers;
Expand Down
12 changes: 1 addition & 11 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
};

use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::hash_node::HashNode;
use datafusion_common::cse::HashNode;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
Expand Down Expand Up @@ -1656,16 +1656,6 @@ impl Expr {
}

impl HashNode for Expr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

/// Hashes the direct content of an `Expr` without recursing into its children.
///
/// This method is useful to incrementally compute hashes, such as in
/// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants
/// during the bottom-up phase of the first traversal and so avoid computing the hash
/// of the node and then the hash of its descendants separately.
///
/// If a node doesn't have any children then this method is similar to `.hash()`, but
/// not necessarily returns the same value.
///
/// As it is pretty easy to forget changing this method when `Expr` changes the
/// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes
/// compile time.
Expand Down
Loading