From a84e5f89bd52d59c78f11fffbab89ed1d418538f Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 4 Mar 2024 23:33:44 +0100 Subject: [PATCH] Consolidate `TreeNode` transform and rewrite APIs (#8891) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor `TreeNode::rewrite()` * use handle_tree_recursion in `Expr` * use macro for transform recursions * fix api * minor fixes * fix * don't trust `t.transformed` coming from transformation closures, keep the old way of detecting if changes were made * rephrase todo comment, always propagate up `t.transformed` from the transformation closure, fix projection pushdown closure * Fix `TreeNodeRecursion` docs * extend Skip (Prune) functionality to Jump as it is defined in https://synnada.notion.site/synnada/TreeNode-Design-Proposal-bceac27d18504a2085145550e267c4c1 * fix Jump and add tests * jump test fixes * fix clippy * unify "transform" traversals using macros, fix "visit" traversal jumps, add visit jump tests, ensure consistent naming `f` instead of `op`, `f_down` instead of `pre_visit` and `f_up` instead of `post_visit` * fix macro rewrite * minor fixes * minor fix * refactor tests * add transform tests * add apply, transform_down and transform_up tests * refactor tests * test jump on both a and e nodes in both top-down and bottom-up traversals * better transform/rewrite tests * minor fix * simplify tests * add stop tests, reorganize tests * fix previous merges and remove leftover file * Review TreeNode Refactor (#1) * Minor changes * Jump doesn't ignore f_up * update test * Update rewriter * LogicalPlan visit update and propagate from children flags * Update tree_node.rs * Update map_children's --------- Co-authored-by: Mustafa Akur * fix * minor fixes * fix f_up call when f_down returns jump * simplify code * minor fix * revert unnecessary changes * fix `DynTreeNode` and `ConcreteTreeNode` `transformed` and `tnr` propagation * introduce TransformedResult helper * fix docs * restore transform as alias to trassform_up * restore transform as alias to trassform_up 2 * Simplifications and comment improvements (#2) --------- Co-authored-by: Berkay Şahin <124376117+berkaysynnada@users.noreply.github.com> Co-authored-by: Mustafa Akur Co-authored-by: Mehmet Ozan Kabak --- datafusion-examples/examples/rewrite_expr.rs | 19 +- datafusion/common/src/tree_node.rs | 1783 ++++++++++++++--- .../core/src/datasource/listing/helpers.rs | 42 +- .../physical_plan/parquet/row_filter.rs | 43 +- datafusion/core/src/execution/context/mod.rs | 8 +- .../aggregate_statistics.rs | 10 +- .../physical_optimizer/coalesce_batches.rs | 13 +- .../combine_partial_final_agg.rs | 12 +- .../enforce_distribution.rs | 43 +- .../src/physical_optimizer/enforce_sorting.rs | 49 +- .../src/physical_optimizer/join_selection.rs | 33 +- .../limited_distinct_aggregation.rs | 26 +- .../physical_optimizer/output_requirements.rs | 22 +- .../physical_optimizer/pipeline_checker.rs | 5 +- .../physical_optimizer/projection_pushdown.rs | 36 +- .../core/src/physical_optimizer/pruning.rs | 15 +- .../replace_with_order_preserving_variants.rs | 10 +- .../src/physical_optimizer/sort_pushdown.rs | 4 +- .../core/src/physical_optimizer/test_utils.rs | 28 +- .../physical_optimizer/topk_aggregation.rs | 24 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 23 +- datafusion/expr/src/expr.rs | 32 +- datafusion/expr/src/expr_rewriter/mod.rs | 89 +- datafusion/expr/src/expr_rewriter/order_by.rs | 12 +- datafusion/expr/src/logical_plan/display.rs | 34 +- datafusion/expr/src/logical_plan/plan.rs | 94 +- datafusion/expr/src/tree_node/expr.rs | 419 ++-- datafusion/expr/src/tree_node/plan.rs | 80 +- datafusion/expr/src/utils.rs | 10 +- .../src/analyzer/count_wildcard_rule.rs | 125 +- .../src/analyzer/inline_table_scan.rs | 49 +- datafusion/optimizer/src/analyzer/mod.rs | 4 +- .../optimizer/src/analyzer/rewrite_expr.rs | 73 +- datafusion/optimizer/src/analyzer/subquery.rs | 25 +- .../optimizer/src/analyzer/type_coercion.rs | 173 +- .../optimizer/src/common_subexpr_eliminate.rs | 125 +- datafusion/optimizer/src/decorrelate.rs | 130 +- .../src/decorrelate_predicate_subquery.rs | 17 +- datafusion/optimizer/src/plan_signature.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 23 +- .../optimizer/src/scalar_subquery_to_join.rs | 78 +- .../simplify_expressions/expr_simplifier.rs | 329 +-- .../src/simplify_expressions/guarantees.rs | 66 +- .../simplify_expressions/inlist_simplifier.rs | 52 +- .../src/unwrap_cast_in_comparison.rs | 55 +- datafusion/optimizer/src/utils.rs | 14 +- .../physical-expr/src/equivalence/class.rs | 29 +- .../physical-expr/src/equivalence/mod.rs | 22 +- .../src/equivalence/projection.rs | 12 +- .../src/equivalence/properties.rs | 22 +- .../physical-expr/src/expressions/case.rs | 28 +- datafusion/physical-expr/src/physical_expr.rs | 2 +- datafusion/physical-expr/src/utils/mod.rs | 31 +- datafusion/physical-plan/src/empty.rs | 2 +- .../src/joins/stream_join_utils.rs | 18 +- datafusion/physical-plan/src/joins/utils.rs | 20 +- datafusion/physical-plan/src/lib.rs | 7 +- .../physical-plan/src/placeholder_row.rs | 3 +- .../physical-plan/src/recursive_query.rs | 7 +- datafusion/physical-plan/src/tree_node.rs | 6 +- datafusion/sql/src/utils.rs | 76 +- .../sqllogictest/test_files/group_by.slt | 2 +- .../library-user-guide/working-with-exprs.md | 6 +- 63 files changed, 3074 insertions(+), 1579 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 8d13d1201881..cc1396f770e4 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -17,7 +17,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result, ScalarValue}; use datafusion_expr::{ AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, @@ -95,14 +95,15 @@ impl MyAnalyzerRule { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, filter.input, )?)) } - _ => Transformed::No(plan), + _ => Transformed::no(plan), }) }) + .data() } fn analyze_expr(expr: Expr) -> Result { @@ -111,13 +112,14 @@ impl MyAnalyzerRule { Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { // transform to UInt64 - Transformed::Yes(Expr::Literal(ScalarValue::UInt64( + Transformed::yes(Expr::Literal(ScalarValue::UInt64( i.map(|i| i as u64), ))) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) + .data() } } @@ -175,14 +177,15 @@ fn my_rewrite(expr: Expr) -> Result { let low: Expr = *low; let high: Expr = *high; if negated { - Transformed::Yes(expr.clone().lt(low).or(expr.gt(high))) + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) } else { - Transformed::Yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) + Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) } } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) + .data() } #[derive(Default)] diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c5c4ee824d61..2d653a27c47b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,29 +22,74 @@ use std::sync::Arc; use crate::Result; -/// If the function returns [`VisitRecursion::Continue`], the normal execution of the -/// function continues. If it returns [`VisitRecursion::Skip`], the function returns -/// with [`VisitRecursion::Continue`] to jump next recursion step, bypassing further -/// exploration of the current step. In case of [`VisitRecursion::Stop`], the function -/// return with [`VisitRecursion::Stop`] and recursion halts. +/// This macro is used to control continuation behaviors during tree traversals +/// based on the specified direction. Depending on `$DIRECTION` and the value of +/// the given expression (`$EXPR`), which should be a variant of [`TreeNodeRecursion`], +/// the macro results in the following behavior: +/// +/// - If the expression returns [`TreeNodeRecursion::Continue`], normal execution +/// continues. +/// - If it returns [`TreeNodeRecursion::Stop`], recursion halts and propagates +/// [`TreeNodeRecursion::Stop`]. +/// - If it returns [`TreeNodeRecursion::Jump`], the continuation behavior depends +/// on the traversal direction: +/// - For `UP` direction, the function returns with [`TreeNodeRecursion::Jump`], +/// bypassing further bottom-up closures until the next top-down closure. +/// - For `DOWN` direction, the function returns with [`TreeNodeRecursion::Continue`], +/// skipping further exploration. +/// - If no direction is specified, `Jump` is treated like `Continue`. #[macro_export] -macro_rules! handle_tree_recursion { - ($EXPR:expr) => { +macro_rules! handle_visit_recursion { + // Internal helper macro for handling the `Jump` case based on the direction: + (@handle_jump UP) => { + return Ok(TreeNodeRecursion::Jump) + }; + (@handle_jump DOWN) => { + return Ok(TreeNodeRecursion::Continue) + }; + (@handle_jump) => { + {} // Treat `Jump` like `Continue`, do nothing and continue execution. + }; + + // Main macro logic with variables to handle directionality. + ($EXPR:expr $(, $DIRECTION:ident)?) => { match $EXPR { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children, let - // the recursion continue: - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children: - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Jump => handle_visit_recursion!(@handle_jump $($DIRECTION)?), + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } }; } -/// Defines a visitable and rewriteable a tree node. This trait is -/// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as -/// well as expression trees ([`PhysicalExpr`], [`Expr`]) in -/// DataFusion +/// This macro is used to determine continuation during combined transforming +/// traversals. +/// +/// Depending on the [`TreeNodeRecursion`] the bottom-up closure returns, +/// [`Transformed::try_transform_node_with()`] decides recursion continuation +/// and if state propagation is necessary. Then, the same procedure recursively +/// applies to the children of the node in question. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => {{ + let pre_visited = $F_DOWN?; + match pre_visited.tnr { + TreeNodeRecursion::Continue => pre_visited + .data + .map_children($F_SELF)? + .try_transform_node_with($F_UP, TreeNodeRecursion::Jump), + #[allow(clippy::redundant_closure_call)] + TreeNodeRecursion::Jump => $F_UP(pre_visited.data), + TreeNodeRecursion::Stop => return Ok(pre_visited), + } + .map(|mut post_visited| { + post_visited.transformed |= pre_visited.transformed; + post_visited + }) + }}; +} + +/// Defines a visitable and rewriteable tree node. This trait is implemented +/// for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well as expression +/// trees ([`PhysicalExpr`], [`Expr`]) in DataFusion. /// /// /// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html @@ -52,283 +97,507 @@ macro_rules! handle_tree_recursion { /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { - /// Applies `op` to the node and its children. `op` is applied in a preoder way, - /// and it is controlled by [`VisitRecursion`], which means result of the `op` - /// on the self node can cause an early return. + /// Visit the tree node using the given [`TreeNodeVisitor`], performing a + /// depth-first walk of the node and its children. + /// + /// Consider the following tree structure: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` /// - /// The `op` closure can be used to collect some info from the - /// tree node or do some checking for the tree node. - fn apply Result>( + /// Here, the nodes would be visited using the following order: + /// ```text + /// TreeNodeVisitor::f_down(ParentNode) + /// TreeNodeVisitor::f_down(ChildNode1) + /// TreeNodeVisitor::f_up(ChildNode1) + /// TreeNodeVisitor::f_down(ChildNode2) + /// TreeNodeVisitor::f_up(ChildNode2) + /// TreeNodeVisitor::f_up(ParentNode) + /// ``` + /// + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. + /// + /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], + /// the recursion stops immediately. + /// + /// If using the default [`TreeNodeVisitor::f_up`] that does nothing, consider using + /// [`Self::apply`]. + fn visit>( &self, - op: &mut F, - ) -> Result { - handle_tree_recursion!(op(self)?); - self.apply_children(&mut |node| node.apply(op)) + visitor: &mut V, + ) -> Result { + match visitor.f_down(self)? { + TreeNodeRecursion::Continue => { + handle_visit_recursion!( + self.apply_children(&mut |n| n.visit(visitor))?, + UP + ); + visitor.f_up(self) + } + TreeNodeRecursion::Jump => visitor.f_up(self), + TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), + } } - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for + /// recursively transforming [`TreeNode`]s. /// - /// For an node tree such as + /// Consider the following tree structure: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order + /// Here, the nodes would be visited using the following order: /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// post_visit(ChildNode1) - /// pre_visit(ChildNode2) - /// post_visit(ChildNode2) - /// post_visit(ParentNode) + /// TreeNodeRewriter::f_down(ParentNode) + /// TreeNodeRewriter::f_down(ChildNode1) + /// TreeNodeRewriter::f_up(ChildNode1) + /// TreeNodeRewriter::f_down(ChildNode2) + /// TreeNodeRewriter::f_up(ChildNode2) + /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] + /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], + /// the recursion stops immediately. + fn rewrite>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| { + rewriter.f_up(n) + }) + } + + /// Applies `f` to the node and its children. `f` is applied in a pre-order + /// way, and it is controlled by [`TreeNodeRecursion`], which means result + /// of the `f` on a node can cause an early return. /// - /// If using the default [`TreeNodeVisitor::post_visit`] that does - /// nothing, [`Self::apply`] should be preferred. - fn visit>( + /// The `f` closure can be used to collect some information from tree nodes + /// or run a check on the tree. + fn apply Result>( &self, - visitor: &mut V, - ) -> Result { - handle_tree_recursion!(visitor.pre_visit(self)?); - handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); - visitor.post_visit(self) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - self.transform_up(op) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its - /// children(Preorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down(op)) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its - /// children(Preorder Traversal) using a mutable function, `F`. - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down_mut(self, op: &mut F) -> Result - where - F: FnMut(Self) -> Result>, - { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down_mut(op)) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its - /// children and then itself(Postorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let after_op_children = self.map_children(|node| node.transform_up(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its - /// children and then itself(Postorder Traversal) using a mutable function, `F`. - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up_mut(self, op: &mut F) -> Result - where - F: FnMut(Self) -> Result>, - { - let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) - } - - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. + f: &mut F, + ) -> Result { + handle_visit_recursion!(f(self)?, DOWN); + self.apply_children(&mut |n| n.apply(f)) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to the tree in a bottom-up (post-order) fashion. When + /// `f` does not apply to a given node, it is left unchanged. + fn transform Result>>( + self, + f: &F, + ) -> Result> { + self.transform_up(f) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to a node and then to its children (pre-order traversal). + /// When `f` does not apply to a given node, it is left unchanged. + fn transform_down Result>>( + self, + f: &F, + ) -> Result> { + f(self)?.try_transform_node_with( + |n| n.map_children(|c| c.transform_down(f)), + TreeNodeRecursion::Continue, + ) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given mutable function `f` to a node and then to its children (pre-order + /// traversal). When `f` does not apply to a given node, it is left unchanged. + fn transform_down_mut Result>>( + self, + f: &mut F, + ) -> Result> { + f(self)?.try_transform_node_with( + |n| n.map_children(|c| c.transform_down_mut(f)), + TreeNodeRecursion::Continue, + ) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to all children of a node, and then to the node itself + /// (post-order traversal). When `f` does not apply to a given node, it is + /// left unchanged. + fn transform_up Result>>( + self, + f: &F, + ) -> Result> { + self.map_children(|c| c.transform_up(f))? + .try_transform_node_with(f, TreeNodeRecursion::Jump) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given mutable function `f` to all children of a node, and then to the + /// node itself (post-order traversal). When `f` does not apply to a given + /// node, it is left unchanged. + fn transform_up_mut Result>>( + self, + f: &mut F, + ) -> Result> { + self.map_children(|c| c.transform_up_mut(f))? + .try_transform_node_with(f, TreeNodeRecursion::Jump) + } + + /// Transforms the tree using `f_down` while traversing the tree top-down + /// (pre-order), and using `f_up` while traversing the tree bottom-up + /// (post-order). + /// + /// Use this method if you want to start the `f_up` process right where `f_down` jumps. + /// This can make the whole process faster by reducing the number of `f_up` steps. + /// If you don't need this, it's just like using `transform_down_mut` followed by + /// `transform_up_mut` on the same tree. /// - /// For an node tree such as + /// Consider the following tree structure: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order + /// The nodes are visited using the following order: /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) + /// f_down(ParentNode) + /// f_down(ChildNode1) + /// f_up(ChildNode1) + /// f_down(ChildNode2) + /// f_up(ChildNode2) + /// f_up(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is mutate - /// called on that node + /// If `f_down` or `f_up` returns [`Err`], the recursion stops immediately. /// - /// If using the default [`TreeNodeRewriter::pre_visit`] which - /// returns `true`, [`Self::transform`] should be preferred. - fn rewrite>(self, rewriter: &mut R) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) - } + /// Example: + /// ```text + /// | +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// TreeNodeRecursion::Continue | | I | + /// | +---+ + /// | | + /// | +---+ + /// \|/ | F | + /// ' +---+ + /// / \ ___________________ + /// When `f_down` is +---+ \ ---+ + /// applied on node "E", | E | | G | + /// it returns with "Jump". +---+ +---+ + /// | | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ + /// + /// Instead of starting from leaf nodes, `f_up` starts from the node "E". + /// +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// | | I | + /// | +---+ + /// | | + /// / +---+ + /// / | F | + /// / +---+ + /// / / \ ______________________ + /// | +---+ . \ ---+ + /// | | E | /|\ After `f_down` jumps | G | + /// | +---+ | on node E, `f_up` +---+ + /// \------| ---/ if applied on node E. | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ + /// ``` + fn transform_down_up< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up(f_down, f_up), + f_up + ) } - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result; + /// Apply the closure `F` to the node's children. + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result; - /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result; + /// Apply transform `F` to the node's children. Note that the transform `F` + /// might have a direction (pre-order or post-order). + fn map_children Result>>( + self, + f: F, + ) -> Result>; } -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. -/// -/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`] -/// and [`TreeNodeVisitor::post_visit`] are invoked recursively -/// on an node tree. +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) +/// for recursively walking [`TreeNode`]s. /// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. +/// A [`TreeNodeVisitor`] allows one to express algorithms separately from the +/// code traversing the structure of the `TreeNode` tree, making it easier to +/// add new types of tree nodes and algorithms. /// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -/// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node -/// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no -/// children of that tree node are visited. +/// When passed to [`TreeNode::visit`], [`TreeNodeVisitor::f_down`] and +/// [`TreeNodeVisitor::f_up`] are invoked recursively on the tree. +/// See [`TreeNodeRecursion`] for more details on controlling the traversal. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. - type N: TreeNode; + type Node: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + /// Default implementation simply continues the recursion. + fn f_down(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) + /// Invoked after all children of `node` are visited. + /// Default implementation simply continues the recursion. + fn f_up(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) } } -/// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is -/// invoked recursively on all nodes of a tree. +/// Trait for potentially recursively transforming a tree of [`TreeNode`]s. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. - type N: TreeNode; + type Node: TreeNode; - /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(Recursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(RewriteRecursion::Continue) + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) } - /// Invoked after (Postorder) all children of `node` have been mutated and - /// returns a potentially modified node. - fn mutate(&mut self, node: Self::N) -> Result; -} - -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`]. -#[derive(Debug)] -pub enum RewriteRecursion { - /// Continue rewrite this node tree. - Continue, - /// Call 'op' immediately and return. - Mutate, - /// Do not rewrite the children of this node. - Stop, - /// Keep recursive but skip apply op on this node - Skip, + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) + } } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. -#[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. +/// Controls how [`TreeNode`] recursions should proceed. +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum TreeNodeRecursion { + /// Continue recursion with the next node. Continue, - /// Keep recursive but skip applying op on the children - Skip, - /// Stop the visit to this node tree. + /// In top-down traversals, skip recursing into children but continue with + /// the next node, which actually means pruning of the subtree. + /// + /// In bottom-up traversals, bypass calling bottom-up closures till the next + /// leaf node. + /// + /// In combined traversals, if it is the `f_down` (pre-order) phase, execution + /// "jumps" to the next `f_up` (post-order) phase by shortcutting its children. + /// If it is the `f_up` (post-order) phase, execution "jumps" to the next `f_down` + /// (pre-order) phase by shortcutting its parent nodes until the first parent node + /// having unvisited children path. + Jump, + /// Stop recursion. Stop, } -pub enum Transformed { - /// The item was transformed / rewritten somehow - Yes(T), - /// The item was not transformed - No(T), +/// This struct is used by tree transformation APIs such as +/// - [`TreeNode::rewrite`], +/// - [`TreeNode::transform_down`], +/// - [`TreeNode::transform_down_mut`], +/// - [`TreeNode::transform_up`], +/// - [`TreeNode::transform_up_mut`], +/// - [`TreeNode::transform_down_up`] +/// +/// to control the transformation and return the transformed result. +/// +/// Specifically, API users can provide transformation closures or [`TreeNodeRewriter`] +/// implementations to control the transformation by returning: +/// - The resulting (possibly transformed) node, +/// - A flag indicating whether any change was made to the node, and +/// - A flag specifying how to proceed with the recursion. +/// +/// At the end of the transformation, the return value will contain: +/// - The final (possibly transformed) tree, +/// - A flag indicating whether any change was made to the tree, and +/// - A flag specifying how the recursion ended. +#[derive(PartialEq, Debug)] +pub struct Transformed { + pub data: T, + pub transformed: bool, + pub tnr: TreeNodeRecursion, } impl Transformed { - pub fn into(self) -> T { - match self { - Transformed::Yes(t) => t, - Transformed::No(t) => t, + /// Create a new `Transformed` object with the given information. + pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self { + Self { + data, + transformed, + tnr, } } - pub fn into_pair(self) -> (T, bool) { - match self { - Transformed::Yes(t) => (t, true), - Transformed::No(t) => (t, false), + /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement. + pub fn yes(data: T) -> Self { + Self::new(data, true, TreeNodeRecursion::Continue) + } + + /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement. + pub fn no(data: T) -> Self { + Self::new(data, false, TreeNodeRecursion::Continue) + } + + /// Applies the given `f` to the data of this [`Transformed`] object. + pub fn update_data U>(self, f: F) -> Transformed { + Transformed::new(f(self.data), self.transformed, self.tnr) + } + + /// Maps the data of [`Transformed`] object to the result of the given `f`. + pub fn map_data Result>(self, f: F) -> Result> { + f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) + } + + /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] + /// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently + /// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of + /// the node is [`TreeNodeRecursion::Jump`], recursion stops with the given + /// `return_if_jump` value. + fn try_transform_node_with Result>>( + mut self, + f: F, + return_if_jump: TreeNodeRecursion, + ) -> Result> { + match self.tnr { + TreeNodeRecursion::Continue => { + return f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }); + } + TreeNodeRecursion::Jump => { + self.tnr = return_if_jump; + } + TreeNodeRecursion::Stop => {} + } + Ok(self) + } + + /// If [`TreeNodeRecursion`] of the node is [`TreeNodeRecursion::Continue`] or + /// [`TreeNodeRecursion::Jump`], transformation is applied to the node. + /// Otherwise, it remains as it is. + pub fn try_transform_node Result>>( + self, + f: F, + ) -> Result> { + if self.tnr == TreeNodeRecursion::Stop { + Ok(self) + } else { + f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }) } } } -/// Helper trait for implementing [`TreeNode`] that have children stored as Arc's -/// -/// If some trait object, such as `dyn T`, implements this trait, -/// its related `Arc` will automatically implement [`TreeNode`] +/// Transformation helper to process tree nodes that are siblings. +pub trait TransformedIterator: Iterator { + fn map_until_stop_and_collect< + F: FnMut(Self::Item) -> Result>, + >( + self, + f: F, + ) -> Result>>; +} + +impl TransformedIterator for I { + fn map_until_stop_and_collect< + F: FnMut(Self::Item) -> Result>, + >( + self, + mut f: F, + ) -> Result>> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + let data = self + .map(|item| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + f(item).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(item), + }) + .collect::>>()?; + Ok(Transformed::new(data, transformed, tnr)) + } +} + +/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. +pub trait TransformedResult { + fn data(self) -> Result; + + fn transformed(self) -> Result; + + fn tnr(self) -> Result; +} + +impl TransformedResult for Result> { + fn data(self) -> Result { + self.map(|t| t.data) + } + + fn transformed(self) -> Result { + self.map(|t| t.transformed) + } + + fn tnr(self) -> Result { + self.map(|t| t.tnr) + } +} + +/// Helper trait for implementing [`TreeNode`] that have children stored as +/// `Arc`s. If some trait object, such as `dyn T`, implements this trait, +/// its related `Arc` will automatically implement [`TreeNode`]. pub trait DynTreeNode { - /// Returns all children of the specified TreeNode + /// Returns all children of the specified `TreeNode`. fn arc_children(&self) -> Vec>; - /// construct a new self with the specified children + /// Constructs a new node with the specified children. fn with_new_arc_children( &self, arc_self: Arc, @@ -336,32 +605,40 @@ pub trait DynTreeNode { ) -> Result>; } -/// Blanket implementation for Arc for any tye that implements -/// [`DynTreeNode`] (such as [`Arc`]) +/// Blanket implementation for any `Arc` where `T` implements [`DynTreeNode`] +/// (such as [`Arc`]). impl TreeNode for Arc { - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; for child in self.arc_children() { - handle_tree_recursion!(op(&child)?) + tnr = f(&child)?; + handle_visit_recursion!(tnr) } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { let children = self.arc_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; - let arc_self = Arc::clone(&self); - self.with_new_arc_children(arc_self, new_children) + let new_children = children.into_iter().map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` + // along with the node containing transformed children. + if new_children.transformed { + let arc_self = Arc::clone(&self); + new_children.map_data(|new_children| { + self.with_new_arc_children(arc_self, new_children) + }) + } else { + Ok(Transformed::new(self, false, new_children.tnr)) + } } else { - Ok(self) + Ok(Transformed::no(self)) } } } @@ -381,28 +658,1016 @@ pub trait ConcreteTreeNode: Sized { } impl TreeNode for T { - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; for child in self.children() { - handle_tree_recursion!(op(child)?) + tnr = f(child)?; + handle_visit_recursion!(tnr) } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { let (new_self, children) = self.take_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; - new_self.with_new_children(new_children) + let new_children = children.into_iter().map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` along with + // the node containing transformed children. + new_children.map_data(|new_children| new_self.with_new_children(new_children)) } else { - Ok(new_self) + Ok(Transformed::no(new_self)) } } } + +#[cfg(test)] +mod tests { + use std::fmt::Display; + + use crate::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, + }; + use crate::Result; + + #[derive(PartialEq, Debug)] + struct TestTreeNode { + children: Vec>, + data: T, + } + + impl TestTreeNode { + fn new(children: Vec>, data: T) -> Self { + Self { children, data } + } + } + + impl TreeNode for TestTreeNode { + fn apply_children(&self, f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let mut tnr = TreeNodeRecursion::Continue; + for child in &self.children { + tnr = f(child)?; + handle_visit_recursion!(tnr); + } + Ok(tnr) + } + + fn map_children(self, f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { + Ok(self + .children + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|new_children| Self { + children: new_children, + ..self + })) + } + } + + // J + // | + // I + // | + // F + // / \ + // E G + // | | + // C H + // / \ + // B D + // | + // A + fn test_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + // Continue on all nodes + // Expected visits in a combined traversal + fn all_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + // Expected transformed tree after a combined traversal + fn transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_c = + TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + // Expected transformed tree after a top-down traversal + fn transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // Expected transformed tree after a bottom-up traversal + fn transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) + } + + // f_down Jump on A node + fn f_down_jump_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_down Jump on E node + fn f_down_jump_on_e_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_up Jump on A node + fn f_up_jump_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) + } + + // f_up Jump on E node + fn f_up_jump_on_e_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_jump_on_e_transformed_tree() -> TestTreeNode { + transformed_tree() + } + + fn f_up_jump_on_e_transformed_up_tree() -> TestTreeNode { + transformed_up_tree() + } + + // f_down Stop on A node + + fn f_down_stop_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_down Stop on E node + fn f_down_stop_on_e_visits() -> Vec { + vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_up Stop on A node + fn f_up_stop_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + // f_up Stop on E node + fn f_up_stop_on_e_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_c = + TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + fn down_visits(visits: Vec) -> Vec { + visits + .into_iter() + .filter(|v| v.starts_with("f_down")) + .collect() + } + + type TestVisitorF = Box) -> Result>; + + struct TestVisitor { + visits: Vec, + f_down: TestVisitorF, + f_up: TestVisitorF, + } + + impl TestVisitor { + fn new(f_down: TestVisitorF, f_up: TestVisitorF) -> Self { + Self { + visits: vec![], + f_down, + f_up, + } + } + } + + impl TreeNodeVisitor for TestVisitor { + type Node = TestTreeNode; + + fn f_down(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_down({})", node.data)); + (*self.f_down)(node) + } + + fn f_up(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_up({})", node.data)); + (*self.f_up)(node) + } + } + + fn visit_continue(_: &TestTreeNode) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn visit_event_on>( + data: D, + event: TreeNodeRecursion, + ) -> impl FnMut(&TestTreeNode) -> Result { + let d = data.into(); + move |node| { + Ok(if node.data == d { + event + } else { + TreeNodeRecursion::Continue + }) + } + } + + macro_rules! visit_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_VISITS:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + let mut visitor = TestVisitor::new(Box::new($F_DOWN), Box::new($F_UP)); + tree.visit(&mut visitor)?; + assert_eq!(visitor.visits, $EXPECTED_VISITS); + + Ok(()) + } + }; + } + + macro_rules! test_apply { + ($NAME:ident, $F:expr, $EXPECTED_VISITS:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + let mut visits = vec![]; + tree.apply(&mut |node| { + visits.push(format!("f_down({})", node.data)); + $F(node) + })?; + assert_eq!(visits, $EXPECTED_VISITS); + + Ok(()) + } + }; + } + + type TestRewriterF = + Box) -> Result>>>; + + struct TestRewriter { + f_down: TestRewriterF, + f_up: TestRewriterF, + } + + impl TestRewriter { + fn new(f_down: TestRewriterF, f_up: TestRewriterF) -> Self { + Self { f_down, f_up } + } + } + + impl TreeNodeRewriter for TestRewriter { + type Node = TestTreeNode; + + fn f_down(&mut self, node: Self::Node) -> Result> { + (*self.f_down)(node) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + (*self.f_up)(node) + } + } + + fn transform_yes>( + transformation_name: N, + ) -> impl FnMut(TestTreeNode) -> Result>> { + move |node| { + Ok(Transformed::yes(TestTreeNode::new( + node.children, + format!("{}({})", transformation_name, node.data).into(), + ))) + } + } + + fn transform_and_event_on< + N: Display, + T: PartialEq + Display + From, + D: Into, + >( + transformation_name: N, + data: D, + event: TreeNodeRecursion, + ) -> impl FnMut(TestTreeNode) -> Result>> { + let d = data.into(); + move |node| { + let new_node = TestTreeNode::new( + node.children, + format!("{}({})", transformation_name, node.data).into(), + ); + Ok(if node.data == d { + Transformed::new(new_node, true, event) + } else { + Transformed::yes(new_node) + }) + } + } + + macro_rules! rewrite_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP)); + assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + macro_rules! transform_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!( + tree.transform_down_up(&mut $F_DOWN, &mut $F_UP,)?, + $EXPECTED_TREE + ); + + Ok(()) + } + }; + } + + macro_rules! transform_down_test { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!(tree.transform_down_mut(&mut $F)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + macro_rules! transform_up_test { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!(tree.transform_up_mut(&mut $F)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + visit_test!(test_visit, visit_continue, visit_continue, all_visits()); + visit_test!( + test_visit_f_down_jump_on_a, + visit_event_on("a", TreeNodeRecursion::Jump), + visit_continue, + f_down_jump_on_a_visits() + ); + visit_test!( + test_visit_f_down_jump_on_e, + visit_event_on("e", TreeNodeRecursion::Jump), + visit_continue, + f_down_jump_on_e_visits() + ); + visit_test!( + test_visit_f_up_jump_on_a, + visit_continue, + visit_event_on("a", TreeNodeRecursion::Jump), + f_up_jump_on_a_visits() + ); + visit_test!( + test_visit_f_up_jump_on_e, + visit_continue, + visit_event_on("e", TreeNodeRecursion::Jump), + f_up_jump_on_e_visits() + ); + visit_test!( + test_visit_f_down_stop_on_a, + visit_event_on("a", TreeNodeRecursion::Stop), + visit_continue, + f_down_stop_on_a_visits() + ); + visit_test!( + test_visit_f_down_stop_on_e, + visit_event_on("e", TreeNodeRecursion::Stop), + visit_continue, + f_down_stop_on_e_visits() + ); + visit_test!( + test_visit_f_up_stop_on_a, + visit_continue, + visit_event_on("a", TreeNodeRecursion::Stop), + f_up_stop_on_a_visits() + ); + visit_test!( + test_visit_f_up_stop_on_e, + visit_continue, + visit_event_on("e", TreeNodeRecursion::Stop), + f_up_stop_on_e_visits() + ); + + test_apply!(test_apply, visit_continue, down_visits(all_visits())); + test_apply!( + test_apply_f_down_jump_on_a, + visit_event_on("a", TreeNodeRecursion::Jump), + down_visits(f_down_jump_on_a_visits()) + ); + test_apply!( + test_apply_f_down_jump_on_e, + visit_event_on("e", TreeNodeRecursion::Jump), + down_visits(f_down_jump_on_e_visits()) + ); + test_apply!( + test_apply_f_down_stop_on_a, + visit_event_on("a", TreeNodeRecursion::Stop), + down_visits(f_down_stop_on_a_visits()) + ); + test_apply!( + test_apply_f_down_stop_on_e, + visit_event_on("e", TreeNodeRecursion::Stop), + down_visits(f_down_stop_on_e_visits()) + ); + + rewrite_test!( + test_rewrite, + transform_yes("f_down"), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_down_jump_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_down_jump_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(f_down_jump_on_e_transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_up_jump_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_a_transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_up_jump_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_e_transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_up_stop_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_up_stop_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + + transform_test!( + test_transform, + transform_yes("f_down"), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + transform_test!( + test_transform_f_down_jump_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + transform_test!( + test_transform_f_down_jump_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(f_down_jump_on_e_transformed_tree()) + ); + transform_test!( + test_transform_f_up_jump_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_a_transformed_tree()) + ); + transform_test!( + test_transform_f_up_jump_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_e_transformed_tree()) + ); + transform_test!( + test_transform_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_up_stop_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_up_stop_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + + transform_down_test!( + test_transform_down, + transform_yes("f_down"), + Transformed::yes(transformed_down_tree()) + ); + transform_down_test!( + test_transform_down_f_down_jump_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), + Transformed::yes(f_down_jump_on_a_transformed_down_tree()) + ); + transform_down_test!( + test_transform_down_f_down_jump_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), + Transformed::yes(f_down_jump_on_e_transformed_down_tree()) + ); + transform_down_test!( + test_transform_down_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + Transformed::new( + f_down_stop_on_a_transformed_down_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_down_test!( + test_transform_down_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + Transformed::new( + f_down_stop_on_e_transformed_down_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + + transform_up_test!( + test_transform_up, + transform_yes("f_up"), + Transformed::yes(transformed_up_tree()) + ); + transform_up_test!( + test_transform_up_f_up_jump_on_a, + transform_and_event_on("f_up", "a", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_a_transformed_up_tree()) + ); + transform_up_test!( + test_transform_up_f_up_jump_on_e, + transform_and_event_on("f_up", "e", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_e_transformed_up_tree()) + ); + transform_up_test!( + test_transform_up_f_up_stop_on_a, + transform_and_event_on("f_up", "a", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_up_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_up_test!( + test_transform_up_f_up_stop_on_e, + transform_and_event_on("f_up", "e", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_up_tree(), + true, + TreeNodeRecursion::Stop + ) + ); +} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 077356b716b0..eef25792d00a 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -19,29 +19,27 @@ use std::sync::Arc; -use arrow::compute::{and, cast, prep_null_mask_filter}; +use super::PartitionedFile; +use crate::datasource::listing::ListingTableUrl; +use crate::execution::context::SessionState; +use crate::{error::Result, scalar::ScalarValue}; + use arrow::{ - array::{ArrayRef, StringBuilder}, + array::{Array, ArrayRef, AsArray, StringBuilder}, + compute::{and, cast, prep_null_mask_filter}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use arrow_array::cast::AsArray; -use arrow_array::Array; use arrow_schema::Fields; -use futures::stream::FuturesUnordered; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; -use log::{debug, trace}; - -use crate::{error::Result, scalar::ScalarValue}; - -use super::PartitionedFile; -use crate::datasource::listing::ListingTableUrl; -use crate::execution::context::SessionState; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; + +use futures::stream::{BoxStream, FuturesUnordered}; +use futures::{StreamExt, TryStreamExt}; +use log::{debug, trace}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; @@ -57,9 +55,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(VisitRecursion::Skip) + Ok(TreeNodeRecursion::Jump) } else { - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } Expr::Literal(_) @@ -88,27 +86,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(VisitRecursion::Continue), + | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } ScalarFunctionDefinition::UDF(fun) => { match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } @@ -129,7 +127,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Unnest { .. } | Expr::Placeholder(_) => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } }) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 3c40509a86d2..c0e37a7150d9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -15,26 +15,28 @@ // specific language governing permissions and limitations // under the License. +use std::collections::BTreeSet; +use std::sync::Arc; + +use super::ParquetFileMetrics; +use crate::physical_plan::metrics; + use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; -use std::collections::BTreeSet; - use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; + use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; -use std::sync::Arc; - -use crate::physical_plan::metrics; - -use super::ParquetFileMetrics; /// This module contains utilities for enabling the pushdown of DataFusion filter predicates (which /// can be any DataFusion `Expr` that evaluates to a `BooleanArray`) to the parquet decoder level in `arrow-rs`. @@ -188,8 +190,7 @@ impl<'a> FilterCandidateBuilder<'a> { mut self, metadata: &ParquetMetaData, ) -> Result> { - let expr = self.expr.clone(); - let expr = expr.rewrite(&mut self)?; + let expr = self.expr.clone().rewrite(&mut self).data()?; if self.non_primitive_columns || self.projected_columns { Ok(None) @@ -209,29 +210,35 @@ impl<'a> FilterCandidateBuilder<'a> { } impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { - type N = Arc; + type Node = Arc; - fn pre_visit(&mut self, node: &Arc) -> Result { + fn f_down( + &mut self, + node: Arc, + ) -> Result>> { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(node)) } - fn mutate(&mut self, expr: Arc) -> Result> { + fn f_up( + &mut self, + expr: Arc, + ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { if self.file_schema.field_with_name(column.name()).is_err() { // the column expr must be in the table schema @@ -239,7 +246,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { Ok(field) => { // return the null value corresponding to the data type let null_value = ScalarValue::try_from(field.data_type())?; - Ok(Arc::new(Literal::new(null_value))) + Ok(Transformed::yes(Arc::new(Literal::new(null_value)))) } Err(e) => { // If the column is not in the table schema, should throw the error @@ -249,7 +256,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 2d964d29688c..2144cd3c7736 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -40,7 +40,7 @@ use arrow_schema::Schema; use datafusion_common::{ alias::AliasGenerator, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -2189,9 +2189,9 @@ impl<'a> BadPlanVisitor<'a> { } impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn f_down(&mut self, node: &Self::Node) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2205,7 +2205,7 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } - _ => Ok(VisitRecursion::Continue), + _ => Ok(TreeNodeRecursion::Continue), } } } diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4fe11c14a758..df54222270ce 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -27,7 +27,7 @@ use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -85,10 +85,14 @@ impl PhysicalOptimizerRule for AggregateStatistics { Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| { + self.optimize(child, _config).map(Transformed::yes) + }) + .data() } } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes)) + .data() } } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 7b66ca529094..7c0082037da0 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -18,8 +18,10 @@ //! CoalesceBatches optimizer that groups batches together rows //! in bigger batches to avoid overhead with small batches -use crate::config::ConfigOptions; +use std::sync::Arc; + use crate::{ + config::ConfigOptions, error::Result, physical_optimizer::PhysicalOptimizerRule, physical_plan::{ @@ -27,8 +29,8 @@ use crate::{ repartition::RepartitionExec, Partitioning, }, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use std::sync::Arc; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters @@ -71,14 +73,15 @@ impl PhysicalOptimizerRule for CoalesceBatches { }) .unwrap_or(false); if wrap_in_coalesce { - Ok(Transformed::Yes(Arc::new(CoalesceBatchesExec::new( + Ok(Transformed::yes(Arc::new(CoalesceBatchesExec::new( plan, target_batch_size, )))) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .data() } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 61eb2381c63b..c45e14100e82 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -109,11 +109,12 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { }); Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) }) + .data() } fn name(&self) -> &str { @@ -185,11 +186,12 @@ fn discard_column_index(group_expr: Arc) -> Arc None, }; Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) + Transformed::yes(normalized_form) } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .data() .unwrap_or(group_expr) } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index eb221a28e2cf..822cd0541ae2 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -45,7 +45,7 @@ use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning}; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -197,22 +197,25 @@ impl PhysicalOptimizerRule for EnforceDistribution { let adjusted = if top_down_join_key_reordering { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new_default(plan); - let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; + let adjusted = plan_requirements + .transform_down(&adjust_input_keys_ordering) + .data()?; adjusted.plan } else { // Run a bottom-up process plan.transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) - })? + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) + }) + .data()? }; let distribution_context = DistributionContext::new_default(adjusted); // Distribution enforcement needs to be applied bottom-up. - let distribution_context = - distribution_context.transform_up(&|distribution_context| { + let distribution_context = distribution_context + .transform_up(&|distribution_context| { ensure_distribution(distribution_context, config) - })?; + }) + .data()?; Ok(distribution_context.plan) } @@ -306,7 +309,7 @@ fn adjust_input_keys_ordering( vec![], &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } PartitionMode::CollectLeft => { // Push down requirements to the right side @@ -370,18 +373,18 @@ fn adjust_input_keys_ordering( sort_options.clone(), &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } else if let Some(aggregate_exec) = plan.as_any().downcast_ref::() { if !requirements.data.is_empty() { if aggregate_exec.mode() == &AggregateMode::FinalPartitioned { return reorder_aggregate_keys(requirements, aggregate_exec) - .map(Transformed::Yes); + .map(Transformed::yes); } else { requirements.data.clear(); } } else { // Keep everything unchanged - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } } else if let Some(proj) = plan.as_any().downcast_ref::() { let expr = proj.expr(); @@ -409,7 +412,7 @@ fn adjust_input_keys_ordering( child.data = requirements.data.clone(); } } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn reorder_partitioned_join_keys( @@ -1057,7 +1060,7 @@ fn ensure_distribution( let dist_context = update_children(dist_context)?; if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); + return Ok(Transformed::no(dist_context)); } let target_partitions = config.execution.target_partitions; @@ -1237,7 +1240,7 @@ fn ensure_distribution( plan.with_new_children(children_plans)? }; - Ok(Transformed::Yes(DistributionContext::new( + Ok(Transformed::yes(DistributionContext::new( plan, data, children, ))) } @@ -1323,6 +1326,7 @@ pub(crate) mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::tree_node::TransformedResult; use datafusion_common::ScalarValue; use datafusion_expr::logical_plan::JoinType; use datafusion_expr::Operator; @@ -1716,7 +1720,7 @@ pub(crate) mod tests { config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; config.optimizer.prefer_existing_sort = prefer_existing_sort; - ensure_distribution(distribution_context, &config).map(|item| item.into().plan) + ensure_distribution(distribution_context, &config).map(|item| item.data.plan) } /// Test whether plan matches with expected plan @@ -1785,14 +1789,16 @@ pub(crate) mod tests { PlanWithKeyRequirements::new_default($PLAN.clone()); let adjusted = plan_requirements .transform_down(&adjust_input_keys_ordering) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. adjusted.plan } else { // Run reorder_join_keys_to_inputs rule $PLAN.clone().transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) - })? + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) + }) + .data()? }; // Then run ensure_distribution rule @@ -1800,6 +1806,7 @@ pub(crate) mod tests { .transform_up(&|distribution_context| { ensure_distribution(distribution_context, &config) }) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 5fac1397e023..79dd5758cc2f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -60,7 +60,7 @@ use crate::physical_plan::windows::{ use crate::physical_plan::{Distribution, ExecutionPlan, InputOrderMode}; use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::partial_sort::PartialSortExec; @@ -160,37 +160,40 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_requirements = PlanWithCorrespondingSort::new_default(plan); // Execute a bottom-up traversal to enforce sorting requirements, // remove unnecessary sorts, and optimize sort-sensitive operators: - let adjusted = plan_requirements.transform_up(&ensure_sorting)?; + let adjusted = plan_requirements.transform_up(&ensure_sorting)?.data; let new_plan = if config.optimizer.repartition_sorts { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); - let parallel = - plan_with_coalesce_partitions.transform_up(¶llelize_sorts)?; + let parallel = plan_with_coalesce_partitions + .transform_up(¶llelize_sorts) + .data()?; parallel.plan } else { adjusted.plan }; let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); - let updated_plan = - plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + let updated_plan = plan_with_pipeline_fixer + .transform_up(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, true, config, ) - })?; + }) + .data()?; // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; + let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?.data; adjusted .plan - .transform_up(&|plan| Ok(Transformed::Yes(replace_with_partial_sort(plan)?))) + .transform_up(&|plan| Ok(Transformed::yes(replace_with_partial_sort(plan)?))) + .data() } fn name(&self) -> &str { @@ -262,7 +265,7 @@ fn parallelize_sorts( // `SortPreservingMergeExec` or a `CoalescePartitionsExec`, and they // all have a single child. Therefore, if the first child has no // connection, we can return immediately. - Ok(Transformed::No(requirements)) + Ok(Transformed::no(requirements)) } else if (is_sort(&requirements.plan) || is_sort_preserving_merge(&requirements.plan)) && requirements.plan.output_partitioning().partition_count() <= 1 @@ -291,7 +294,7 @@ fn parallelize_sorts( } let spm = SortPreservingMergeExec::new(sort_exprs, requirements.plan.clone()); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(spm.with_fetch(fetch)), false, @@ -305,7 +308,7 @@ fn parallelize_sorts( // For the removal of self node which is also a `CoalescePartitionsExec`. requirements = requirements.children.swap_remove(0); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(CoalescePartitionsExec::new(requirements.plan.clone())), false, @@ -313,7 +316,7 @@ fn parallelize_sorts( ), )) } else { - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -326,10 +329,12 @@ fn ensure_sorting( // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.children.is_empty() { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } let maybe_requirements = analyze_immediate_sort_removal(requirements); - let Transformed::No(mut requirements) = maybe_requirements else { + requirements = if !maybe_requirements.transformed { + maybe_requirements.data + } else { return Ok(maybe_requirements); }; @@ -368,17 +373,17 @@ fn ensure_sorting( // calculate the result in reverse: let child_node = &requirements.children[0]; if is_window(plan) && child_node.data { - return adjust_window_sort_removal(requirements).map(Transformed::Yes); + return adjust_window_sort_removal(requirements).map(Transformed::yes); } else if is_sort_preserving_merge(plan) && child_node.plan.output_partitioning().partition_count() <= 1 { // This `SortPreservingMergeExec` is unnecessary, input already has a // single partition. let child_node = requirements.children.swap_remove(0); - return Ok(Transformed::Yes(child_node)); + return Ok(Transformed::yes(child_node)); } - update_sort_ctx_children(requirements, false).map(Transformed::Yes) + update_sort_ctx_children(requirements, false).map(Transformed::yes) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input @@ -408,10 +413,10 @@ fn analyze_immediate_sort_removal( child.data = false; } node.data = false; - return Transformed::Yes(node); + return Transformed::yes(node); } } - Transformed::No(node) + Transformed::no(node) } /// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine @@ -683,6 +688,7 @@ mod tests { let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); let adjusted = plan_requirements .transform_up(&ensure_sorting) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -691,6 +697,7 @@ mod tests { PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); let parallel = plan_with_coalesce_partitions .transform_up(¶llelize_sorts) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. parallel.plan @@ -708,6 +715,7 @@ mod tests { state.config_options(), ) }) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -715,6 +723,7 @@ mod tests { assign_initial_requirements(&mut sort_pushdown); sort_pushdown .transform_down(&pushdown_sorts) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index cc629df73120..20104285e44a 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -37,7 +37,7 @@ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow_schema::Schema; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, JoinSide, JoinType}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::sort_properties::SortProperties; @@ -236,7 +236,9 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let new_plan = plan.transform_up(&|p| apply_subrules(p, &subrules, config))?; + let new_plan = plan + .transform_up(&|p| apply_subrules(p, &subrules, config)) + .data()?; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -251,13 +253,15 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; - new_plan.transform_up(&|plan| { - statistical_join_selection_subrule( - plan, - collect_threshold_byte_size, - collect_threshold_num_rows, - ) - }) + new_plan + .transform_up(&|plan| { + statistical_join_selection_subrule( + plan, + collect_threshold_byte_size, + collect_threshold_num_rows, + ) + }) + .data() } fn name(&self) -> &str { @@ -433,9 +437,9 @@ fn statistical_join_selection_subrule( }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) } @@ -647,7 +651,7 @@ fn apply_subrules( for subrule in subrules { input = subrule(input, config_options)?; } - Ok(Transformed::Yes(input)) + Ok(Transformed::yes(input)) } #[cfg(test)] @@ -808,8 +812,9 @@ mod tests_statistical { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let new_plan = - plan.transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new()))?; + let new_plan = plan + .transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new())) + .data()?; // TODO: End state payloads will be checked here. let config = ConfigOptions::new().optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 7be9acec5092..9509d4e4c828 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -26,7 +26,7 @@ use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use itertools::Itertools; @@ -109,7 +109,7 @@ impl LimitedDistinctAggregation { let mut rewrite_applicable = true; let mut closure = |plan: Arc| { if !rewrite_applicable { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { if found_match_aggr { @@ -120,7 +120,7 @@ impl LimitedDistinctAggregation { // a partial and final aggregation with different groupings disqualifies // rewriting the child aggregation rewrite_applicable = false; - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } } } @@ -131,14 +131,14 @@ impl LimitedDistinctAggregation { Some(new_aggr) => { match_aggr = plan; found_match_aggr = true; - return Ok(Transformed::Yes(new_aggr)); + return Ok(Transformed::yes(new_aggr)); } } } rewrite_applicable = false; - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child.clone().transform_down_mut(&mut closure).data().ok()?; if is_global_limit { return Some(Arc::new(GlobalLimitExec::new( child, @@ -162,22 +162,22 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { plan: Arc, config: &ConfigOptions, ) -> Result> { - let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { + if config.optimizer.enable_distinct_aggregation_soft_limit { plan.transform_down(&|plan| { Ok( if let Some(plan) = LimitedDistinctAggregation::transform_limit(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) - })? + }) + .data() } else { - plan - }; - Ok(plan) + Ok(plan) + } } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index 7fea375725a5..bd71b3e8ed80 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -29,7 +29,7 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -193,15 +193,17 @@ impl PhysicalOptimizerRule for OutputRequirements { ) -> Result> { match self.mode { RuleMode::Add => require_top_ordering(plan), - RuleMode::Remove => plan.transform_up(&|plan| { - if let Some(sort_req) = - plan.as_any().downcast_ref::() - { - Ok(Transformed::Yes(sort_req.input())) - } else { - Ok(Transformed::No(plan)) - } - }), + RuleMode::Remove => plan + .transform_up(&|plan| { + if let Some(sort_req) = + plan.as_any().downcast_ref::() + { + Ok(Transformed::yes(sort_req.input())) + } else { + Ok(Transformed::no(plan)) + } + }) + .data(), } } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index e783f75378b1..1dc8bc5042bf 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -28,7 +28,7 @@ use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::config::OptimizerOptions; use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; @@ -51,6 +51,7 @@ impl PhysicalOptimizerRule for PipelineChecker { config: &ConfigOptions, ) -> Result> { plan.transform_up(&|p| check_finiteness_requirements(p, &config.optimizer)) + .data() } fn name(&self) -> &str { @@ -82,7 +83,7 @@ pub fn check_finiteness_requirements( input ) } else { - Ok(Transformed::No(input)) + Ok(Transformed::no(input)) } } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 4ed265d59526..17d30a2b4ec1 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan, ExecutionPlanProperties} use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{DataFusionError, JoinSide}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -73,7 +75,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&remove_unnecessary_projections) + plan.transform_down(&remove_unnecessary_projections).data() } fn name(&self) -> &str { @@ -98,7 +100,7 @@ pub fn remove_unnecessary_projections( // If the projection does not cause any change on the input, we can // safely remove it: if is_projection_removable(projection) { - return Ok(Transformed::Yes(projection.input().clone())); + return Ok(Transformed::yes(projection.input().clone())); } // If it does, check if we can push it under its child(ren): let input = projection.input().as_any(); @@ -111,8 +113,10 @@ pub fn remove_unnecessary_projections( return if let Some(new_plan) = maybe_unified { // To unify 3 or more sequential projections: remove_unnecessary_projections(new_plan) + .data() + .map(Transformed::yes) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; } else if let Some(output_req) = input.downcast_ref::() { try_swapping_with_output_req(projection, output_req)? @@ -148,10 +152,10 @@ pub fn remove_unnecessary_projections( None } } else { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); }; - Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) + Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes)) } /// Tries to embed `projection` to its input (`csv`). If possible, returns @@ -271,7 +275,7 @@ fn try_unifying_projections( if let Some(column) = expr.as_any().downcast_ref::() { *column_ref_map.entry(column.clone()).or_default() += 1; } - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); @@ -893,16 +897,16 @@ fn update_expr( .clone() .transform_up_mut(&mut |expr: Arc| { if state == RewriteState::RewrittenInvalid { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); } let Some(column) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: - Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + Ok(Transformed::yes(projected_exprs[column.index()].0.clone())) } else { // default to invalid, in case we can't find the relevant column state = RewriteState::RewrittenInvalid; @@ -923,11 +927,12 @@ fn update_expr( ) }) .map_or_else( - || Ok(Transformed::No(expr)), - |c| Ok(Transformed::Yes(c)), + || Ok(Transformed::no(expr)), + |c| Ok(Transformed::yes(c)), ) } - }); + }) + .data(); new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } @@ -1044,7 +1049,7 @@ fn new_columns_for_join_on( }) .map(|(index, (_, alias))| Column::new(alias, index)); if let Some(new_column) = new_column { - Ok(Transformed::Yes(Arc::new(new_column))) + Ok(Transformed::yes(Arc::new(new_column))) } else { // If the column is not found in the projection expressions, // it means that the column is not projected. In this case, @@ -1055,9 +1060,10 @@ fn new_columns_for_join_on( ))) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } }) + .data() .ok() }) .collect::>(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 37f705d8a82f..05d2d852e057 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -29,20 +29,22 @@ use crate::{ logical_expr::Operator, physical_plan::{ColumnarValue, PhysicalExpr}, }; -use arrow::record_batch::RecordBatchOptions; + use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, + record_batch::{RecordBatch, RecordBatchOptions}, }; use arrow_array::cast::AsArray; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ - internal_err, plan_err, + internal_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, + ScalarValue, }; -use datafusion_common::{plan_datafusion_err, ScalarValue}; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; + use log::trace; /// A source of runtime statistical information to [`PruningPredicate`]s. @@ -1034,12 +1036,13 @@ fn rewrite_column_expr( e.transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { - return Ok(Transformed::Yes(Arc::new(column_new.clone()))); + return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() } fn reverse_operator(op: Operator) -> Result { diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index c0abde26c300..e8b6a78b929e 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -236,7 +236,7 @@ pub(crate) fn replace_with_order_preserving_variants( ) -> Result> { update_children(&mut requirements); if !(is_sort(&requirements.plan) && requirements.children[0].data) { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } // For unbounded cases, we replace with the order-preserving variant in any @@ -260,13 +260,13 @@ pub(crate) fn replace_with_order_preserving_variants( for child in alternate_plan.children.iter_mut() { child.data = false; } - Ok(Transformed::Yes(alternate_plan)) + Ok(Transformed::yes(alternate_plan)) } else { // The alternate plan does not help, use faster order-breaking variants: alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; alternate_plan.data = false; requirements.children = vec![alternate_plan]; - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -293,7 +293,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::tree_node::TreeNode; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; @@ -395,7 +395,7 @@ mod tests { // Run the rule top-down let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new_default(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).and_then(check_integrity)?; + let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).data().and_then(check_integrity)?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index ff82319fba19..c527819e7746 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -87,7 +87,7 @@ pub(crate) fn pushdown_sorts( } // Can push down requirements child.data = None; - return Ok(Transformed::Yes(child)); + return Ok(Transformed::yes(child)); } else { // Can not push down requirements requirements.children = vec![child]; @@ -112,7 +112,7 @@ pub(crate) fn pushdown_sorts( requirements = add_sort_above(requirements, sort_reqs, None); assign_initial_requirements(&mut requirements); } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn pushdown_requirement_to_children( diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 3898fb6345f0..d944cedb0f96 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -40,7 +40,7 @@ use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; @@ -376,15 +376,19 @@ pub fn sort_exec( /// TODO: Once [`ExecutionPlan`] implements [`PartialEq`], string comparisons should be /// replaced with direct plan equality checks. pub fn check_integrity(context: PlanContext) -> Result> { - context.transform_up(&|node| { - let children_plans = node.plan.children(); - assert_eq!(node.children.len(), children_plans.len()); - for (child_plan, child_node) in children_plans.iter().zip(node.children.iter()) { - assert_eq!( - displayable(child_plan.as_ref()).one_line().to_string(), - displayable(child_node.plan.as_ref()).one_line().to_string() - ); - } - Ok(Transformed::No(node)) - }) + context + .transform_up(&|node| { + let children_plans = node.plan.children(); + assert_eq!(node.children.len(), children_plans.len()); + for (child_plan, child_node) in + children_plans.iter().zip(node.children.iter()) + { + assert_eq!( + displayable(child_plan.as_ref()).one_line().to_string(), + displayable(child_node.plan.as_ref()).one_line().to_string() + ); + } + Ok(Transformed::no(node)) + }) + .data() } diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index 0ca709e56bcb..c47e5e25d143 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -29,7 +29,7 @@ use crate::physical_plan::ExecutionPlan; use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; @@ -104,13 +104,13 @@ impl TopKAggregation { let mut cardinality_preserved = true; let mut closure = |plan: Arc| { if !cardinality_preserved { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it match Self::transform_agg(aggr, order, limit) { None => cardinality_preserved = false, - Some(plan) => return Ok(Transformed::Yes(plan)), + Some(plan) => return Ok(Transformed::yes(plan)), } } else { // or we continue down whitelisted nodes of other types @@ -118,9 +118,9 @@ impl TopKAggregation { cardinality_preserved = false; } } - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child.clone().transform_down_mut(&mut closure).data().ok()?; let sort = SortExec::new(sort.expr().to_vec(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); @@ -140,20 +140,20 @@ impl PhysicalOptimizerRule for TopKAggregation { plan: Arc, config: &ConfigOptions, ) -> Result> { - let plan = if config.optimizer.enable_topk_aggregation { + if config.optimizer.enable_topk_aggregation { plan.transform_down(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) - })? + }) + .data() } else { - plan - }; - Ok(plan) + Ok(plan) + } } fn name(&self) -> &str { diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 6b371b782cb5..59905d859dc8 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,19 +17,12 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use arrow::array::{Array, ArrayRef, AsArray, Int64Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_array::cast::AsArray; use arrow_array::types::Int64Type; -use arrow_array::Array; -use hashbrown::HashMap; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; -use tokio::task::JoinSet; - use datafusion::common::Result; use datafusion::datasource::MemTable; use datafusion::physical_plan::aggregates::{ @@ -38,12 +31,17 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_physical_expr::expressions::{col, Sum}; use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; +use hashbrown::HashMap; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tokio::task::JoinSet; + /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -315,8 +313,9 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { let mut visitor = Visitor { expected_sort }; impl TreeNodeVisitor for Visitor { - type N = Arc; - fn pre_visit(&mut self, node: &Self::N) -> Result { + type Node = Arc; + + fn f_down(&mut self, node: &Self::Node) -> Result { if let Some(exec) = node.as_any().downcast_ref::() { if self.expected_sort { assert!(matches!( @@ -327,7 +326,7 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear)); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d1687cbd6f29..68b123ab1f28 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,28 +17,27 @@ //! Logical Expressions: [`Expr`] +use std::collections::HashSet; +use std::fmt::{self, Display, Formatter, Write}; +use std::hash::Hash; +use std::str::FromStr; +use std::sync::Arc; + use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; +use crate::{ + aggregate_function, built_in_function, built_in_window_function, udaf, + BuiltinScalarFunction, ExprSchemable, Operator, Signature, +}; -use crate::Operator; -use crate::{aggregate_function, ExprSchemable}; -use crate::{built_in_function, BuiltinScalarFunction}; -use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; -use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + internal_err, plan_err, Column, DFSchema, OwnedTableReference, Result, ScalarValue, +}; use sqlparser::ast::NullTreatment; -use std::collections::HashSet; -use std::fmt; -use std::fmt::{Display, Formatter, Write}; -use std::hash::Hash; -use std::str::FromStr; -use std::sync::Arc; - -use crate::Signature; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS @@ -1275,8 +1274,9 @@ impl Expr { rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; } - Ok(Transformed::Yes(expr)) + Ok(Transformed::yes(expr)) }) + .data() } /// Returns true if some of this `exprs` subexpressions may not be evaluated diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 3f7388c3c3d5..cd9a8344dec4 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -17,16 +17,19 @@ //! Expression rewriter -use crate::expr::{Alias, Unnest}; -use crate::logical_plan::Projection; -use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; -use datafusion_common::Result; -use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; +use crate::expr::{Alias, Unnest}; +use crate::logical_plan::Projection; +use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; + +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; +use datafusion_common::{Column, DFSchema, Result}; + mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; @@ -37,12 +40,13 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions @@ -61,12 +65,13 @@ pub fn normalize_col_with_schemas( Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage @@ -90,12 +95,13 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( if let Expr::Column(c) = expr { let col = c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// Recursively normalize all [`Column`] expressions in a list of expression trees @@ -116,14 +122,15 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { - Some(new_c) => Transformed::Yes(Expr::Column((*new_c).to_owned())), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// Recursively 'unnormalize' (remove all qualifiers) from an @@ -139,12 +146,13 @@ pub fn unnormalize_col(expr: Expr) -> Expr { relation: None, name: c.name, }; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() .expect("Unnormalize is infallable") } @@ -177,12 +185,13 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() .expect("strip_outer_reference is infallable") } @@ -260,22 +269,24 @@ pub fn unalias(expr: Expr) -> Expr { /// schema of plan nodes don't change after optimization pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result where - R: TreeNodeRewriter, + R: TreeNodeRewriter, { let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?; + let expr = expr.rewrite(rewriter)?.data; expr.alias_if_changed(original_name) } #[cfg(test)] mod test { + use std::ops::Add; + use super::*; use crate::expr::Sort; use crate::{col, lit, Cast}; + use arrow::datatypes::DataType; - use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; - use std::ops::Add; #[derive(Default)] struct RecordingRewriter { @@ -283,16 +294,16 @@ mod test { } impl TreeNodeRewriter for RecordingRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { self.v.push(format!("Previsited {expr}")); - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(expr)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { self.v.push(format!("Mutated {expr}")); - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -307,19 +318,27 @@ mod test { } else { utf8_val }; - Ok(Transformed::Yes(lit(utf8_val))) + Ok(Transformed::yes(lit(utf8_val))) } // otherwise, return None - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }; // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("foo")) + .transform(&transformer) + .data() + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("baz")) + .transform(&transformer) + .data() + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -454,10 +473,10 @@ mod test { } impl TreeNodeRewriter for TestRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) + fn f_up(&mut self, _: Expr) -> Result> { + Ok(Transformed::yes(self.rewrite_to.clone())) } } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646..b1bc11a83f90 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -20,7 +20,8 @@ use crate::expr::{Alias, Sort}; use crate::expr_rewriter::normalize_col; use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; -use datafusion_common::tree_node::{Transformed, TreeNode}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output @@ -91,7 +92,7 @@ fn rewrite_in_terms_of_projection( .to_field(input.schema()) .map(|f| f.qualified_column())?, ); - return Ok(Transformed::Yes(col)); + return Ok(Transformed::yes(col)); } // if that doesn't work, try to match the expression as an @@ -103,7 +104,7 @@ fn rewrite_in_terms_of_projection( e } else { // The expr is not based on Aggregate plan output. Skip it. - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; // expr is an actual expr like min(t.c2), but we are looking @@ -118,7 +119,7 @@ fn rewrite_in_terms_of_projection( // look for the column named the same as this expr if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { let found = found.clone(); - return Ok(Transformed::Yes(match normalized_expr { + return Ok(Transformed::yes(match normalized_expr { Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { expr: Box::new(found), data_type, @@ -131,8 +132,9 @@ fn rewrite_in_terms_of_projection( })); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() } /// Does the underlying expr match e? diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba1..e0cb44626e24 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -16,12 +16,14 @@ // under the License. //! This module provides logic for displaying LogicalPlans in various styles +use std::fmt; + use crate::LogicalPlan; + use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; -use std::fmt; /// Formats plans with a single line per node. For example: /// @@ -49,12 +51,12 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit( + fn f_down( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -69,15 +71,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit( + fn f_up( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -171,12 +173,12 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit( + fn f_down( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -204,18 +206,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit( + fn f_up( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); res.ok_or(DataFusionError::Internal("Fail to format".to_string())) - .map(|_| VisitRecursion::Continue) + .map(|_| TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 5cce8f9cd45c..ca021c4bfc28 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -46,8 +46,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, - VisitRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -476,7 +475,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(using_columns) @@ -649,31 +648,24 @@ impl LogicalPlan { // Decimal128(Some(69999999999999),30,15) // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - struct RemoveAliases {} - - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; - - fn pre_visit(&mut self, expr: &Expr) -> Result { + let predicate = predicate + .transform_down(&|expr| { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) } - Expr::Alias(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, + TreeNodeRecursion::Jump, + )), + _ => Ok(Transformed::no(expr)), } - } - - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) - } - } - - let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + }) + .data()?; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) @@ -1125,9 +1117,9 @@ impl LogicalPlan { impl LogicalPlan { /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> + pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> Result, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -1151,9 +1143,9 @@ impl LogicalPlan { } /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> + pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> where - V: TreeNodeVisitor, + V: TreeNodeVisitor, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -1226,11 +1218,11 @@ impl LogicalPlan { _ => {} } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok::<(), DataFusionError>(()) })?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(param_types) @@ -1247,19 +1239,20 @@ impl LogicalPlan { Expr::Placeholder(Placeholder { id, .. }) => { let value = param_values.get_placeholders_with_values(id)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = Arc::new(qry.subquery.replace_params_with_values(param_values)?); - Ok(Transformed::Yes(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns: qry.outer_ref_columns.clone(), }))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }) + .data() } } @@ -2842,9 +2835,9 @@ digraph { } impl TreeNodeVisitor for OkVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2855,10 +2848,10 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2869,7 +2862,7 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -2925,23 +2918,23 @@ digraph { } impl TreeNodeVisitor for StoppingVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } - self.inner.pre_visit(plan)?; + self.inner.f_down(plan)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } - self.inner.post_visit(plan) + self.inner.f_up(plan) } } @@ -2994,22 +2987,22 @@ digraph { } impl TreeNodeVisitor for ErrorVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } - self.inner.pre_visit(plan) + self.inner.f_down(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } - self.inner.post_visit(plan) + self.inner.f_up(plan) } } @@ -3317,10 +3310,11 @@ digraph { Arc::new(LogicalPlan::TableScan(table)), ) .unwrap(); - Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + Ok(Transformed::yes(LogicalPlan::Filter(filter))) } - x => Ok(Transformed::No(x)), + x => Ok(Transformed::no(x)), }) + .data() .unwrap(); let expected = "Explain\ diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 81949f2178f6..67d48f986f13 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,14 +24,16 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::{handle_visit_recursion, internal_err, Result}; impl TreeNode for Expr { - fn apply_children Result>( + fn apply_children Result>( &self, - op: &mut F, - ) -> Result { + f: &mut F, + ) -> Result { let children = match self { Expr::Alias(Alias{expr,..}) | Expr::Not(expr) @@ -129,21 +131,19 @@ impl TreeNode for Expr { } }; + let mut tnr = TreeNodeRecursion::Continue; for child in children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } + tnr = f(child)?; + handle_visit_recursion!(tnr, DOWN); } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children Result>( - self, - mut transform: F, - ) -> Result { + fn map_children(self, mut f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { Ok(match self { Expr::Column(_) | Expr::Wildcard { .. } @@ -153,27 +153,29 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Unnest(_) - | Expr::Literal(_) => self, + | Expr::Literal(_) => Transformed::no(self), Expr::Alias(Alias { expr, relation, name, - }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), + }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => Expr::InSubquery(InSubquery::new( - transform_boxed(expr, &mut transform)?, - subquery, - negated, - )), + }) => transform_box(expr, &mut f)?.update_data(|be| { + Expr::InSubquery(InSubquery::new(be, subquery, negated)) + }), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - Expr::BinaryExpr(BinaryExpr::new( - transform_boxed(left, &mut transform)?, - op, - transform_boxed(right, &mut transform)?, - )) + transform_box(left, &mut f)? + .update_data(|new_left| (new_left, right)) + .try_transform_node(|(new_left, right)| { + Ok(transform_box(right, &mut f)? + .update_data(|new_right| (new_left, new_right))) + })? + .update_data(|(new_left, new_right)| { + Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) + }) } Expr::Like(Like { negated, @@ -181,102 +183,136 @@ impl TreeNode for Expr { pattern, escape_char, case_insensitive, - }) => Expr::Like(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, pattern)) + .try_transform_node(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .update_data(|new_pattern| (new_expr, new_pattern))) + })? + .update_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => Expr::SimilarTo(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), - Expr::Not(expr) => Expr::Not(transform_boxed(expr, &mut transform)?), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, pattern)) + .try_transform_node(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .update_data(|new_pattern| (new_expr, new_pattern))) + })? + .update_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), + Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), Expr::IsNotNull(expr) => { - Expr::IsNotNull(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) + } + Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull), + Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue), + Expr::IsFalse(expr) => { + transform_box(expr, &mut f)?.update_data(Expr::IsFalse) } - Expr::IsNull(expr) => Expr::IsNull(transform_boxed(expr, &mut transform)?), - Expr::IsTrue(expr) => Expr::IsTrue(transform_boxed(expr, &mut transform)?), - Expr::IsFalse(expr) => Expr::IsFalse(transform_boxed(expr, &mut transform)?), Expr::IsUnknown(expr) => { - Expr::IsUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsUnknown) } Expr::IsNotTrue(expr) => { - Expr::IsNotTrue(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue) } Expr::IsNotFalse(expr) => { - Expr::IsNotFalse(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse) } Expr::IsNotUnknown(expr) => { - Expr::IsNotUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown) } Expr::Negative(expr) => { - Expr::Negative(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::Negative) } Expr::Between(Between { expr, negated, low, high, - }) => Expr::Between(Between::new( - transform_boxed(expr, &mut transform)?, - negated, - transform_boxed(low, &mut transform)?, - transform_boxed(high, &mut transform)?, - )), - Expr::Case(case) => { - let expr = transform_option_box(case.expr, &mut transform)?; - let when_then_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - transform_boxed(when, &mut transform)?, - transform_boxed(then, &mut transform)?, - )) - }) - .collect::>>()?; - let else_expr = transform_option_box(case.else_expr, &mut transform)?; - - Expr::Case(Case::new(expr, when_then_expr, else_expr)) - } - Expr::Cast(Cast { expr, data_type }) => { - Expr::Cast(Cast::new(transform_boxed(expr, &mut transform)?, data_type)) - } - Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new( - transform_boxed(expr, &mut transform)?, - data_type, - )), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, low, high)) + .try_transform_node(|(new_expr, low, high)| { + Ok(transform_box(low, &mut f)? + .update_data(|new_low| (new_expr, new_low, high))) + })? + .try_transform_node(|(new_expr, new_low, high)| { + Ok(transform_box(high, &mut f)? + .update_data(|new_high| (new_expr, new_low, new_high))) + })? + .update_data(|(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }), + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => transform_option_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, when_then_expr, else_expr)) + .try_transform_node(|(new_expr, when_then_expr, else_expr)| { + Ok(when_then_expr + .into_iter() + .map_until_stop_and_collect(|(when, then)| { + transform_box(when, &mut f)? + .update_data(|new_when| (new_when, then)) + .try_transform_node(|(new_when, then)| { + Ok(transform_box(then, &mut f)? + .update_data(|new_then| (new_when, new_then))) + }) + })? + .update_data(|new_when_then_expr| { + (new_expr, new_when_then_expr, else_expr) + })) + })? + .try_transform_node(|(new_expr, new_when_then_expr, else_expr)| { + Ok(transform_option_box(else_expr, &mut f)?.update_data( + |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), + )) + })? + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::Cast(Cast::new(be, data_type))), + Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::Sort(Sort { expr, asc, nulls_first, - }) => Expr::Sort(Sort::new( - transform_boxed(expr, &mut transform)?, - asc, - nulls_first, - )), - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( - ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + transform_vec(args, &mut f)?.map_data(|new_args| match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args))) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + })? + } Expr::WindowFunction(WindowFunction { args, fun, @@ -284,112 +320,139 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, - }) => Expr::WindowFunction(WindowFunction::new( - fun, - transform_vec(args, &mut transform)?, - transform_vec(partition_by, &mut transform)?, - transform_vec(order_by, &mut transform)?, - window_frame, - null_treatment, - )), + }) => transform_vec(args, &mut f)? + .update_data(|new_args| (new_args, partition_by, order_by)) + .try_transform_node(|(new_args, partition_by, order_by)| { + Ok(transform_vec(partition_by, &mut f)?.update_data( + |new_partition_by| (new_args, new_partition_by, order_by), + )) + })? + .try_transform_node(|(new_args, new_partition_by, order_by)| { + Ok( + transform_vec(order_by, &mut f)?.update_data(|new_order_by| { + (new_args, new_partition_by, new_order_by) + }), + ) + })? + .update_data(|(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new( + fun, + new_args, + new_partition_by, + new_order_by, + window_frame, + null_treatment, + )) + }), Expr::AggregateFunction(AggregateFunction { args, func_def, distinct, filter, order_by, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::UDF(fun) => { - let order_by = order_by - .map(|order_by| transform_vec(order_by, &mut transform)) - .transpose()?; - Expr::AggregateFunction(AggregateFunction::new_udf( - fun, - transform_vec(args, &mut transform)?, - false, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_vec(args, &mut f)? + .update_data(|new_args| (new_args, filter, order_by)) + .try_transform_node(|(new_args, filter, order_by)| { + Ok(transform_option_box(filter, &mut f)? + .update_data(|new_filter| (new_args, new_filter, order_by))) + })? + .try_transform_node(|(new_args, new_filter, order_by)| { + Ok(transform_option_vec(order_by, &mut f)? + .update_data(|new_order_by| (new_args, new_filter, new_order_by))) + })? + .map_data(|(new_args, new_filter, new_order_by)| match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun, + new_args, + distinct, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + new_args, + false, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + })?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::Cube(exprs) => Expr::GroupingSet(GroupingSet::Cube( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::GroupingSets(lists_of_exprs) => { - Expr::GroupingSet(GroupingSet::GroupingSets( - lists_of_exprs - .into_iter() - .map(|exprs| transform_vec(exprs, &mut transform)) - .collect::>>()?, - )) - } + GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), + GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), + GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs + .into_iter() + .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))? + .update_data(|new_lists_of_exprs| { + Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) + }), }, Expr::InList(InList { expr, list, negated, - }) => Expr::InList(InList::new( - transform_boxed(expr, &mut transform)?, - transform_vec(list, &mut transform)?, - negated, - )), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, list)) + .try_transform_node(|(new_expr, list)| { + Ok(transform_vec(list, &mut f)? + .update_data(|new_list| (new_expr, new_list))) + })? + .update_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - Expr::GetIndexedField(GetIndexedField::new( - transform_boxed(expr, &mut transform)?, - field, - )) + transform_box(expr, &mut f)?.update_data(|be| { + Expr::GetIndexedField(GetIndexedField::new(be, field)) + }) } }) } } -fn transform_boxed Result>( - boxed_expr: Box, - transform: &mut F, -) -> Result> { - // TODO: It might be possible to avoid an allocation (the Box::new) below by reusing the box. - transform(*boxed_expr).map(Box::new) +fn transform_box(be: Box, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + Ok(f(*be)?.update_data(Box::new)) } -fn transform_option_box Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|expr| transform_boxed(expr, transform)) - .transpose() +fn transform_option_box( + obe: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + obe.map_or(Ok(Transformed::no(None)), |be| { + Ok(transform_box(be, f)?.update_data(Some)) + }) } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|exprs| transform_vec(exprs, transform)) - .transpose() +fn transform_option_vec( + ove: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + ove.map_or(Ok(Transformed::no(None)), |ve| { + Ok(transform_vec(ve, f)?.update_data(Some)) + }) } /// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>( - v: Vec, - transform: &mut F, -) -> Result> { - v.into_iter().map(transform).collect() +fn transform_vec(ve: Vec, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + ve.into_iter().map_until_stop_and_collect(f) } diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c35a09874a62..02d5d1851289 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -19,19 +19,21 @@ use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; -use datafusion_common::{handle_tree_recursion, Result}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; +use datafusion_common::{handle_visit_recursion, Result}; impl TreeNode for LogicalPlan { - fn apply Result>( + fn apply Result>( &self, - op: &mut F, - ) -> Result { + f: &mut F, + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::apply_subqueries`] before visiting its children - handle_tree_recursion!(op(self)?); - self.apply_subqueries(op)?; - self.apply_children(&mut |node| node.apply(op)) + handle_visit_recursion!(f(self)?, DOWN); + self.apply_subqueries(f)?; + self.apply_children(&mut |n| n.apply(f)) } /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke @@ -54,48 +56,58 @@ impl TreeNode for LogicalPlan { /// visitor.post_visit(Filter) /// visitor.post_visit(Projection) /// ``` - fn visit>( + fn visit>( &self, visitor: &mut V, - ) -> Result { + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::visit_subqueries`] before visiting its children - handle_tree_recursion!(visitor.pre_visit(self)?); - self.visit_subqueries(visitor)?; - handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); - visitor.post_visit(self) + match visitor.f_down(self)? { + TreeNodeRecursion::Continue => { + self.visit_subqueries(visitor)?; + handle_visit_recursion!( + self.apply_children(&mut |n| n.visit(visitor))?, + UP + ); + visitor.f_up(self) + } + TreeNodeRecursion::Jump => { + self.visit_subqueries(visitor)?; + visitor.f_up(self) + } + TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), + } } - fn apply_children Result>( + fn apply_children Result>( &self, - op: &mut F, - ) -> Result { + f: &mut F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; for child in self.inputs() { - handle_tree_recursion!(op(child)?) + tnr = f(child)?; + handle_visit_recursion!(tnr, DOWN) } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { - let old_children = self.inputs(); - let new_children = old_children + let new_children = self + .inputs() .iter() .map(|&c| c.clone()) - .map(transform) - .collect::>>()?; - - // if any changes made, make a new child - if old_children - .into_iter() - .zip(new_children.iter()) - .any(|(c1, c2)| c1 != c2) - { - self.with_new_exprs(self.expressions(), new_children) + .map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` + // along with the node containing transformed children. + if new_children.transformed { + new_children.map_data(|new_children| { + self.with_new_exprs(self.expressions(), new_children) + }) } else { - Ok(self) + Ok(new_children.update_data(|_| self)) } } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index fe9297b32a8e..dfd90e470965 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -31,7 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, @@ -665,10 +665,10 @@ where exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Jump); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -685,10 +685,10 @@ where if let Err(e) = f(expr) { // save the error for later (it may not be a DataFusionError err = Err(e); - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } else { // keep going - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } }) // The closure always returns OK, so this will always too diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 9242e68562c6..93b24d71c496 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -15,9 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::analyzer::AnalyzerRule; + use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; use datafusion_common::Result; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; @@ -27,7 +32,6 @@ use datafusion_expr::{ aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; -use std::sync::Arc; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -43,7 +47,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal) + plan.transform_down(&analyze_internal).data() } fn name(&self) -> &str { @@ -61,7 +65,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes( + Ok(Transformed::yes( LogicalPlanBuilder::from((*window.input).clone()) .window(window_expr)? .build()?, @@ -74,7 +78,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Aggregate( + Ok(Transformed::yes(LogicalPlan::Aggregate( Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, ))) } @@ -83,7 +87,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Sort(Sort { + Ok(Transformed::yes(LogicalPlan::Sort(Sort { expr: sort_expr, input, fetch, @@ -95,7 +99,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::Projection( Projection::try_new(projection_expr, projection.input)?, ))) } @@ -103,22 +107,22 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { predicate, input, .. }) => { let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; - Ok(Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, input, )?))) } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } struct CountWildcardRewriter {} impl TreeNodeRewriter for CountWildcardRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, old_expr: Expr) -> Result { - let new_expr = match old_expr.clone() { + fn f_up(&mut self, old_expr: Expr) -> Result> { + Ok(match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( @@ -131,7 +135,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { null_treatment, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::WindowFunction(expr::WindowFunction { + Transformed::yes(Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), @@ -140,10 +144,10 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, window_frame, null_treatment, - }) + })) } - _ => old_expr, + _ => Transformed::no(old_expr), }, Expr::AggregateFunction(AggregateFunction { func_def: @@ -156,68 +160,65 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction::new( + Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, order_by, - )) + ))) } - _ => old_expr, + _ => Transformed::no(old_expr), }, ScalarSubquery(Subquery { subquery, outer_ref_columns, - }) => { - let new_plan = subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) - } + }) => subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .update_data(|new_plan| { + ScalarSubquery(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + }) + }), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) - } - Expr::Exists(expr::Exists { subquery, negated }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) - } - _ => old_expr, - }; - Ok(new_expr) + }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .update_data(|new_plan| { + Expr::InSubquery(InSubquery::new( + expr, + Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + )) + }), + Expr::Exists(expr::Exists { subquery, negated }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .update_data(|new_plan| { + Expr::Exists(expr::Exists { + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + }) + }), + _ => Transformed::no(old_expr), + }) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec8293..b21ec851dfcd 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; + use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::Exists; -use datafusion_expr::expr::InSubquery; +use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::{ logical_plan::LogicalPlan, Expr, Filter, LogicalPlanBuilder, TableScan, }; @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal) + plan.transform_up(&analyze_internal).data() } fn name(&self) -> &str { @@ -51,7 +51,7 @@ impl AnalyzerRule for InlineTableScan { } fn analyze_internal(plan: LogicalPlan) -> Result> { - Ok(match plan { + match plan { // Match only on scans without filter / projection / fetch // Views and DataFrames won't have those added // during the early stage of planning @@ -64,33 +64,31 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { }) if filters.is_empty() && source.get_logical_plan().is_some() => { let sub_plan = source.get_logical_plan().unwrap(); let projection_exprs = generate_projection_expr(&projection, sub_plan)?; - let plan = LogicalPlanBuilder::from(sub_plan.clone()) + LogicalPlanBuilder::from(sub_plan.clone()) .project(projection_exprs)? // Ensures that the reference to the inlined table remains the // same, meaning we don't have to change any of the parent nodes // that reference this table. .alias(table_name)? - .build()?; - Transformed::Yes(plan) + .build() + .map(Transformed::yes) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( - new_expr, - filter.input, - )?)) + let new_expr = filter.predicate.transform(&rewrite_subquery).data()?; + Filter::try_new(new_expr, filter.input) + .map(|e| Transformed::yes(LogicalPlan::Filter(e))) } - _ => Transformed::No(plan), - }) + _ => Ok(Transformed::no(plan)), + } } fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::Exists(Exists { subquery, negated }))) + Ok(Transformed::yes(Expr::Exists(Exists { subquery, negated }))) } Expr::InSubquery(InSubquery { expr, @@ -98,19 +96,19 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, )))) } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::ScalarSubquery(subquery))) + Ok(Transformed::yes(Expr::ScalarSubquery(subquery))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } } @@ -135,13 +133,12 @@ fn generate_projection_expr( mod tests { use std::{sync::Arc, vec}; - use arrow::datatypes::{DataType, Field, Schema}; - - use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; - use crate::analyzer::inline_table_scan::InlineTableScan; use crate::test::assert_analyzed_plan_eq; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; + pub struct RawTableSource {} impl TableSource for RawTableSource { diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 4c480017fc3a..08caa4be60a9 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -29,7 +29,7 @@ use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; @@ -136,7 +136,7 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { })?; } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index eedfc40a7f80..41ebcd8e501a 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -19,21 +19,19 @@ use std::sync::Arc; +use super::AnalyzerRule; + use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::utils::list_ndims; -use datafusion_common::DFSchema; -use datafusion_common::DFSchemaRef; -use datafusion_common::Result; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; -use datafusion_expr::BuiltinScalarFunction; -use datafusion_expr::Operator; -use datafusion_expr::ScalarFunctionDefinition; -use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; - -use super::AnalyzerRule; +use datafusion_expr::{ + BinaryExpr, BuiltinScalarFunction, Expr, LogicalPlan, Operator, + ScalarFunctionDefinition, +}; #[derive(Default)] pub struct OperatorToFunction {} @@ -94,41 +92,34 @@ pub(crate) struct OperatorToFunctionRewriter { } impl TreeNodeRewriter for OperatorToFunctionRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::BinaryExpr(BinaryExpr { - ref left, + fn f_up(&mut self, expr: Expr) -> Result> { + if let Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) = expr + { + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( + left.as_ref(), op, - ref right, - }) => { - if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( - left.as_ref(), - op, - right.as_ref(), - self.schema.as_ref(), - )? - .or_else(|| { - rewrite_array_concat_operator_to_func( - left.as_ref(), - op, - right.as_ref(), - ) - }) { - // Convert &Box -> Expr - let left = (**left).clone(); - let right = (**right).clone(); - return Ok(Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args: vec![left, right], - })); - } - - Ok(expr) + right.as_ref(), + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func(left.as_ref(), op, right.as_ref()) + }) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + }))); } - _ => Ok(expr), } + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index a0e972fc703c..b7f513727d39 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; + use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; @@ -25,7 +28,6 @@ use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, Window, }; -use std::ops::Deref; /// Do necessary check on subquery expressions and fail the invalid plan /// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, @@ -146,7 +148,7 @@ fn check_inner_plan( LogicalPlan::Aggregate(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -171,7 +173,7 @@ fn check_inner_plan( check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -188,7 +190,7 @@ fn check_inner_plan( | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -206,7 +208,7 @@ fn check_inner_plan( is_aggregate, can_contain_outer_ref, )?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -221,7 +223,7 @@ fn check_inner_plan( JoinType::Full => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -287,12 +289,11 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { .into_iter() .partition(|e| e.contains_outer()); - correlated - .into_iter() - .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(VisitRecursion::Continue); + for expr in correlated { + exprs.push(strip_outer_reference(expr.clone())); + } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index d469e0f8ce0d..08f49ed15b09 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -19,10 +19,11 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, IntervalUnit}; +use crate::analyzer::AnalyzerRule; +use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -50,8 +51,6 @@ use datafusion_expr::{ WindowFrameBound, WindowFrameUnits, }; -use crate::analyzer::AnalyzerRule; - #[derive(Default)] pub struct TypeCoercion {} @@ -126,13 +125,9 @@ pub(crate) struct TypeCoercionRewriter { } impl TreeNodeRewriter for TypeCoercionRewriter { - type N = Expr; - - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match expr { Expr::Unnest(_) => internal_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" @@ -142,20 +137,20 @@ impl TreeNodeRewriter for TypeCoercionRewriter { outer_ref_columns, }) => { let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, - })) + }))) } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Expr::Exists(Exists { + Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }, negated, - })) + }))) } Expr::InSubquery(InSubquery { expr, @@ -173,42 +168,34 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }; - Ok(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( Box::new(expr.cast_to(&common_type, &self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, - ))) - } - Expr::Not(expr) => { - let expr = not(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsTrue(expr) => { - let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotTrue(expr) => { - let expr = is_not_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsFalse(expr) => { - let expr = is_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotFalse(expr) => { - let expr = - is_not_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsUnknown(expr) => { - let expr = is_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotUnknown(expr) => { - let expr = - is_not_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + )))) } + Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( + &expr, + &self.schema, + )?))), + Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), Expr::Like(Like { negated, expr, @@ -230,14 +217,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { })?; let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new( + Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, pattern, escape_char, case_insensitive, - )); - Ok(expr) + )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( @@ -245,12 +231,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &op, &right.get_type(&self.schema)?, )?; - - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left.cast_to(&left_type, &self.schema)?), op, Box::new(right.cast_to(&right_type, &self.schema)?), - ))) + )))) } Expr::Between(Between { expr, @@ -280,13 +265,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - let expr = Expr::Between(Between::new( + Ok(Transformed::yes(Expr::Between(Between::new( Box::new(expr.cast_to(&coercion_type, &self.schema)?), negated, Box::new(low.cast_to(&coercion_type, &self.schema)?), Box::new(high.cast_to(&coercion_type, &self.schema)?), - )); - Ok(expr) + )))) } Expr::InList(InList { expr, @@ -313,18 +297,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list_expr.cast_to(&coerced_type, &self.schema) }) .collect::>>()?; - let expr = Expr::InList(InList ::new( + Ok(Transformed::yes(Expr::InList(InList ::new( Box::new(cast_expr), cast_list_expr, negated, - )); - Ok(expr) + )))) } } } Expr::Case(case) => { let case = coerce_case_expression(case, &self.schema)?; - Ok(Expr::Case(case)) + Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -338,7 +321,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun, )?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( + fun, new_args, + )))) } ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -346,7 +331,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + Ok(Transformed::yes(Expr::ScalarFunction( + ScalarFunction::new_udf(fun, new_expr), + ))) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -366,10 +353,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + ), + ))) } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -377,10 +365,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + ), + ))) } AggregateFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -409,15 +398,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { _ => args, }; - let expr = Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( fun, args, partition_by, order_by, window_frame, null_treatment, - )); - Ok(expr) + )))) } Expr::Alias(_) | Expr::Column(_) @@ -434,7 +422,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(expr), + | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), } } } @@ -764,31 +752,26 @@ mod test { use std::any::Any; use std::sync::{Arc, OnceLock}; - use arrow::array::{FixedSizeListArray, Int32Array}; - use arrow::datatypes::{DataType, TimeUnit}; + use crate::analyzer::type_coercion::{ + cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + }; + use crate::test::assert_analyzed_plan_eq; - use arrow::datatypes::Field; - use datafusion_common::tree_node::TreeNode; + use arrow::array::{FixedSizeListArray, Int32Array}; + use arrow::datatypes::{DataType, Field, TimeUnit}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; + use datafusion_expr::logical_plan::{EmptyRelation, Projection}; use datafusion_expr::{ - cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, - AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, - SimpleAggregateUDF, Subquery, - }; - use datafusion_expr::{ - lit, - logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ScalarUDF, Signature, Volatility, + cast, col, concat, concat_ws, create_udaf, is_true, lit, + AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, + BuiltinScalarFunction, Case, ColumnarValue, Expr, ExprSchemable, Filter, + LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + Subquery, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; - use crate::analyzer::type_coercion::{ - cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, - }; - use crate::test::assert_analyzed_plan_eq; - fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -1289,7 +1272,7 @@ mod test { std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1324,7 +1307,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -1335,7 +1318,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -1346,7 +1329,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ae720bc68998..30c184a28e33 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -25,10 +25,12 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, }; use datafusion_common::{ - internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, + DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ @@ -642,51 +644,52 @@ impl ExprIdentifierVisitor<'_> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> (usize, Identifier) { + fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { VisitRecord::EnterMark(idx) => { - return (idx, desc); + return Some((idx, desc)); } VisitRecord::ExprItem(s) => { desc.push_str(&s); } } } - - unreachable!("Enter mark should paired with node number"); + None } } impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Jump); } self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; - let (idx, sub_expr_desc) = self.pop_enter_mark(); + let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { + return Ok(TreeNodeRecursion::Continue); + }; // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -700,7 +703,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -743,74 +746,83 @@ struct CommonSubexprRewriter<'a> { } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(RewriteRecursion::Stop); + if expr.short_circuits() || is_volatile_expression(&expr)? { + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok(RewriteRecursion::Stop); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok(RewriteRecursion::Skip); + return Ok(Transformed::no(expr)); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { if *counter > 1 { self.affected_id.insert(curr_id.clone()); - Ok(RewriteRecursion::Mutate) + + // This expr tree is finished. + if self.curr_index >= self.id_array.len() { + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Jump, + )); + } + + let (series_number, id) = &self.id_array[self.curr_index]; + self.curr_index += 1; + // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. + let expr_set_item = self.expr_set.get(id).ok_or_else(|| { + internal_datafusion_err!("expr_set invalid state") + })?; + if *series_number < self.max_series_number + || id.is_empty() + || expr_set_item.1 <= 1 + { + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Jump, + )); + } + + self.max_series_number = *series_number; + // step index to skip all sub-node (which has smaller series number). + while self.curr_index < self.id_array.len() + && *series_number > self.id_array[self.curr_index].0 + { + self.curr_index += 1; + } + + let expr_name = expr.display_name()?; + // Alias this `Column` expr to it original "expr name", + // `projection_push_down` optimizer use "expr name" to eliminate useless + // projections. + Ok(Transformed::new( + col(id).alias(expr_name), + true, + TreeNodeRecursion::Jump, + )) } else { self.curr_index += 1; - Ok(RewriteRecursion::Skip) + Ok(Transformed::no(expr)) } } _ => internal_err!("expr_set invalid state"), } } - - fn mutate(&mut self, expr: Expr) -> Result { - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(expr); - } - - let (series_number, id) = &self.id_array[self.curr_index]; - self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - DataFusionError::Internal("expr_set invalid state".to_string()) - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(expr); - } - - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - - let expr_name = expr.display_name()?; - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - Ok(col(id).alias(expr_name)) - } } fn replace_common_expr( @@ -826,6 +838,7 @@ fn replace_common_expr( max_series_number: 0, curr_index: 0, }) + .data() } #[cfg(test)] diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 0f4b39d9eee3..fd548ba4948e 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -15,19 +15,20 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeSet, HashMap}; +use std::ops::Deref; + use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; use crate::utils::collect_subquery_cols; + use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Result}; -use datafusion_common::{Column, DFSchemaRef, ScalarValue}; +use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; -use std::collections::{BTreeSet, HashMap}; -use std::ops::Deref; /// This struct rewrite the sub query plan by pull up the correlated expressions(contains outer reference columns) from the inner subquery's 'Filter'. /// It adds the inner reference columns to the 'Projection' or 'Aggregate' of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition. @@ -56,19 +57,19 @@ pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Filter(_) => Ok(RewriteRecursion::Continue), + LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } else { - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(plan)) } } LogicalPlan::Limit(_) => { @@ -77,21 +78,21 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { (false, true) => { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok(Transformed::no(plan)), } } _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok(Transformed::no(plan)), } } - fn mutate(&mut self, plan: LogicalPlan) -> Result { + fn f_up(&mut self, plan: LogicalPlan) -> Result> { let subquery_schema = plan.schema().clone(); match &plan { LogicalPlan::Filter(plan_filter) => { @@ -140,7 +141,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { .build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } (None, _) => { // if the subquery still has filter expressions, restore them. @@ -152,7 +153,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = plan.build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } } } @@ -196,7 +197,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => @@ -240,7 +241,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::SubqueryAlias(alias) => { let mut local_correlated_cols = BTreeSet::new(); @@ -262,7 +263,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(plan.clone(), input_map.clone()); } - Ok(plan) + Ok(Transformed::no(plan)) } LogicalPlan::Limit(limit) => { let input_expr_map = self @@ -273,7 +274,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => { + (true, false) => Transformed::yes( if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -281,17 +282,17 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { }) } else { LogicalPlanBuilder::from((*limit.input).clone()).build()? - } - } - _ => plan, + }, + ), + _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { self.collected_count_expr_map - .insert(new_plan.clone(), input_map); + .insert(new_plan.data.clone(), input_map); } Ok(new_plan) } - _ => Ok(plan), + _ => Ok(Transformed::no(plan)), } } } @@ -370,31 +371,34 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let result_expr = e.clone().transform_up(&|expr| { - let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { - match func_def { + let result_expr = e + .clone() + .transform_up(&|expr| { + let new_expr = match expr { + Expr::AggregateFunction(expr::AggregateFunction { + func_def, .. + }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( 0, )))) } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } } AggregateFunctionDefinition::UDF { .. } => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::Name(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } - } - } - _ => Transformed::No(expr), - }; - Ok(new_expr) - })?; + }, + _ => Transformed::no(expr), + }; + Ok(new_expr) + }) + .data()?; let result_expr = result_expr.unalias(); let props = ExecutionProps::new(); @@ -415,17 +419,23 @@ fn proj_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for expr in proj_expr.iter() { - let result_expr = expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(name) + { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + }) + .data()?; + if result_expr.ne(expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); @@ -448,17 +458,21 @@ fn filter_exprs_evaluation_result_on_empty_batch( input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result> { - let result_expr = filter_expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = filter_expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + }) + .data()?; + let pull_up_expr = if result_expr.ne(filter_expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema); diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index a9e1f1228e5e..b94cf37c5c12 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::collections::BTreeSet; +use std::ops::Deref; +use std::sync::Arc; + use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::alias::AliasGenerator; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; @@ -30,10 +35,8 @@ use datafusion_expr::{ exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; + use log::debug; -use std::collections::BTreeSet; -use std::ops::Deref; -use std::sync::Arc; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins #[derive(Default)] @@ -228,7 +231,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery.clone().rewrite(&mut pull_up)?; + let new_plan = subquery.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } @@ -321,8 +324,11 @@ impl SubqueryInfo { #[cfg(test)] mod tests { + use std::ops::Add; + use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ @@ -330,7 +336,6 @@ mod tests { logical_plan::LogicalPlanBuilder, not_exists, not_in_subquery, or, out_ref_col, Operator, }; - use std::ops::Add; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 07f495a7262d..4143d52a053e 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, num::NonZeroUsize, }; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_expr::LogicalPlan; /// Non-unique identifier of a [`LogicalPlan`]. @@ -75,7 +75,7 @@ fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; plan.apply(&mut |_plan| { node_number += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // Closure always return Ok .unwrap(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 40156d43c572..a63133c5166f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -22,7 +22,9 @@ use crate::optimizer::ApplyOrder; use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, JoinConstraint, Result, @@ -222,7 +224,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) @@ -233,7 +235,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { .. }) => { is_evaluate = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } Expr::Alias(_) | Expr::BinaryExpr(_) @@ -255,7 +257,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(VisitRecursion::Continue), + | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) @@ -992,13 +994,14 @@ pub fn replace_cols_by_name( e.transform_up(&|expr| { Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::Yes(new_c.clone()), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .data() } /// check whether the expression uses the columns in `check_map`. @@ -1009,12 +1012,12 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } - None => VisitRecursion::Continue, + None => TreeNodeRecursion::Continue, } } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 9aa08c37fa35..8acc36e479ca 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeSet, HashMap}; +use std::sync::Arc; + use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; -use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -56,8 +58,11 @@ impl ScalarSubqueryToJoin { sub_query_info: vec![], alias_gen, }; - let new_expr = predicate.clone().rewrite(&mut extract)?; - Ok((extract.sub_query_info, new_expr)) + predicate + .clone() + .rewrite(&mut extract) + .data() + .map(|new_expr| (extract.sub_query_info, new_expr)) } } @@ -86,20 +91,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { build_join(&subquery, &cur_input, &alias)? { if !expr_check_map.is_empty() { - rewrite_expr = - rewrite_expr.clone().transform_up(&|expr| { + rewrite_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + }) + .data()?; } cur_input = optimized_subquery; } else { @@ -141,20 +148,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { if let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { - let new_expr = - rewrite_expr.clone().transform_up(&|expr| { + let new_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + }) + .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } } @@ -201,16 +210,9 @@ struct ExtractScalarSubQuery { } impl TreeNodeRewriter for ExtractScalarSubQuery { - type N = Expr; - - fn pre_visit(&mut self, expr: &Expr) -> Result { - match expr { - Expr::ScalarSubquery(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), - } - } + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); @@ -220,12 +222,16 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; - Ok(Expr::Column(create_col_from_scalar_expr( - &scalar_expr, - subqry_alias, - )?)) + Ok(Transformed::new( + Expr::Column(create_col_from_scalar_expr( + &scalar_expr, + subqry_alias, + )?), + true, + TreeNodeRecursion::Jump, + )) } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -282,7 +288,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?; + let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } @@ -371,15 +377,17 @@ fn build_join( #[cfg(test)] mod tests { + use std::ops::Add; + use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Result; + use datafusion_expr::logical_plan::LogicalPlanBuilder; use datafusion_expr::{ - col, lit, logical_plan::LogicalPlanBuilder, max, min, out_ref_col, - scalar_subquery, sum, Between, + col, lit, max, min, out_ref_col, scalar_subquery, sum, Between, }; - use std::ops::Add; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 51647188fd93..6b5dd1b4681e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -19,15 +19,21 @@ use std::ops::Not; +use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; +use super::utils::*; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; +use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; - use datafusion_common::{ cast::{as_large_list_array, as_list_array}, - tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -39,14 +45,6 @@ use datafusion_expr::{ use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::simplify_expressions::guarantees::GuaranteeRewriter; -use crate::simplify_expressions::regex::simplify_regex_expr; -use crate::simplify_expressions::SimplifyInfo; - -use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; -use super::utils::*; - /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, @@ -131,31 +129,36 @@ impl ExprSimplifier { /// let expr = simplifier.simplify(expr).unwrap(); /// assert_eq!(expr, b_lt_2); /// ``` - pub fn simplify(&self, expr: Expr) -> Result { + pub fn simplify(&self, mut expr: Expr) -> Result { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); - let expr = if self.canonicalize { - expr.rewrite(&mut Canonicalizer::new())? - } else { - expr - }; + if self.canonicalize { + expr = expr.rewrite(&mut Canonicalizer::new()).data()? + } // TODO iterate until no changes are made during rewrite // (evaluating constants can enable new simplifications and // simplifications can enable new constant evaluation) // https://github.com/apache/arrow-datafusion/issues/1160 - expr.rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier)? - .rewrite(&mut inlist_simplifier)? - .rewrite(&mut shorten_in_list_simplifier)? - .rewrite(&mut guarantee_rewriter)? + expr.rewrite(&mut const_evaluator) + .data()? + .rewrite(&mut simplifier) + .data()? + .rewrite(&mut inlist_simplifier) + .data()? + .rewrite(&mut shorten_in_list_simplifier) + .data()? + .rewrite(&mut guarantee_rewriter) + .data()? // run both passes twice to try an minimize simplifications that we missed - .rewrite(&mut const_evaluator)? + .rewrite(&mut const_evaluator) + .data()? .rewrite(&mut simplifier) + .data() } /// Apply type coercion to an [`Expr`] so that it can be @@ -171,7 +174,7 @@ impl ExprSimplifier { pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.rewrite(&mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -303,32 +306,36 @@ impl Canonicalizer { } impl TreeNodeRewriter for Canonicalizer { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { - return Ok(expr); + return Ok(Transformed::no(expr)); }; match (left.as_ref(), right.as_ref(), op.swap()) { // (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) if right_col > left_col => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } // (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } - _ => Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })), + _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))), } } } @@ -367,9 +374,9 @@ enum ConstSimplifyResult { } impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -377,7 +384,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // stack as not ok (as all parents have at least one child or // descendant that can not be evaluated - if !Self::can_evaluate(expr) { + if !Self::can_evaluate(&expr) { // walk back up stack, marking first parent that is not mutable let parent_iter = self.can_evaluate.iter_mut().rev(); for p in parent_iter { @@ -393,10 +400,10 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(expr)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting // and may not evalute all their sub expressions. Thus if @@ -405,11 +412,15 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { Some(true) => { let result = self.evaluate_to_scalar(expr); match result { - ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)), - ConstSimplifyResult::SimplifyRuntimeError(_, expr) => Ok(expr), + ConstSimplifyResult::Simplified(s) => { + Ok(Transformed::yes(Expr::Literal(s))) + } + ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { + Ok(Transformed::yes(expr)) + } } } - Some(false) => Ok(expr), + Some(false) => Ok(Transformed::no(expr)), _ => internal_err!("Failed to pop can_evaluate"), } } @@ -566,10 +577,10 @@ impl<'a, S> Simplifier<'a, S> { } impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { - type N = Expr; + type Node = Expr; /// rewrite the expression simplifying any constant expressions - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, @@ -577,7 +588,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }; let info = self.info; - let new_expr = match expr { + Ok(match expr { // // Rules for Eq // @@ -590,11 +601,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => *right, Some(false) => Expr::Not(right), None => lit_bool_null(), - } + }) } // A = true --> A // A = false --> !A @@ -604,11 +615,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => *left, Some(false) => Expr::Not(left), None => lit_bool_null(), - } + }) } // Rules for NotEq // @@ -621,11 +632,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => Expr::Not(right), Some(false) => *right, None => lit_bool_null(), - } + }) } // A != true --> !A // A != false --> A @@ -635,11 +646,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => Expr::Not(left), Some(false) => *left, None => lit_bool_null(), - } + }) } // @@ -651,32 +662,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Or, right: _, - }) if is_true(&left) => *left, + }) if is_true(&left) => Transformed::yes(*left), // false OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&left) => *right, + }) if is_false(&left) => Transformed::yes(*right), // A OR true --> true (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: Or, right, - }) if is_true(&right) => *right, + }) if is_true(&right) => Transformed::yes(*right), // A OR false --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&right) => *left, + }) if is_false(&right) => Transformed::yes(*left), // A OR !A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // !A OR A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -684,32 +695,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // (..A..) OR A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&left, &right, Or) => *left, + }) if expr_contains(&left, &right, Or) => Transformed::yes(*left), // A OR (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&right, &left, Or) => *right, + }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), // A OR (A AND B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => { + Transformed::yes(*left) + } // (A AND B) OR A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => { + Transformed::yes(*right) + } // // Rules for AND @@ -720,32 +735,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: And, right, - }) if is_true(&left) => *right, + }) if is_true(&left) => Transformed::yes(*right), // false AND A --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: And, right: _, - }) if is_false(&left) => *left, + }) if is_false(&left) => Transformed::yes(*left), // A AND true --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if is_true(&right) => *left, + }) if is_true(&right) => Transformed::yes(*left), // A AND false --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: And, right, - }) if is_false(&right) => *right, + }) if is_false(&right) => Transformed::yes(*right), // A AND !A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: And, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // !A AND A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -753,32 +768,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // (..A..) AND A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&left, &right, And) => *left, + }) if expr_contains(&left, &right, And) => Transformed::yes(*left), // A AND (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&right, &left, And) => *right, + }) if expr_contains(&right, &left, And) => Transformed::yes(*right), // A AND (A OR B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => { + Transformed::yes(*left) + } // (A OR B) AND A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -789,25 +808,25 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Multiply, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // 1 * A --> A Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if is_one(&left) => *right, + }) if is_one(&left) => Transformed::yes(*right), // A * null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Multiply, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null * A --> null Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -818,7 +837,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - *right + Transformed::yes(*right) } // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -829,7 +848,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&right)?.is_floating() && is_zero(&left) => { - *left + Transformed::yes(*left) } // @@ -841,19 +860,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Divide, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // null / A --> null Expr::BinaryExpr(BinaryExpr { left, op: Divide, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A / null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Divide, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // // Rules for Modulo @@ -864,13 +883,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: Modulo, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null % A --> null Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -880,7 +899,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - lit(0) + Transformed::yes(lit(0)) } // @@ -892,28 +911,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseAnd, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null & A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A & 0 -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&left)? && is_zero(&right) => *right, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), // 0 & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&right)? && is_zero(&left) => *left, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), // !A & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -921,7 +940,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // A & !A -> 0 (if A not nullable) @@ -930,7 +951,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // (..A..) & A --> (..A..) @@ -938,14 +961,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, - }) if expr_contains(&left, &right, BitwiseAnd) => *left, + }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), // A & (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if expr_contains(&right, &left, BitwiseAnd) => *right, + }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), // A & (A | B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -953,7 +976,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { - *left + Transformed::yes(*left) } // (A | B) & A --> A (if B not null) @@ -962,7 +985,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -974,28 +997,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseOr, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null | A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A | 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // 0 | A -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&left) => *right, + }) if is_zero(&left) => Transformed::yes(*right), // !A | A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1003,7 +1026,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A | !A -> -1 (if A not nullable) @@ -1012,7 +1037,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) | A --> (..A..) @@ -1020,14 +1047,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, - }) if expr_contains(&left, &right, BitwiseOr) => *left, + }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), // A | (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if expr_contains(&right, &left, BitwiseOr) => *right, + }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), // A | (A & B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -1035,7 +1062,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { - *left + Transformed::yes(*left) } // (A & B) | A --> A (if B not null) @@ -1044,7 +1071,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -1056,28 +1083,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseXor, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null ^ A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A ^ 0 -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&left)? && is_zero(&right) => *left, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), // 0 ^ A -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&right)? && is_zero(&left) => *right, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), // !A ^ A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1085,7 +1112,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A ^ !A -> -1 (if A not nullable) @@ -1094,7 +1123,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1104,11 +1135,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); - if expr == *right { + Transformed::yes(if expr == *right { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) } else { expr - } + }) } // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A) @@ -1118,11 +1149,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); - if expr == *left { + Transformed::yes(if expr == *left { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } else { expr - } + }) } // @@ -1134,21 +1165,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftRight, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null >> A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A >> 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for BitwiseShiftRight @@ -1159,31 +1190,31 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftLeft, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null << A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A << 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for Not // - Expr::Not(inner) => negate_clause(*inner), + Expr::Not(inner) => Transformed::yes(negate_clause(*inner)), // // Rules for Negative // - Expr::Negative(inner) => distribute_negation(*inner), + Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)), // // Rules for Case @@ -1237,19 +1268,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, - }) => simpl_log(args, info)?, + }) => Transformed::yes(simpl_log(args, info)?), // power Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, - }) => simpl_power(args, info)?, + }) => Transformed::yes(simpl_power(args, info)?), // concat Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, - }) => simpl_concat(args)?, + }) => Transformed::yes(simpl_concat(args)?), // concat_ws Expr::ScalarFunction(ScalarFunction { @@ -1259,11 +1290,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ), args, }) => match &args[..] { - [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, - _ => Expr::ScalarFunction(ScalarFunction::new( + [delimiter, vals @ ..] => { + Transformed::yes(simpl_concat_ws(delimiter, vals)?) + } + _ => Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::ConcatWithSeparator, args, - )), + ))), }, // @@ -1272,18 +1305,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // a between 3 and 5 --> a >= 3 AND a <=5 // a not between 3 and 5 --> a < 3 OR a > 5 - Expr::Between(between) => { - if between.negated { - let l = *between.expr.clone(); - let r = *between.expr; - or(l.lt(*between.low), r.gt(*between.high)) - } else { - and( - between.expr.clone().gt_eq(*between.low), - between.expr.lt_eq(*between.high), - ) - } - } + Expr::Between(between) => Transformed::yes(if between.negated { + let l = *between.expr.clone(); + let r = *between.expr; + or(l.lt(*between.low), r.gt(*between.high)) + } else { + and( + between.expr.clone().gt_eq(*between.low), + between.expr.lt_eq(*between.high), + ) + }), // // Rules for regexes @@ -1292,7 +1323,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, - }) => simplify_regex_expr(left, op, right)?, + }) => Transformed::yes(simplify_regex_expr(left, op, right)?), // Rules for Like Expr::Like(Like { @@ -1307,25 +1338,24 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::Literal(ScalarValue::Utf8(Some(pattern_str))) if pattern_str == "%" ) => { - lit(!negated) + Transformed::yes(lit(!negated)) } // a is not null/unknown --> true (if a is not nullable) Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) if !info.nullable(&expr)? => { - lit(true) + Transformed::yes(lit(true)) } // a is null/unknown --> false (if a is not nullable) Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { - lit(false) + Transformed::yes(lit(false)) } // no additional rewrites possible - expr => expr, - }; - Ok(new_expr) + expr => Transformed::no(expr), + }) } } @@ -1337,16 +1367,15 @@ mod tests { sync::Arc, }; + use super::*; + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFField, ToDFSchema}; use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::execution_props::ExecutionProps; - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index aa7bb4f78a93..6eb583257dcb 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -21,7 +21,8 @@ use std::{borrow::Cow, collections::HashMap}; -use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; @@ -57,23 +58,25 @@ impl<'a> GuaranteeRewriter<'a> { } impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if self.guarantees.is_empty() { - return Ok(expr); + return Ok(Transformed::no(expr)); } match &expr { Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(true)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), + Some(NullableInterval::NotNull { .. }) => { + Ok(Transformed::yes(lit(false))) + } + _ => Ok(Transformed::no(expr)), }, Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(false)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), + Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), + _ => Ok(Transformed::no(expr)), }, Expr::Between(Between { expr: inner, @@ -93,14 +96,14 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let contains = expr_interval.contains(*interval)?; if contains.is_certainly_true() { - Ok(lit(!negated)) + Ok(Transformed::yes(lit(!negated))) } else if contains.is_certainly_false() { - Ok(lit(*negated)) + Ok(Transformed::yes(lit(*negated))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -135,23 +138,23 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let result = left_interval.apply_operator(op, right_interval.as_ref())?; if result.is_certainly_true() { - Ok(lit(true)) + Ok(Transformed::yes(lit(true))) } else if result.is_certainly_false() { - Ok(lit(false)) + Ok(Transformed::yes(lit(false))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } // Columns (if interval is collapsed to a single value) Expr::Column(_) => { if let Some(interval) = self.guarantees.get(&expr) { - Ok(interval.single_value().map_or(expr, lit)) + Ok(Transformed::yes(interval.single_value().map_or(expr, lit))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -181,17 +184,17 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { }) .collect::>()?; - Ok(Expr::InList(InList { + Ok(Transformed::yes(Expr::InList(InList { expr: inner.clone(), list: new_list, negated: *negated, - })) + }))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -201,7 +204,8 @@ mod tests { use super::*; use arrow::datatypes::DataType; - use datafusion_common::{tree_node::TreeNode, ScalarValue}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; + use datafusion_common::ScalarValue; use datafusion_expr::{col, lit, Operator}; #[test] @@ -221,12 +225,12 @@ mod tests { // x IS NULL => guaranteed false let expr = col("x").is_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(false)); // x IS NOT NULL => guaranteed true let expr = col("x").is_not_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(true)); } @@ -236,7 +240,7 @@ mod tests { T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).data().unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -248,7 +252,7 @@ mod tests { fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, "{} was simplified to {}, but expected it to be unchanged", @@ -478,7 +482,7 @@ mod tests { let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let output = col("x").rewrite(&mut rewriter).unwrap(); + let output = col("x").rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, Expr::Literal(scalar.clone())); } } @@ -522,7 +526,7 @@ mod tests { .collect(), *negated, ); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); let expected_list = expected_list .iter() .map(|v| lit(ScalarValue::Int32(Some(*v)))) diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 710c24f66e33..fa1d7cfc1239 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -17,17 +17,17 @@ //! This module implements a rule that simplifies the values for `InList`s +use super::utils::{is_null, lit_bool_null}; +use super::THRESHOLD_INLINE_INLIST; + use std::borrow::Cow; use std::collections::HashSet; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; -use super::utils::{is_null, lit_bool_null}; -use super::THRESHOLD_INLINE_INLIST; - pub(super) struct ShortenInListSimplifier {} impl ShortenInListSimplifier { @@ -37,9 +37,9 @@ impl ShortenInListSimplifier { } impl TreeNodeRewriter for ShortenInListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) if let Expr::InList(InList { @@ -61,7 +61,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { { let first_val = list[0].clone(); if negated { - return Ok(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.into_iter().skip(1).fold( (*expr.clone()).not_eq(first_val), |acc, y| { // Note that `A and B and C and D` is a left-deep tree structure @@ -83,20 +83,20 @@ impl TreeNodeRewriter for ShortenInListSimplifier { // The code below maintain the left-deep tree structure. acc.and((*expr.clone()).not_eq(y)) }, - )); + ))); } else { - return Ok(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.into_iter().skip(1).fold( (*expr.clone()).eq(first_val), |acc, y| { // Same reasoning as above acc.or((*expr.clone()).eq(y)) }, - )); + ))); } } } - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -109,9 +109,9 @@ impl InListSimplifier { } impl TreeNodeRewriter for InListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::InList(InList { expr, mut list, @@ -121,11 +121,11 @@ impl TreeNodeRewriter for InListSimplifier { // expr IN () --> false // expr NOT IN () --> true if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) { - return Ok(lit(negated)); + return Ok(Transformed::yes(lit(negated))); // null in (x, y, z) --> null // null not in (x, y, z) --> null } else if is_null(&expr) { - return Ok(lit_bool_null()); + return Ok(Transformed::yes(lit_bool_null())); // expr IN ((subquery)) -> expr IN (subquery), see ##5529 } else if list.len() == 1 && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) @@ -133,7 +133,9 @@ impl TreeNodeRewriter for InListSimplifier { let Expr::ScalarSubquery(subquery) = list.remove(0) else { unreachable!() }; - return Ok(Expr::InSubquery(InSubquery::new(expr, subquery, negated))); + return Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + expr, subquery, negated, + )))); } } // Combine multiple OR expressions into a single IN list expression if possible @@ -165,7 +167,7 @@ impl TreeNodeRewriter for InListSimplifier { list, negated: false, }; - return Ok(Expr::InList(merged_inlist)); + return Ok(Transformed::yes(Expr::InList(merged_inlist))); } } } @@ -191,40 +193,40 @@ impl TreeNodeRewriter for InListSimplifier { (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && !l1.negated && !l2.negated => { - return inlist_intersection(l1, l2, false); + return inlist_intersection(l1, l2, false).map(Transformed::yes); } (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && l1.negated && l2.negated => { - return inlist_union(l1, l2, true); + return inlist_union(l1, l2, true).map(Transformed::yes); } (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && !l1.negated && l2.negated => { - return inlist_except(l1, l2); + return inlist_except(l1, l2).map(Transformed::yes); } (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && l1.negated && !l2.negated => { - return inlist_except(l2, l1); + return inlist_except(l2, l1).map(Transformed::yes); } (Expr::InList(l1), Operator::Or, Expr::InList(l2)) if l1.expr == l2.expr && l1.negated && l2.negated => { - return inlist_intersection(l1, l2, true); + return inlist_intersection(l1, l2, true).map(Transformed::yes); } (left, op, right) => { // put the expression back together - return Ok(Expr::BinaryExpr(BinaryExpr { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { left: Box::new(left), op, right: Box::new(right), - })); + }))); } } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 4c22742c8635..196a35ee9ae8 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -18,13 +18,18 @@ //! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. + +use std::cmp::Ordering; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; @@ -32,8 +37,6 @@ use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; -use std::cmp::Ordering; -use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -125,13 +128,9 @@ struct UnwrapCastExprRewriter { } impl TreeNodeRewriter for UnwrapCastExprRewriter { - type N = Expr; - - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match &expr { // For case: // try_cast/cast(expr as data_type) op literal @@ -159,11 +158,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( lit(value), *op, expr.as_ref().clone(), - )); + ))); } } ( @@ -178,11 +177,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( expr.as_ref().clone(), *op, lit(value), - )); + ))); } } (_, _) => { @@ -191,7 +190,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; } // return the new binary op - Ok(binary_expr(left, *op, right)) + Ok(Transformed::yes(binary_expr(left, *op, right))) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) @@ -215,12 +214,12 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { let internal_left_type = internal_left.get_type(&self.schema); if internal_left_type.is_err() { // error data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let internal_left_type = internal_left_type?; if !is_support_data_type(&internal_left_type) { // not supported data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let right_exprs = list .iter() @@ -255,17 +254,19 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }) .collect::>>(); match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) - } - Err(_) => Ok(expr), + Ok(right_exprs) => Ok(Transformed::yes(in_list( + internal_left, + right_exprs, + *negated, + ))), + Err(_) => Ok(Transformed::no(expr)), } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } // TODO: handle other expr type and dfs visit them - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -474,15 +475,17 @@ fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option DFSchemaRef { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6189f9a57942..0df79550f143 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,16 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules +use std::collections::{BTreeSet, HashMap}; + use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFSchemaRef}; -use datafusion_common::{DFSchema, Result}; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::is_volatile; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::utils as expr_utils; use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; + use log::{debug, trace}; -use std::collections::{BTreeSet, HashMap}; /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same @@ -101,9 +103,9 @@ pub(crate) fn is_volatile_expression(e: &Expr) -> Result { e.apply(&mut |expr| { Ok(if is_volatile(expr)? { is_volatile_expr = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) })?; Ok(is_volatile_expr) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 1f797018719b..280535f5e6be 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; use crate::{ expressions::Column, physical_expr::deduplicate_physical_exprs, @@ -22,9 +24,9 @@ use crate::{ LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_common::tree_node::TreeNode; -use datafusion_common::{tree_node::Transformed, JoinType}; -use std::sync::Arc; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::JoinType; /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by @@ -263,11 +265,12 @@ impl EquivalenceGroup { .transform(&|expr| { for cls in self.iter() { if cls.contains(&expr) { - return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + return Ok(Transformed::yes(cls.canonical_expr().unwrap())); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() .unwrap_or(expr) } @@ -458,11 +461,12 @@ impl EquivalenceGroup { column.index() + left_size, )) as _; - return Ok(Transformed::Yes(new_column)); + return Ok(Transformed::yes(new_column)); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() .unwrap(); result.add_equal_conditions(&new_lhs, &new_rhs); } @@ -477,15 +481,14 @@ impl EquivalenceGroup { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::equivalence::tests::create_test_params; use crate::equivalence::{EquivalenceClass, EquivalenceGroup}; - use crate::expressions::lit; - use crate::expressions::Column; - use crate::expressions::Literal; - use datafusion_common::Result; - use datafusion_common::ScalarValue; - use std::sync::Arc; + use crate::expressions::{lit, Column, Literal}; + + use datafusion_common::{Result, ScalarValue}; #[test] fn test_bridge_groups() -> Result<()> { diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index a31be06ecf0b..46909f23616f 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -15,18 +15,22 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use crate::expressions::Column; +use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + mod class; mod ordering; mod projection; mod properties; -use crate::expressions::Column; -use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; + pub use class::{EquivalenceClass, EquivalenceGroup}; -use datafusion_common::tree_node::{Transformed, TreeNode}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; pub use properties::{join_equivalence_properties, EquivalenceProperties}; -use std::sync::Arc; /// This function constructs a duplicate-free `LexOrderingReq` by filtering out /// duplicate entries that have same physical expression inside. For example, @@ -48,12 +52,13 @@ pub fn add_offset_to_expr( offset: usize, ) -> Arc { expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + Some(col) => Ok(Transformed::yes(Arc::new(Column::new( col.name(), offset + col.index(), )))), - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .data() .unwrap() // Note that we can safely unwrap here since our transform always returns // an `Ok` value. @@ -61,19 +66,22 @@ pub fn add_offset_to_expr( #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::expressions::{col, Column}; use crate::PhysicalSortExpr; + use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{plan_datafusion_err, Result}; + use itertools::izip; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; - use std::sync::Arc; pub fn output_schema( mapping: &ProjectionMapping, diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 30a31b0ad402..96c919667d84 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -21,7 +21,7 @@ use crate::expressions::Column; use crate::PhysicalExpr; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; /// Stores the mapping between source expressions and target expressions for a @@ -68,10 +68,11 @@ impl ProjectionMapping { let matching_input_field = input_schema.field(idx); let matching_input_column = Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) + Ok(Transformed::yes(Arc::new(matching_input_column))) } - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .data() .map(|source_expr| (source_expr, target_expr)) }) .collect::>>() @@ -108,6 +109,8 @@ impl ProjectionMapping { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, @@ -119,12 +122,13 @@ mod tests { use crate::expressions::{col, BinaryExpr, Literal}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; - use std::sync::Arc; #[test] fn project_orderings() -> Result<()> { diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 5a9a4f64876d..f234a1fa08cd 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -15,11 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::CastExpr; -use arrow_schema::SchemaRef; -use datafusion_common::{JoinSide, JoinType, Result}; -use indexmap::{IndexMap, IndexSet}; -use itertools::Itertools; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -27,7 +22,7 @@ use super::ordering::collapse_lex_ordering; use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::Literal; +use crate::expressions::{CastExpr, Literal}; use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, @@ -35,8 +30,12 @@ use crate::{ PhysicalSortRequirement, }; -use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinSide, JoinType, Result}; + +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; /// A `EquivalenceProperties` object stores useful information related to a schema. /// Currently, it keeps track of: @@ -848,6 +847,7 @@ impl EquivalenceProperties { pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new_default(expr.clone()) .transform_up(&|expr| Ok(update_ordering(expr, self))) + .data() // Guaranteed to always return `Ok`. .unwrap() } @@ -886,9 +886,9 @@ fn update_ordering( // We have a Literal, which is the other possible leaf node type: node.data = node.expr.get_ordering(&[]); } else { - return Transformed::No(node); + return Transformed::no(node); } - Transformed::Yes(node) + Transformed::yes(node) } /// This function determines whether the provided expression is constant @@ -1297,10 +1297,12 @@ mod tests { use crate::expressions::{col, BinaryExpr, Column}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{Fields, SortOptions, TimeUnit}; use datafusion_common::Result; use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; #[test] diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 6a168e2f1e5f..609349509b86 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -19,18 +19,18 @@ use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::expressions::NoOp; +use crate::expressions::{try_cast, NoOp}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; + use arrow::array::*; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{cast::as_boolean_array, internal_err, DataFusionError, Result}; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; use itertools::Itertools; @@ -414,17 +414,15 @@ pub fn case( #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - use crate::expressions::lit; - use crate::expressions::{binary, cast}; + use crate::expressions::{binary, cast, col, lit}; + use arrow::array::StringArray; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; - use datafusion_common::plan_err; - use datafusion_common::tree_node::{Transformed, TreeNode}; - use datafusion_common::ScalarValue; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; @@ -972,11 +970,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .data() .unwrap(); let expr3 = expr @@ -993,11 +992,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .data() .unwrap(); assert!(expr.ne(&expr2)); diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 567054e2b59e..39b8de81af56 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -252,7 +252,7 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - expr.with_new_children(children) + Ok(expr.with_new_children(children)?) } else { Ok(expr) } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index e14ff2692146..b8e99403d695 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -30,7 +30,7 @@ use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::Result; use datafusion_expr::Operator; @@ -130,11 +130,10 @@ pub fn get_indices_of_exprs_strict>>( pub type ExprTreeNode = ExprContext>; -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a -/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting -/// identical expressions in one node. Caller specifies the node type in the -/// DAEG via the `constructor` argument, which constructs nodes in the DAEG -/// from the [ExprTreeNode] ancillary object. +/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression +/// DAG) by collecting identical expressions in one node. Caller specifies the node type +/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from +/// the [`ExprTreeNode`] ancillary object. struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, @@ -144,16 +143,15 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter - for PhysicalExprDAEGBuilder<'a, T, F> +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> + PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. fn mutate( &mut self, mut node: ExprTreeNode, - ) -> Result> { + ) -> Result>> { // Get the expression associated with the input expression node. let expr = &node.expr; @@ -176,7 +174,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(node) + Ok(Transformed::yes(node)) } } @@ -197,7 +195,9 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let root = init + .transform_up_mut(&mut |node| builder.mutate(node)) + .data()?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -211,7 +211,7 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns.insert(column.clone()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -234,13 +234,14 @@ pub fn reassign_predicate_columns( Err(_) if ignore_not_found => usize::MAX, Err(e) => return Err(e.into()), }; - return Ok(Transformed::Yes(Arc::new(Column::new( + return Ok(Transformed::yes(Arc::new(Column::new( column.name(), index, )))); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() } /// Reverses the ORDER BY expression, which is useful during equivalent window diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index ebd92efb4cd2..4ff79cdaae70 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -182,7 +182,7 @@ mod tests { let schema = test::aggr_test_schema(); let empty = Arc::new(EmptyExec::new(schema.clone())); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); + let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?; assert_eq!(empty.schema(), empty2.schema()); let too_many_kids = vec![empty2]; diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 9a4c98927683..9824c723d9d1 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -31,7 +31,7 @@ use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, @@ -284,14 +284,16 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { - Some(transformed) => Transformed::Yes(transformed), - None => Transformed::No(p), - } + let converted_filter_expr = expr + .transform_up(&|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { + Some(transformed) => Transformed::yes(transformed), + None => Transformed::no(p), + } + }) }) - })?; + .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 3dac0107d3ef..1cb2b100e2d6 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -39,6 +39,7 @@ use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray}; use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; @@ -50,7 +51,6 @@ use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use hashbrown::raw::RawTable; @@ -475,13 +475,17 @@ fn replace_on_columns_of_right_ordering( ) -> Result<()> { for (left_col, right_col) in on_columns { for item in right_ordering.iter_mut() { - let new_expr = item.expr.clone().transform(&|e| { - if e.eq(right_col) { - Ok(Transformed::Yes(left_col.clone())) - } else { - Ok(Transformed::No(e)) - } - })?; + let new_expr = item + .expr + .clone() + .transform(&|e| { + if e.eq(right_col) { + Ok(Transformed::yes(left_col.clone())) + } else { + Ok(Transformed::no(e)) + } + }) + .data()?; item.expr = new_expr; } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index ac864668a1c8..6334a4a211d4 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -30,7 +30,6 @@ use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; use datafusion_common::utils::DataPtr; use datafusion_common::Result; use datafusion_execution::TaskContext; @@ -652,7 +651,7 @@ pub fn need_data_exchange(plan: Arc) -> bool { pub fn with_new_children_if_necessary( plan: Arc, children: Vec>, -) -> Result>> { +) -> Result> { let old_children = plan.children(); if children.len() != old_children.len() { internal_err!("Wrong number of children") @@ -662,9 +661,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - Ok(Transformed::Yes(plan.with_new_children(children)?)) + plan.with_new_children(children) } else { - Ok(Transformed::No(plan)) + Ok(plan) } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 37d209a3b473..3880cf3d77af 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -184,8 +184,7 @@ mod tests { let placeholder = Arc::new(PlaceholderRowExec::new(schema)); - let placeholder_2 = - with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + let placeholder_2 = with_new_children_if_necessary(placeholder.clone(), vec![])?; assert_eq!(placeholder.schema(), placeholder_2.schema()); let too_many_kids = vec![placeholder_2]; diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 9786b1cbf6fd..2e4b97bc224b 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -30,7 +30,7 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -317,16 +317,17 @@ fn assign_work_table( ) } else { work_table_refs += 1; - Ok(Transformed::Yes(Arc::new( + Ok(Transformed::yes(Arc::new( exec.with_work_table(work_table.clone()), ))) } } else if plan.as_any().is::() { not_impl_err!("Recursive queries cannot be nested") } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .data() } impl Stream for RecursiveQueryStream { diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index b8a5f95c5325..52a52f81bdaf 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::{displayable, with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode, Transformed}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; use datafusion_common::Result; impl DynTreeNode for dyn ExecutionPlan { @@ -35,7 +35,7 @@ impl DynTreeNode for dyn ExecutionPlan { arc_self: Arc, new_children: Vec>, ) -> Result> { - with_new_children_if_necessary(arc_self, new_children).map(Transformed::into) + with_new_children_if_necessary(arc_self, new_children) } } @@ -63,7 +63,7 @@ impl PlanContext { pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); - self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + self.plan = with_new_children_if_necessary(self.plan, children_plans)?; Ok(self) } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 0dc1258ebabe..abb896ab113e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -17,34 +17,36 @@ //! SQL Utility Functions +use std::collections::HashMap; + use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use sqlparser::ast::Ident; - -use datafusion_common::{exec_err, internal_err, plan_err}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + exec_err, internal_err, plan_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{Alias, GroupingSet, WindowFunction}; -use datafusion_expr::expr_vec_fmt; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; -use datafusion_expr::{Expr, LogicalPlan}; -use std::collections::HashMap; +use datafusion_expr::{expr_vec_fmt, Expr, LogicalPlan}; +use sqlparser::ast::Ident; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - expr.clone().transform_up(&|nested_expr| { - match nested_expr { - Expr::Column(col) => { - let field = plan.schema().field_from_column(&col)?; - Ok(Transformed::Yes(Expr::Column(field.qualified_column()))) - } - _ => { - // keep recursing - Ok(Transformed::No(nested_expr)) + expr.clone() + .transform_up(&|nested_expr| { + match nested_expr { + Expr::Column(col) => { + let field = plan.schema().field_from_column(&col)?; + Ok(Transformed::yes(Expr::Column(field.qualified_column()))) + } + _ => { + // keep recursing + Ok(Transformed::no(nested_expr)) + } } - } - }) + }) + .data() } /// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s. @@ -66,13 +68,15 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_down(&|nested_expr| { - if base_exprs.contains(&nested_expr) { - Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) - } else { - Ok(Transformed::No(nested_expr)) - } - }) + expr.clone() + .transform_down(&|nested_expr| { + if base_exprs.contains(&nested_expr) { + Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?)) + } else { + Ok(Transformed::no(nested_expr)) + } + }) + .data() } /// Determines if the set of `Expr`'s are a valid projection on the input @@ -170,16 +174,18 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - expr.clone().transform_up(&|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::Yes(aliased_expr.clone())) - } else { - Ok(Transformed::No(Expr::Column(c))) + expr.clone() + .transform_up(&|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::yes(aliased_expr.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) + } } - } - _ => Ok(Transformed::No(nested_expr)), - }) + _ => Ok(Transformed::no(nested_expr)), + }) + .data() } /// given a slice of window expressions sharing the same sort key, find their common partition diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 79877fa421e3..906926a5a9ab 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -3214,7 +3214,7 @@ JOIN sales_global AS e ON s.currency = e.currency AND s.ts >= e.ts GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency -ORDER BY s.sn +ORDER BY s.sn, s.zip_code ---- 0 GRC 0 2022-01-01T06:00:00 EUR 30 0 GRC 4 2022-01-03T10:00:00 EUR 80 diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index b7e9248a7c1f..a839420aa5b2 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -92,7 +92,7 @@ In our example, we'll use rewriting to update our `add_one` UDF, to be rewritten ### Rewriting with `transform` -To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::No` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::Yes` is used to wrap the new `Expr`. +To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::no` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::yes` is used to wrap the new `Expr`. ```rust fn rewrite_add_one(expr: Expr) -> Result { @@ -102,9 +102,9 @@ fn rewrite_add_one(expr: Expr) -> Result { let input_arg = scalar_fun.args[0].clone(); let new_expression = input_arg + lit(1i64); - Transformed::Yes(new_expression) + Transformed::yes(new_expression) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) }