diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index a34f3478f0cb..453ae26e7333 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::hash_node::HashNode; +//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with +//! a [`CSEController`], that defines how to eliminate common subtrees from a particular +//! [`TreeNode`] tree. + use crate::hash_utils::combine_hashes; use crate::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, @@ -26,6 +29,26 @@ use indexmap::IndexMap; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher, RandomState}; use std::marker::PhantomData; +use std::sync::Arc; + +/// Hashes the direct content of an [`TreeNode`] without recursing into its children. +/// +/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds +/// a deep hash of a node and its descendants during the bottom-up phase of the first +/// traversal and so avoid computing the hash of the node and then the hash of its +/// descendants separately. +/// +/// If a node doesn't have any children then the value returned by `hash_node()` is +/// similar to '.hash()`, but not necessarily returns the same value. +pub trait HashNode { + fn hash_node(&self, state: &mut H); +} + +impl HashNode for Arc { + fn hash_node(&self, state: &mut H) { + (**self).hash_node(state); + } +} /// Identifier that represents a [`TreeNode`] tree. /// @@ -72,8 +95,8 @@ impl<'n, N: HashNode> Identifier<'n, N> { /// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the /// preorder index of the nodes. /// -/// This cache is filled by [`CommonSubTreeNodeVisitor`] during the first traversal and is -/// used by [`CommonSubTreeNodeRewriter`] during the second traversal. +/// This cache is filled by [`CSEVisitor`] during the first traversal and is +/// used by [`CSERewriter`] during the second traversal. /// /// The purpose of this cache is to quickly find the identifier of a node during the /// second traversal. @@ -108,7 +131,8 @@ type CommonNodes<'n, N> = IndexMap, (N, String)>; type ChildrenList = (Vec, Vec); -pub trait SubTreeNodeEliminatorController { +/// The [`TreeNode`] specific definition of elimination. +pub trait CSEController { /// The type of the tree nodes. type Node; @@ -135,11 +159,11 @@ pub trait SubTreeNodeEliminatorController { // A helper method called on each node during top-down traversal during the second, // rewriting traversal of CSE. - fn rewrite_f_down(&mut self, node: &Self::Node); + fn rewrite_f_down(&mut self, _node: &Self::Node) {} // A helper method called on each node during bottom-up traversal during the second, // rewriting traversal of CSE. - fn rewrite_f_up(&mut self, node: &Self::Node); + fn rewrite_f_up(&mut self, _node: &Self::Node) {} } /// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common @@ -181,7 +205,7 @@ pub enum FoundCommonNodes { /// /// A [`TreeNode`] without any children (column, literal etc.) will not have identifier /// because they should not be recognized as common subtree. -struct CommonSubTreeNodeVisitor<'a, 'n, N, C: SubTreeNodeEliminatorController> { +struct CSEVisitor<'a, 'n, N, C: CSEController> { /// statistics of [`TreeNode`]s node_stats: &'a mut NodeStats<'n, N>, @@ -226,9 +250,7 @@ enum VisitRecord<'n, N> { NodeItem(Identifier<'n, N>, bool), } -impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController> - CommonSubTreeNodeVisitor<'_, 'n, N, C> -{ +impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, N, C> { /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before /// it. Returns a tuple that contains: /// - The pre-order index of the [`TreeNode`] we marked. @@ -260,8 +282,8 @@ impl<'n, N: TreeNode + HashNode, C: SubTreeNodeEliminatorController> } } -impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController> - TreeNodeVisitor<'n> for CommonSubTreeNodeVisitor<'_, 'n, N, C> +impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisitor<'n> + for CSEVisitor<'_, 'n, N, C> { type Node = N; @@ -331,8 +353,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: SubTreeNodeEliminatorController> -{ +struct CSERewriter<'a, 'n, N, C: CSEController> { /// statistics of [`TreeNode`]s node_stats: &'a NodeStats<'n, N>, @@ -349,8 +370,8 @@ struct CommonSubTreeNodeRewriter<'a, 'n, N, C: SubTreeNodeEliminatorController> TreeNodeRewriter - for CommonSubTreeNodeRewriter<'_, '_, N, C> +impl> TreeNodeRewriter + for CSERewriter<'_, '_, N, C> { type Node = N; @@ -393,17 +414,18 @@ impl> TreeNodeRew } } -pub struct SubTreeNodeEliminator> { +/// The main entry point of Common Subexpression Elimination. +/// +/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular +/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the +/// [`CSE::extract_common_nodes()`] method. +pub struct CSE> { random_state: RandomState, phantom_data: PhantomData, controller: C, } -impl< - N: TreeNode + HashNode + Clone + Eq, - C: SubTreeNodeEliminatorController, - > SubTreeNodeEliminator -{ +impl> CSE { pub fn new(controller: C) -> Self { Self { random_state: RandomState::new(), @@ -419,7 +441,7 @@ impl< node_stats: &mut NodeStats<'n, N>, id_array: &mut IdArray<'n, N>, ) -> Result { - let mut visitor = CommonSubTreeNodeVisitor { + let mut visitor = CSEVisitor { node_stats, id_array, visit_stack: vec![], @@ -440,7 +462,7 @@ impl< /// /// Returns and array with 1 element for each input node in `nodes` /// - /// Each element is itself the result of [`SubTreeNodeEliminator::node_to_id_array`] for that node + /// Each element is itself the result of [`CSE::node_to_id_array`] for that node /// (e.g. the identifiers for each node in the tree) fn to_arrays<'n>( &self, @@ -475,7 +497,7 @@ impl< if id_array.is_empty() { Ok(Transformed::no(node)) } else { - node.rewrite(&mut CommonSubTreeNodeRewriter { + node.rewrite(&mut CSERewriter { node_stats, id_array, common_nodes, @@ -516,7 +538,7 @@ impl< pub fn extract_common_nodes( &mut self, nodes_list: Vec>, - ) -> Result>> { + ) -> Result> { let mut found_common = false; let mut node_stats = NodeStats::new(); let id_arrays_list = nodes_list @@ -542,15 +564,15 @@ impl< )?; assert!(!common_nodes.is_empty()); - Ok(Transformed::yes(FoundCommonNodes::Yes { + Ok(FoundCommonNodes::Yes { common_nodes: common_nodes.into_values().collect(), new_nodes_list, original_nodes_list: nodes_list, - })) + }) } else { - Ok(Transformed::no(FoundCommonNodes::No { + Ok(FoundCommonNodes::No { original_nodes_list: nodes_list, - })) + }) } } } @@ -558,11 +580,7 @@ impl< #[cfg(test)] mod test { use crate::alias::AliasGenerator; - use crate::cse::{ - IdArray, Identifier, NodeStats, SubTreeNodeEliminator, - SubTreeNodeEliminatorController, - }; - use crate::hash_node::HashNode; + use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE}; use crate::tree_node::tests::TestTreeNode; use crate::Result; use std::collections::HashSet; @@ -576,12 +594,12 @@ mod test { NormalAndAggregates, } - pub struct SubTestTreeNodeEliminatorController<'a> { + pub struct TestTreeNodeCSEController<'a> { alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask, } - impl<'a> SubTestTreeNodeEliminatorController<'a> { + impl<'a> TestTreeNodeCSEController<'a> { fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self { Self { alias_generator, @@ -590,7 +608,7 @@ mod test { } } - impl SubTreeNodeEliminatorController for SubTestTreeNodeEliminatorController<'_> { + impl CSEController for TestTreeNodeCSEController<'_> { type Node = TestTreeNode; fn conditional_children( @@ -620,10 +638,6 @@ mod test { fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) } - - fn rewrite_f_down(&mut self, _node: &Self::Node) {} - - fn rewrite_f_up(&mut self, _node: &Self::Node) {} } impl HashNode for TestTreeNode { @@ -635,11 +649,10 @@ mod test { #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); - let eliminator = - SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new( - &alias_generator, - TestTreeNodeMask::Normal, - )); + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::Normal, + )); let a_plus_1 = TestTreeNode::new( vec![ @@ -728,11 +741,10 @@ mod test { assert_eq!(expected, id_array); // include aggregates - let eliminator = - SubTreeNodeEliminator::new(SubTestTreeNodeEliminatorController::new( - &alias_generator, - TestTreeNodeMask::NormalAndAggregates, - )); + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::NormalAndAggregates, + )); let mut id_array = vec![]; eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; diff --git a/datafusion/common/src/hash_node.rs b/datafusion/common/src/hash_node.rs deleted file mode 100644 index 996f4b032c77..000000000000 --- a/datafusion/common/src/hash_node.rs +++ /dev/null @@ -1,29 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::hash::Hasher; -use std::sync::Arc; - -pub trait HashNode { - fn hash_node(&self, _state: &mut H); -} - -impl HashNode for Arc { - fn hash_node(&self, state: &mut H) { - (**self).hash_node(state); - } -} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 485d0799c9ce..da74244b184d 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -36,7 +36,6 @@ pub mod display; pub mod error; pub mod file_options; pub mod format; -pub mod hash_node; pub mod hash_utils; pub mod instant; pub mod parsers; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d85605f0192b..41d4b360149b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -34,7 +34,7 @@ use crate::{ }; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::hash_node::HashNode; +use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -1656,16 +1656,6 @@ impl Expr { } impl HashNode for Expr { - /// Hashes the direct content of an `Expr` without recursing into its children. - /// - /// This method is useful to incrementally compute hashes, such as in - /// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants - /// during the bottom-up phase of the first traversal and so avoid computing the hash - /// of the node and then the hash of its descendants separately. - /// - /// If a node doesn't have any children then this method is similar to `.hash()`, but - /// not necessarily returns the same value. - /// /// As it is pretty easy to forget changing this method when `Expr` changes the /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes /// compile time. diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 7021d33129fc..921011d33fc4 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -27,9 +27,7 @@ use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; -use datafusion_common::cse::{ - FoundCommonNodes, SubTreeNodeEliminator, SubTreeNodeEliminatorController, -}; +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; @@ -144,12 +142,12 @@ impl CommonSubexprEliminate { // Extract common sub-expressions from the list. - SubTreeNodeEliminator::new(SubExprEliminatorController::new( + match CSE::new(ExprCSEController::new( config.alias_generator().as_ref(), ExprMask::Normal, )) .extract_common_nodes(window_expr_list)? - .map_data(|common| match common { + { // If there are common sub-expressions, then the insert a projection node // with the common expressions between the new window nodes and the // original input. @@ -157,12 +155,13 @@ impl CommonSubexprEliminate { common_nodes: common_exprs, new_nodes_list: new_exprs_list, original_nodes_list: original_exprs_list, - } => build_common_expr_project_plan(input, common_exprs) - .map(|new_input| (new_exprs_list, new_input, Some(original_exprs_list))), + } => build_common_expr_project_plan(input, common_exprs).map(|new_input| { + Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list))) + }), FoundCommonNodes::No { original_nodes_list: original_exprs_list, - } => Ok((original_exprs_list, input, None)), - })? + } => Ok(Transformed::no((original_exprs_list, input, None))), + }? // Recurse into the new input. // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { @@ -235,40 +234,48 @@ impl CommonSubexprEliminate { } = aggregate; let input = Arc::unwrap_or_clone(input); // Extract common sub-expressions from the aggregate and grouping expressions. - SubTreeNodeEliminator::new(SubExprEliminatorController::new( + match CSE::new(ExprCSEController::new( config.alias_generator().as_ref(), ExprMask::Normal, )) .extract_common_nodes(vec![group_expr, aggr_expr])? - .map_data(|common| { - match common { - // If there are common sub-expressions, then insert a projection node - // with the common expressions between the new aggregate node and the - // original input. - FoundCommonNodes::Yes { - common_nodes: common_exprs, - new_nodes_list: mut new_exprs_list, - original_nodes_list: mut original_exprs_list, - } => { - let new_aggr_expr = new_exprs_list.pop().unwrap(); - let new_group_expr = new_exprs_list.pop().unwrap(); - - build_common_expr_project_plan(input, common_exprs).map(|new_input| { - let aggr_expr = original_exprs_list.pop().unwrap(); - (new_aggr_expr, new_group_expr, new_input, Some(aggr_expr)) - }) - } + { + // If there are common sub-expressions, then insert a projection node + // with the common expressions between the new aggregate node and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = new_exprs_list.pop().unwrap(); + let new_group_expr = new_exprs_list.pop().unwrap(); + + build_common_expr_project_plan(input, common_exprs).map(|new_input| { + let aggr_expr = original_exprs_list.pop().unwrap(); + Transformed::yes(( + new_aggr_expr, + new_group_expr, + new_input, + Some(aggr_expr), + )) + }) + } - FoundCommonNodes::No { - original_nodes_list: mut original_exprs_list, - } => { - let new_aggr_expr = original_exprs_list.pop().unwrap(); - let new_group_expr = original_exprs_list.pop().unwrap(); + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = original_exprs_list.pop().unwrap(); + let new_group_expr = original_exprs_list.pop().unwrap(); - Ok((new_aggr_expr, new_group_expr, input, None)) - } + Ok(Transformed::no(( + new_aggr_expr, + new_group_expr, + input, + None, + ))) } - })? + }? // Recurse into the new input. // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { @@ -285,121 +292,115 @@ impl CommonSubexprEliminate { .transform_data( |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { // Extract common aggregate sub-expressions from the aggregate expressions. - SubTreeNodeEliminator::new(SubExprEliminatorController::new( + match CSE::new(ExprCSEController::new( config.alias_generator().as_ref(), ExprMask::NormalAndAggregates, )) .extract_common_nodes(vec![new_aggr_expr])? - .map_data(|common| { - match common { - FoundCommonNodes::Yes { - common_nodes: common_exprs, - new_nodes_list: mut new_exprs_list, - original_nodes_list: mut original_exprs_list, - } => { - let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); - let new_aggr_expr = original_exprs_list.pop().unwrap(); - - let mut agg_exprs = common_exprs - .into_iter() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::>(); + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); + let new_aggr_expr = original_exprs_list.pop().unwrap(); + + let mut agg_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &mut proj_exprs) - } - for (expr_rewritten, expr_orig) in - rewritten_aggr_expr.into_iter().zip(new_aggr_expr) - { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = - expr_rewritten - { - agg_exprs.push(expr.alias(&name)); - proj_exprs - .push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = - config.alias_generator().next(CSE_PREFIX); - let (qualifier, field_name) = - expr_rewritten.qualified_name(); - let out_name = qualified_name( - qualifier.as_ref(), - &field_name, - ); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)) - .alias(out_name), - ); - } + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &mut proj_exprs) + } + for (expr_rewritten, expr_orig) in + rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = + expr_rewritten + { + agg_exprs.push(expr.alias(&name)); + proj_exprs + .push(Expr::Column(Column::from_name(name))); } else { - proj_exprs.push(expr_rewritten); + let expr_alias = + config.alias_generator().next(CSE_PREFIX); + let (qualifier, field_name) = + expr_rewritten.qualified_name(); + let out_name = + qualified_name(qualifier.as_ref(), &field_name); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)) + .alias(out_name), + ); } + } else { + proj_exprs.push(expr_rewritten); } - - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - new_input, - new_group_expr, - agg_exprs, - )?); - Projection::try_new(proj_exprs, Arc::new(agg)) - .map(LogicalPlan::Projection) } - // If there aren't any common aggregate sub-expressions, then just - // rebuild the aggregate node. - FoundCommonNodes::No { - original_nodes_list: mut original_exprs_list, - } => { - let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); - - // If there were common expressions extracted, then we need to - // make sure we restore the original column names. - // TODO: Although `find_common_exprs()` inserts aliases around - // extracted common expressions this doesn't mean that the - // original column names (schema) are preserved due to the - // inserted aliases are not always at the top of the - // expression. - // Let's consider improving `find_common_exprs()` to always - // keep column names and get rid of additional name - // preserving logic here. - if let Some(aggr_expr) = aggr_expr { - let name_perserver = NamePreserver::new_for_projection(); - let saved_names = aggr_expr - .iter() - .map(|expr| name_perserver.save(expr)) - .collect::>(); - let new_aggr_expr = rewritten_aggr_expr - .into_iter() - .zip(saved_names) - .map(|(new_expr, saved_name)| { - saved_name.restore(new_expr) - }) - .collect::>(); - - // Since `group_expr` may have changed, schema may also. - // Use `try_new()` method. - Aggregate::try_new( - new_input, - new_group_expr, - new_aggr_expr, - ) - .map(LogicalPlan::Aggregate) - } else { - Aggregate::try_new_with_schema( - new_input, - new_group_expr, - rewritten_aggr_expr, - schema, - ) + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + new_input, + new_group_expr, + agg_exprs, + )?); + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(|p| Transformed::yes(LogicalPlan::Projection(p))) + } + + // If there aren't any common aggregate sub-expressions, then just + // rebuild the aggregate node. + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); + + // If there were common expressions extracted, then we need to + // make sure we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around + // extracted common expressions this doesn't mean that the + // original column names (schema) are preserved due to the + // inserted aliases are not always at the top of the + // expression. + // Let's consider improving `find_common_exprs()` to always + // keep column names and get rid of additional name + // preserving logic here. + if let Some(aggr_expr) = aggr_expr { + let name_perserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_perserver.save(expr)) + .collect::>(); + let new_aggr_expr = rewritten_aggr_expr + .into_iter() + .zip(saved_names) + .map(|(new_expr, saved_name)| { + saved_name.restore(new_expr) + }) + .collect::>(); + + // Since `group_expr` may have changed, schema may also. + // Use `try_new()` method. + Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) - } + .map(Transformed::no) + } else { + Aggregate::try_new_with_schema( + new_input, + new_group_expr, + rewritten_aggr_expr, + schema, + ) + .map(LogicalPlan::Aggregate) + .map(Transformed::no) } } - }) + } }, ) } @@ -425,12 +426,12 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result, LogicalPlan)>> { // Extract common sub-expressions from the expressions. - SubTreeNodeEliminator::new(SubExprEliminatorController::new( + match CSE::new(ExprCSEController::new( config.alias_generator().as_ref(), ExprMask::Normal, )) .extract_common_nodes(vec![exprs])? - .map_data(|common| match common { + { FoundCommonNodes::Yes { common_nodes: common_exprs, new_nodes_list: mut new_exprs_list, @@ -438,15 +439,15 @@ impl CommonSubexprEliminate { } => { let new_exprs = new_exprs_list.pop().unwrap(); build_common_expr_project_plan(input, common_exprs) - .map(|new_input| (new_exprs, new_input)) + .map(|new_input| Transformed::yes((new_exprs, new_input))) } FoundCommonNodes::No { original_nodes_list: mut original_exprs_list, } => { let new_exprs = original_exprs_list.pop().unwrap(); - Ok((new_exprs, input)) + Ok(Transformed::no((new_exprs, input))) } - })? + }? // Recurse into the new input. // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) .transform_data(|(new_exprs, new_input)| { @@ -593,7 +594,7 @@ enum ExprMask { NormalAndAggregates, } -struct SubExprEliminatorController<'a> { +struct ExprCSEController<'a> { alias_generator: &'a AliasGenerator, mask: ExprMask, @@ -601,7 +602,7 @@ struct SubExprEliminatorController<'a> { alias_counter: usize, } -impl<'a> SubExprEliminatorController<'a> { +impl<'a> ExprCSEController<'a> { fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self { Self { alias_generator, @@ -611,7 +612,7 @@ impl<'a> SubExprEliminatorController<'a> { } } -impl SubTreeNodeEliminatorController for SubExprEliminatorController<'_> { +impl CSEController for ExprCSEController<'_> { type Node = Expr; fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { @@ -689,13 +690,12 @@ impl SubTreeNodeEliminatorController for SubExprEliminatorController<'_> { } fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { - let expr_name = node.schema_name().to_string(); // alias the expressions without an `Alias` ancestor node if self.alias_counter > 0 { col(alias) } else { self.alias_counter += 1; - col(alias).alias(expr_name) + col(alias).alias(node.schema_name().to_string()) } }