-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extract CSE logic to datafusion_common
#13002
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,10 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use crate::hash_node::HashNode; | ||
//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with | ||
//! a [`CSEController`], that defines how to eliminate common subtrees from a particular | ||
//! [`TreeNode`] tree. | ||
|
||
use crate::hash_utils::combine_hashes; | ||
use crate::tree_node::{ | ||
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, | ||
|
@@ -26,6 +29,26 @@ use indexmap::IndexMap; | |
use std::collections::HashMap; | ||
use std::hash::{BuildHasher, Hash, Hasher, RandomState}; | ||
use std::marker::PhantomData; | ||
use std::sync::Arc; | ||
|
||
/// Hashes the direct content of an [`TreeNode`] without recursing into its children. | ||
/// | ||
/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds | ||
/// a deep hash of a node and its descendants during the bottom-up phase of the first | ||
/// traversal and so avoid computing the hash of the node and then the hash of its | ||
/// descendants separately. | ||
/// | ||
/// If a node doesn't have any children then the value returned by `hash_node()` is | ||
/// similar to '.hash()`, but not necessarily returns the same value. | ||
pub trait HashNode { | ||
fn hash_node<H: Hasher>(&self, state: &mut H); | ||
} | ||
|
||
impl<T: HashNode + ?Sized> HashNode for Arc<T> { | ||
fn hash_node<H: Hasher>(&self, state: &mut H) { | ||
(**self).hash_node(state); | ||
} | ||
} | ||
|
||
/// Identifier that represents a [`TreeNode`] tree. | ||
/// | ||
|
@@ -72,8 +95,8 @@ impl<'n, N: HashNode> Identifier<'n, N> { | |
/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the | ||
/// preorder index of the nodes. | ||
/// | ||
/// This cache is filled by [`CommonSubTreeNodeVisitor`] during the first traversal and is | ||
/// used by [`CommonSubTreeNodeRewriter`] during the second traversal. | ||
/// This cache is filled by [`CSEVisitor`] during the first traversal and is | ||
/// used by [`CSERewriter`] during the second traversal. | ||
/// | ||
/// The purpose of this cache is to quickly find the identifier of a node during the | ||
/// second traversal. | ||
|
@@ -108,7 +131,8 @@ type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>; | |
|
||
type ChildrenList<N> = (Vec<N>, Vec<N>); | ||
|
||
pub trait SubTreeNodeEliminatorController { | ||
/// The [`TreeNode`] specific definition of elimination. | ||
pub trait CSEController { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 for the new name |
||
/// The type of the tree nodes. | ||
type Node; | ||
|
||
|
@@ -135,11 +159,11 @@ pub trait SubTreeNodeEliminatorController { | |
|
||
// A helper method called on each node during top-down traversal during the second, | ||
// rewriting traversal of CSE. | ||
fn rewrite_f_down(&mut self, node: &Self::Node); | ||
fn rewrite_f_down(&mut self, _node: &Self::Node) {} | ||
|
||
// A helper method called on each node during bottom-up traversal during the second, | ||
// rewriting traversal of CSE. | ||
fn rewrite_f_up(&mut self, node: &Self::Node); | ||
fn rewrite_f_up(&mut self, _node: &Self::Node) {} | ||
} | ||
|
||
/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common | ||
|
@@ -181,7 +205,7 @@ pub enum FoundCommonNodes<N> { | |
/// | ||
/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier | ||
/// because they should not be recognized as common subtree. | ||
struct CommonSubTreeNodeVisitor<'a, 'n, N, C: SubTreeNodeEliminatorController<Node = N>> { | ||
struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> { | ||
/// statistics of [`TreeNode`]s | ||
node_stats: &'a mut NodeStats<'n, N>, | ||
|
||
|
@@ -226,9 +250,7 @@ enum VisitRecord<'n, N> { | |
NodeItem(Identifier<'n, N>, bool), | ||
} | ||
|
||
impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController<Node = N>> | ||
CommonSubTreeNodeVisitor<'_, 'n, N, C> | ||
{ | ||
impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n, N, C> { | ||
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before | ||
/// it. Returns a tuple that contains: | ||
/// - The pre-order index of the [`TreeNode`] we marked. | ||
|
@@ -260,8 +282,8 @@ impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController<Node = N>> | |
} | ||
} | ||
|
||
impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController<Node = N>> | ||
TreeNodeVisitor<'n> for CommonSubTreeNodeVisitor<'_, 'n, N, C> | ||
impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisitor<'n> | ||
for CSEVisitor<'_, 'n, N, C> | ||
{ | ||
type Node = N; | ||
|
||
|
@@ -331,8 +353,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController<Node = | |
/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the | ||
/// corresponding temporary [`TreeNode`], that column contains the evaluate result of | ||
/// replaced [`TreeNode`] tree. | ||
struct CommonSubTreeNodeRewriter<'a, 'n, N, C: SubTreeNodeEliminatorController<Node = N>> | ||
{ | ||
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> { | ||
/// statistics of [`TreeNode`]s | ||
node_stats: &'a NodeStats<'n, N>, | ||
|
||
|
@@ -349,8 +370,8 @@ struct CommonSubTreeNodeRewriter<'a, 'n, N, C: SubTreeNodeEliminatorController<N | |
controller: &'a mut C, | ||
} | ||
|
||
impl<N: TreeNode + Eq, C: SubTreeNodeEliminatorController<Node = N>> TreeNodeRewriter | ||
for CommonSubTreeNodeRewriter<'_, '_, N, C> | ||
impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter | ||
for CSERewriter<'_, '_, N, C> | ||
{ | ||
type Node = N; | ||
|
||
|
@@ -393,17 +414,18 @@ impl<N: TreeNode + Eq, C: SubTreeNodeEliminatorController<Node = N>> TreeNodeRew | |
} | ||
} | ||
|
||
pub struct SubTreeNodeEliminator<N, C: SubTreeNodeEliminatorController<Node = N>> { | ||
/// The main entry point of Common Subexpression Elimination. | ||
/// | ||
/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular | ||
/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the | ||
/// [`CSE::extract_common_nodes()`] method. | ||
pub struct CSE<N, C: CSEController<Node = N>> { | ||
random_state: RandomState, | ||
phantom_data: PhantomData<N>, | ||
controller: C, | ||
} | ||
|
||
impl< | ||
N: TreeNode + HashNode + Clone + Eq, | ||
C: SubTreeNodeEliminatorController<Node = N>, | ||
> SubTreeNodeEliminator<N, C> | ||
{ | ||
impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C> { | ||
pub fn new(controller: C) -> Self { | ||
Self { | ||
random_state: RandomState::new(), | ||
|
@@ -419,7 +441,7 @@ impl< | |
node_stats: &mut NodeStats<'n, N>, | ||
id_array: &mut IdArray<'n, N>, | ||
) -> Result<bool> { | ||
let mut visitor = CommonSubTreeNodeVisitor { | ||
let mut visitor = CSEVisitor { | ||
node_stats, | ||
id_array, | ||
visit_stack: vec![], | ||
|
@@ -440,7 +462,7 @@ impl< | |
/// | ||
/// Returns and array with 1 element for each input node in `nodes` | ||
/// | ||
/// Each element is itself the result of [`SubTreeNodeEliminator::node_to_id_array`] for that node | ||
/// Each element is itself the result of [`CSE::node_to_id_array`] for that node | ||
/// (e.g. the identifiers for each node in the tree) | ||
fn to_arrays<'n>( | ||
&self, | ||
|
@@ -475,7 +497,7 @@ impl< | |
if id_array.is_empty() { | ||
Ok(Transformed::no(node)) | ||
} else { | ||
node.rewrite(&mut CommonSubTreeNodeRewriter { | ||
node.rewrite(&mut CSERewriter { | ||
node_stats, | ||
id_array, | ||
common_nodes, | ||
|
@@ -516,7 +538,7 @@ impl< | |
pub fn extract_common_nodes( | ||
&mut self, | ||
nodes_list: Vec<Vec<N>>, | ||
) -> Result<Transformed<FoundCommonNodes<N>>> { | ||
) -> Result<FoundCommonNodes<N>> { | ||
let mut found_common = false; | ||
let mut node_stats = NodeStats::new(); | ||
let id_arrays_list = nodes_list | ||
|
@@ -542,27 +564,23 @@ impl< | |
)?; | ||
assert!(!common_nodes.is_empty()); | ||
|
||
Ok(Transformed::yes(FoundCommonNodes::Yes { | ||
Ok(FoundCommonNodes::Yes { | ||
common_nodes: common_nodes.into_values().collect(), | ||
new_nodes_list, | ||
original_nodes_list: nodes_list, | ||
})) | ||
}) | ||
} else { | ||
Ok(Transformed::no(FoundCommonNodes::No { | ||
Ok(FoundCommonNodes::No { | ||
original_nodes_list: nodes_list, | ||
})) | ||
}) | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use crate::alias::AliasGenerator; | ||
use crate::cse::{ | ||
IdArray, Identifier, NodeStats, SubTreeNodeEliminator, | ||
SubTreeNodeEliminatorController, | ||
}; | ||
use crate::hash_node::HashNode; | ||
use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE}; | ||
use crate::tree_node::tests::TestTreeNode; | ||
use crate::Result; | ||
use std::collections::HashSet; | ||
|
@@ -576,12 +594,12 @@ mod test { | |
NormalAndAggregates, | ||
} | ||
|
||
pub struct SubTestTreeNodeEliminatorController<'a> { | ||
pub struct TestTreeNodeCSEController<'a> { | ||
alias_generator: &'a AliasGenerator, | ||
mask: TestTreeNodeMask, | ||
} | ||
|
||
impl<'a> SubTestTreeNodeEliminatorController<'a> { | ||
impl<'a> TestTreeNodeCSEController<'a> { | ||
fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self { | ||
Self { | ||
alias_generator, | ||
|
@@ -590,7 +608,7 @@ mod test { | |
} | ||
} | ||
|
||
impl SubTreeNodeEliminatorController for SubTestTreeNodeEliminatorController<'_> { | ||
impl CSEController for TestTreeNodeCSEController<'_> { | ||
type Node = TestTreeNode<String>; | ||
|
||
fn conditional_children( | ||
|
@@ -620,10 +638,6 @@ mod test { | |
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { | ||
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) | ||
} | ||
|
||
fn rewrite_f_down(&mut self, _node: &Self::Node) {} | ||
|
||
fn rewrite_f_up(&mut self, _node: &Self::Node) {} | ||
} | ||
|
||
impl HashNode for TestTreeNode<String> { | ||
|
@@ -635,11 +649,10 @@ mod test { | |
#[test] | ||
fn id_array_visitor() -> Result<()> { | ||
let alias_generator = AliasGenerator::new(); | ||
let eliminator = | ||
SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new( | ||
&alias_generator, | ||
TestTreeNodeMask::Normal, | ||
)); | ||
let eliminator = CSE::new(TestTreeNodeCSEController::new( | ||
&alias_generator, | ||
TestTreeNodeMask::Normal, | ||
)); | ||
|
||
let a_plus_1 = TestTreeNode::new( | ||
vec![ | ||
|
@@ -728,11 +741,10 @@ mod test { | |
assert_eq!(expected, id_array); | ||
|
||
// include aggregates | ||
let eliminator = | ||
SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new( | ||
&alias_generator, | ||
TestTreeNodeMask::NormalAndAggregates, | ||
)); | ||
let eliminator = CSE::new(TestTreeNodeCSEController::new( | ||
&alias_generator, | ||
TestTreeNodeMask::NormalAndAggregates, | ||
)); | ||
|
||
let mut id_array = vec![]; | ||
eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ use crate::{ | |
}; | ||
|
||
use arrow::datatypes::{DataType, FieldRef}; | ||
use datafusion_common::hash_node::HashNode; | ||
use datafusion_common::cse::HashNode; | ||
use datafusion_common::tree_node::{ | ||
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, | ||
}; | ||
|
@@ -1656,16 +1656,6 @@ impl Expr { | |
} | ||
|
||
impl HashNode for Expr { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
/// Hashes the direct content of an `Expr` without recursing into its children. | ||
/// | ||
/// This method is useful to incrementally compute hashes, such as in | ||
/// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants | ||
/// during the bottom-up phase of the first traversal and so avoid computing the hash | ||
/// of the node and then the hash of its descendants separately. | ||
/// | ||
/// If a node doesn't have any children then this method is similar to `.hash()`, but | ||
/// not necessarily returns the same value. | ||
/// | ||
/// As it is pretty easy to forget changing this method when `Expr` changes the | ||
/// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes | ||
/// compile time. | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to add some documentation explaining what is in the module and what the main entry points are/
Something like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, good idea. Added i documentation in 1cae61c.