diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index abddb53ff8583..1d8745fa8cb47 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -196,6 +196,27 @@ pub trait TreeNode: Sized + Clone { Ok(new_node) } + fn rewrite_old>( + self, + rewriter: &mut R, + ) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + let after_op_children = self.map_children(|node| node.rewrite_old(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for /// recursively transforming [`TreeNode`]s. /// @@ -280,6 +301,21 @@ pub trait TreeNodeVisitor: Sized { } } +pub trait TreeNodeRewriterOld: Sized { + /// The node type which is rewritable. + type N: TreeNode; + + /// Invoked before (Preorder) any children of `node` are rewritten / + /// visited. Default implementation returns `Ok(Recursion::Continue)` + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after (Postorder) all children of `node` have been mutated and + /// returns a potentially modified node. + fn mutate(&mut self, node: Self::N) -> Result; +} + /// Trait for potentially recursively transform a [`TreeNode`] node tree. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. @@ -300,6 +336,17 @@ pub trait TreeNodeRewriter: Sized { } } +pub enum RewriteRecursion { + /// Continue rewrite this node tree. + Continue, + /// Call 'op' immediately and return. + Mutate, + /// Do not rewrite the children of this node. + Stop, + /// Keep recursive but skip apply op on this node + Skip, +} + /// Controls how [`TreeNode`] recursions should proceed. #[derive(Debug)] pub enum TreeNodeRecursion { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f9281fa8e14ab..45fb6e67861a4 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -24,7 +24,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, + RewriteRecursion, TreeNode, TreeNodeRecursion, TreeNodeRewriterOld, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -688,66 +688,69 @@ struct CommonSubexprRewriter<'a> { curr_index: usize, } -impl TreeNodeRewriter for CommonSubexprRewriter<'_> { - type Node = Expr; +impl TreeNodeRewriterOld for CommonSubexprRewriter<'_> { + type N = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn pre_visit(&mut self, _: &Expr) -> Result { if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(RewriteRecursion::Stop); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok((expr, TreeNodeRecursion::Continue)); + return Ok(RewriteRecursion::Skip); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { if *counter > 1 { self.affected_id.insert(curr_id.clone()); - - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok((expr, TreeNodeRecursion::Skip)); - } - - let (series_number, id) = &self.id_array[self.curr_index]; - self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - DataFusionError::Internal("expr_set invalid state".to_string()) - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok((expr, TreeNodeRecursion::Skip)); - } - - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - - let expr_name = expr.display_name()?; - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - Ok((col(id).alias(expr_name), TreeNodeRecursion::Skip)) + Ok(RewriteRecursion::Mutate) } else { self.curr_index += 1; - Ok((expr, TreeNodeRecursion::Continue)) + Ok(RewriteRecursion::Skip) } } _ => internal_err!("expr_set invalid state"), } } + + fn mutate(&mut self, expr: Expr) -> Result { + // This expr tree is finished. + if self.curr_index >= self.id_array.len() { + return Ok(expr); + } + + let (series_number, id) = &self.id_array[self.curr_index]; + self.curr_index += 1; + // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. + let expr_set_item = self.expr_set.get(id).ok_or_else(|| { + DataFusionError::Internal("expr_set invalid state".to_string()) + })?; + if *series_number < self.max_series_number + || id.is_empty() + || expr_set_item.1 <= 1 + { + return Ok(expr); + } + + self.max_series_number = *series_number; + // step index to skip all sub-node (which has smaller series number). + while self.curr_index < self.id_array.len() + && *series_number > self.id_array[self.curr_index].0 + { + self.curr_index += 1; + } + + let expr_name = expr.display_name()?; + // Alias this `Column` expr to it original "expr name", + // `projection_push_down` optimizer use "expr name" to eliminate useless + // projections. + Ok(col(id).alias(expr_name)) + } } fn replace_common_expr( @@ -756,7 +759,7 @@ fn replace_common_expr( expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { - expr.rewrite(&mut CommonSubexprRewriter { + expr.rewrite_old(&mut CommonSubexprRewriter { expr_set, id_array, affected_id,