From aef232b1ac559fc1597a64d9f5f75c2f29f4c286 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 20 Nov 2024 08:29:54 +0100 Subject: [PATCH] Add `Container` trait and to simplify `Expr` and `LogicalPlan` apply and map methods (#13467) * Add `Container` trait and its blanket implementations, remove `map_until_stop_and_collect` macro, simplify apply and map logic with `Container`s where possible * fix clippy * rename `Container` to `TreeNodeContainer` * add docs to containers * clarify when we need a temporary `TreeNodeRefContainer` * code and docs cleanup --- datafusion/common/src/tree_node.rs | 363 ++++++++++++++--- datafusion/expr/src/expr.rs | 36 +- datafusion/expr/src/logical_plan/ddl.rs | 50 ++- datafusion/expr/src/logical_plan/plan.rs | 20 +- datafusion/expr/src/logical_plan/statement.rs | 51 +-- datafusion/expr/src/logical_plan/tree_node.rs | 347 +++++++--------- datafusion/expr/src/tree_node.rs | 372 ++++++------------ .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/sql/src/unparser/rewrite.rs | 24 +- 9 files changed, 687 insertions(+), 580 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c8ec7f18339a..0c153583e34b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -17,11 +17,12 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees +use crate::Result; use recursive::recursive; +use std::collections::HashMap; +use std::hash::Hash; use std::sync::Arc; -use crate::Result; - /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ @@ -769,6 +770,297 @@ impl Transformed { } } +/// [`TreeNodeContainer`] contains elements that a function can be applied on or mapped. +/// The elements of the container are siblings so the continuation rules are similar to +/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`]. +pub trait TreeNodeContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result; + + /// Maps all elements of the container with `f`. + /// This method is usually called from [`TreeNode::map_children`] implementations as + /// a node is actually a container of the node's children. + fn map_elements Result>>( + self, + f: F, + ) -> Result>; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> for Arc { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + Arc::unwrap_or_clone(self) + .map_elements(f)? + .map_data(|c| Ok(Arc::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Option { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + match self { + Some(t) => t.apply_elements(f), + None => Ok(TreeNodeRecursion::Continue), + } + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.map_or(Ok(Transformed::no(None)), |c| { + c.map_elements(f)?.map_data(|c| Ok(Some(c))) + }) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Vec { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|c| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(c), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> + for HashMap +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self.values() { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|(k, c)| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + (k, result.data) + }) + } + TreeNodeRecursion::Stop => Ok((k, c)), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeContainer<'a, T> for (C0, C1) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1)))? + .transform_sibling(|(new_c0, c1)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1))) + }) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeContainer<'a, T> for (C0, C1, C2) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1, self.2)))? + .transform_sibling(|(new_c0, c1, c2)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1, c2))) + })? + .transform_sibling(|(new_c0, new_c1, c2)| { + c2.map_elements(&mut f)? + .map_data(|new_c2| Ok((new_c0, new_c1, new_c2))) + }) + } +} + +/// [`TreeNodeRefContainer`] contains references to elements that a function can be +/// applied on. The elements of the container are siblings so the continuation rules are +/// similar to [`TreeNodeRecursion::visit_sibling`]. +/// +/// This container is similar to [`TreeNodeContainer`], but the lifetime of the reference +/// elements (`T`) are not derived from the container's lifetime. +/// A typical usage of this container is in `Expr::apply_children` when we need to +/// construct a temporary container to be able to call `apply_ref_elements` on a +/// collection of tree node references. But in that case the container's temporary +/// lifetime is different to the lifetime of tree nodes that we put into it. +/// Please find an example usecase in `Expr::apply_children` with the `Expr::Case` case. +/// +/// Most of the cases we don't need to create a temporary container with +/// `TreeNodeRefContainer`, but we can just call `TreeNodeContainer::apply_elements`. +/// Please find an example usecase in `Expr::apply_children` with the `Expr::GroupingSet` +/// case. +pub trait TreeNodeRefContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_ref_elements Result>( + &self, + f: F, + ) -> Result; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for Vec<&'a C> { + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } +} + /// Transformation helper to process a sequence of iterable tree nodes that are siblings. pub trait TreeNodeIterator: Iterator { /// Apples `f` to each item in this iterator @@ -843,50 +1135,6 @@ impl TreeNodeIterator for I { } } -/// Transformation helper to process a heterogeneous sequence of tree node containing -/// expressions. -/// -/// This macro is very similar to [TreeNodeIterator::map_until_stop_and_collect] to -/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and -/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its -/// transformation (`F`). -/// -/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the -/// first element and further elements from the sequence of pairs. An element from a pair -/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on -/// the `Transformed.tnr` result of previous `F`s (`F0` initially). -/// -/// # Returns -/// Error if any of the transformations returns an error -/// -/// Ok(Transformed<(data0, ..., dataN)>) such that: -/// 1. `transformed` is true if any of the transformations had transformed true -/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and -/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F` -/// 3. `tnr` from `F0` or the last invocation of `F` -#[macro_export] -macro_rules! map_until_stop_and_collect { - ($F0:expr, $($EXPR:expr, $F:expr),*) => {{ - $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| { - let all_datas = ( - data0, - $( - if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump { - $F.map(|result| { - tnr = result.tnr; - transformed |= result.transformed; - result.data - })? - } else { - $EXPR - }, - )* - ); - Ok(Transformed::new(all_datas, transformed, tnr)) - }) - }} -} - /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. /// /// # Example @@ -1021,7 +1269,7 @@ pub(crate) mod tests { use std::fmt::Display; use crate::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use crate::Result; @@ -1054,7 +1302,7 @@ pub(crate) mod tests { &'n self, f: F, ) -> Result { - self.children.iter().apply_until_stop(f) + self.children.apply_elements(f) } fn map_children Result>>( @@ -1063,8 +1311,7 @@ pub(crate) mod tests { ) -> Result> { Ok(self .children - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|new_children| Self { children: new_children, ..self @@ -1072,6 +1319,22 @@ pub(crate) mod tests { } } + impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } + } + // J // | // I diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 83d35c3d25b1..8490c08a70bb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,7 +32,7 @@ use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference, @@ -351,6 +351,22 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr { } } +impl<'a> TreeNodeContainer<'a, Self> for Expr { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -653,6 +669,24 @@ impl Display for Sort { } } +impl<'a> TreeNodeContainer<'a, Expr> for Sort { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.expr + .map_elements(f)? + .map_data(|expr| Ok(Self { expr, ..self })) + } +} + /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 93e8b5fd045e..8c64a017988e 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -26,7 +26,10 @@ use std::{ use crate::expr::Sort; use arrow::datatypes::DataType; -use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; +use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; +use datafusion_common::{ + Constraints, DFSchemaRef, Result, SchemaReference, TableReference, +}; use sqlparser::ast::Ident; /// Various types of DDL (CREATE / DROP) catalog manipulation @@ -487,6 +490,28 @@ pub struct OperateFunctionArg { pub data_type: DataType, pub default_expr: Option, } + +impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.default_expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.default_expr.map_elements(f)?.map_data(|default_expr| { + Ok(Self { + default_expr, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct CreateFunctionBody { /// LANGUAGE lang_name @@ -497,6 +522,29 @@ pub struct CreateFunctionBody { pub function_body: Option, } +impl<'a> TreeNodeContainer<'a, Expr> for CreateFunctionBody { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.function_body.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.function_body + .map_elements(f)? + .map_data(|function_body| { + Ok(Self { + function_body, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct DropFunction { pub name: String, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6ee99b22c7f3..e9f4f1f80972 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,7 +45,9 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, +}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, @@ -287,6 +289,22 @@ impl Default for LogicalPlan { } } +impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 05e2b1af14d3..26df379f5e4a 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -16,12 +16,10 @@ // under the License. use arrow::datatypes::DataType; -use datafusion_common::tree_node::{Transformed, TreeNodeIterator}; -use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_common::{DFSchema, DFSchemaRef}; use std::fmt::{self, Display}; use std::sync::{Arc, OnceLock}; -use super::tree_node::rewrite_arc; use crate::{expr_vec_fmt, Expr, LogicalPlan}; /// Statements have a unchanging empty schema. @@ -80,53 +78,6 @@ impl Statement { } } - /// Rewrites input LogicalPlans in the current `Statement` using `f`. - pub(super) fn map_inputs< - F: FnMut(LogicalPlan) -> Result>, - >( - self, - f: F, - ) -> Result> { - match self { - Statement::Prepare(Prepare { - input, - name, - data_types, - }) => Ok(rewrite_arc(input, f)?.update_data(|input| { - Statement::Prepare(Prepare { - input, - name, - data_types, - }) - })), - _ => Ok(Transformed::no(self)), - } - } - - /// Returns a iterator over all expressions in the current `Statement`. - pub(super) fn expression_iter(&self) -> impl Iterator { - match self { - Statement::Execute(Execute { parameters, .. }) => parameters.iter(), - _ => [].iter(), - } - } - - /// Rewrites all expressions in the current `Statement` using `f`. - pub(super) fn map_expressions Result>>( - self, - f: F, - ) -> Result> { - match self { - Statement::Execute(Execute { name, parameters }) => Ok(parameters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|parameters| { - Statement::Execute(Execute { parameters, name }) - })), - _ => Ok(Transformed::no(self)), - } - } - /// Return a `format`able structure with the a human readable /// description of this LogicalPlan node per node, not including /// children. diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index e7dfe8791924..6850c30f4f81 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -36,32 +36,30 @@ //! (Re)creation APIs (these require substantial cloning and thus are slow): //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions + use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, - Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, Join, Limit, - LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort, Subquery, - SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, + Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, + Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, + Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, + UserDefinedLogicalNode, Values, Window, }; +use datafusion_common::tree_node::TreeNodeRefContainer; use recursive::recursive; -use std::ops::Deref; -use std::sync::Arc; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; -use datafusion_common::{ - internal_err, map_until_stop_and_collect, DataFusionError, Result, + Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; +use datafusion_common::{internal_err, Result}; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, f: F, ) -> Result { - self.inputs().into_iter().apply_until_stop(f) + self.inputs().apply_ref_elements(f) } /// Applies `f` to each child (input) of this plan node, rewriting them *in place.* @@ -74,14 +72,14 @@ impl TreeNode for LogicalPlan { /// [`Expr::Exists`]: crate::Expr::Exists fn map_children Result>>( self, - mut f: F, + f: F, ) -> Result> { Ok(match self { LogicalPlan::Projection(Projection { expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Projection(Projection { expr, input, @@ -92,7 +90,7 @@ impl TreeNode for LogicalPlan { predicate, input, having, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Filter(Filter { predicate, input, @@ -102,7 +100,7 @@ impl TreeNode for LogicalPlan { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -112,7 +110,7 @@ impl TreeNode for LogicalPlan { input, window_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Window(Window { input, window_expr, @@ -124,7 +122,7 @@ impl TreeNode for LogicalPlan { group_expr, aggr_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Aggregate(Aggregate { input, group_expr, @@ -132,7 +130,8 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => rewrite_arc(input, f)? + LogicalPlan::Sort(Sort { expr, input, fetch }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Join(Join { left, @@ -143,12 +142,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { + }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, right, @@ -160,12 +154,13 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? + LogicalPlan::Limit(Limit { skip, fetch, input }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, - }) => rewrite_arc(subquery, f)?.update_data(|subquery| { + }) => subquery.map_elements(f)?.update_data(|subquery| { LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, @@ -175,7 +170,7 @@ impl TreeNode for LogicalPlan { input, alias, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, @@ -184,17 +179,18 @@ impl TreeNode for LogicalPlan { }), LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)? .update_data(LogicalPlan::Extension), - LogicalPlan::Union(Union { inputs, schema }) => rewrite_arcs(inputs, f)? + LogicalPlan::Union(Union { inputs, schema }) => inputs + .map_elements(f)? .update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })), LogicalPlan::Distinct(distinct) => match distinct { - Distinct::All(input) => rewrite_arc(input, f)?.update_data(Distinct::All), + Distinct::All(input) => input.map_elements(f)?.update_data(Distinct::All), Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { Distinct::On(DistinctOn { on_expr, select_expr, @@ -211,7 +207,7 @@ impl TreeNode for LogicalPlan { stringified_plans, schema, logical_optimization_succeeded, - }) => rewrite_arc(plan, f)?.update_data(|plan| { + }) => plan.map_elements(f)?.update_data(|plan| { LogicalPlan::Explain(Explain { verbose, plan, @@ -224,7 +220,7 @@ impl TreeNode for LogicalPlan { verbose, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Analyze(Analyze { verbose, input, @@ -237,7 +233,7 @@ impl TreeNode for LogicalPlan { op, input, output_schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, table_schema, @@ -252,7 +248,7 @@ impl TreeNode for LogicalPlan { partition_by, file_type, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, output_url, @@ -271,7 +267,7 @@ impl TreeNode for LogicalPlan { or_replace, column_defaults, temporary, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, constraints, @@ -288,7 +284,7 @@ impl TreeNode for LogicalPlan { or_replace, definition, temporary, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateView(CreateView { name, input, @@ -318,7 +314,7 @@ impl TreeNode for LogicalPlan { dependency_indices, schema, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Unnest(Unnest { input, exec_columns: input_columns, @@ -334,22 +330,24 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, - }) => map_until_stop_and_collect!( - rewrite_arc(static_term, &mut f), - recursive_term, - rewrite_arc(recursive_term, &mut f) - )? - .update_data(|(static_term, recursive_term)| { - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term, - recursive_term, - is_distinct, - }) - }), - LogicalPlan::Statement(stmt) => { - stmt.map_inputs(f)?.update_data(LogicalPlan::Statement) + }) => (static_term, recursive_term).map_elements(f)?.update_data( + |(static_term, recursive_term)| { + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }) + }, + ), + LogicalPlan::Statement(stmt) => match stmt { + Statement::Prepare(p) => p + .input + .map_elements(f)? + .update_data(|input| Statement::Prepare(Prepare { input, ..p })), + _ => Transformed::no(stmt), } + .update_data(LogicalPlan::Statement), // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -359,24 +357,6 @@ impl TreeNode for LogicalPlan { } } -/// Applies `f` to rewrite a `Arc` without copying, if possible -pub(super) fn rewrite_arc Result>>( - plan: Arc, - mut f: F, -) -> Result>> { - f(Arc::unwrap_or_clone(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan))) -} - -/// rewrite a `Vec` of `Arc` without copying, if possible -fn rewrite_arcs Result>>( - input_plans: Vec>, - mut f: F, -) -> Result>>> { - input_plans - .into_iter() - .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f)) -} - /// Rewrites all inputs for an Extension node "in place" /// (it currently has to copy values because there are no APIs for in place modification) /// @@ -423,54 +403,40 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { - LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().apply_until_stop(f) - } - LogicalPlan::Values(Values { values, .. }) => values - .iter() - .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), + LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), + LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { - expr.iter().apply_until_stop(f) + expr.apply_elements(f) } Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().apply_until_stop(f) + window_expr.apply_elements(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr - .iter() - .chain(aggr_expr.iter()) - .apply_until_stop(f), + }) => (group_expr, aggr_expr).apply_ref_elements(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { - on.iter() - // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... - // it not ideal to create an expr here to analyze them, but could cache it on the Join itself - .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .apply_until_stop(|e| f(&e))? - .visit_sibling(|| filter.iter().apply_until_stop(f)) - } - LogicalPlan::Sort(Sort { expr, .. }) => { - expr.iter().apply_until_stop(|sort| f(&sort.expr)) + (on, filter).apply_ref_elements(f) } + LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().apply_until_stop(f) + extension.node.expressions().apply_elements(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().apply_until_stop(f) + filters.apply_elements(f) } LogicalPlan::Unnest(unnest) => { let columns = unnest.exec_columns.clone(); @@ -479,24 +445,23 @@ impl LogicalPlan { .iter() .map(|c| Expr::Column(c.clone())) .collect::>(); - exprs.iter().apply_until_stop(f) + exprs.apply_elements(f) } LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, .. - })) => on_expr - .iter() - .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) - .apply_until_stop(f), - LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip - .iter() - .chain(fetch.iter()) - .map(|e| e.deref()) - .apply_until_stop(f), - LogicalPlan::Statement(stmt) => stmt.expression_iter().apply_until_stop(f), + })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + (skip, fetch).apply_ref_elements(f) + } + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(Execute { parameters, .. }) => { + parameters.apply_elements(f) + } + _ => Ok(TreeNodeRecursion::Continue), + }, // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -529,21 +494,15 @@ impl LogicalPlan { expr, input, schema, - }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) - }), + }) => expr.map_elements(f)?.update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), LogicalPlan::Values(Values { schema, values }) => values - .into_iter() - .map_until_stop_and_collect(|value| { - value.into_iter().map_until_stop_and_collect(&mut f) - })? + .map_elements(f)? .update_data(|values| LogicalPlan::Values(Values { schema, values })), LogicalPlan::Filter(Filter { predicate, @@ -561,12 +520,10 @@ impl LogicalPlan { partitioning_scheme, }) => match partitioning_scheme { Partitioning::Hash(expr, usize) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|expr| Partitioning::Hash(expr, usize)), Partitioning::DistributeBy(expr) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(Partitioning::DistributeBy), Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), } @@ -580,34 +537,28 @@ impl LogicalPlan { input, window_expr, schema, - }) => window_expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|window_expr| { - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) - }), + }) => window_expr.map_elements(f)?.update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), LogicalPlan::Aggregate(Aggregate { input, group_expr, aggr_expr, schema, - }) => map_until_stop_and_collect!( - group_expr.into_iter().map_until_stop_and_collect(&mut f), - aggr_expr, - aggr_expr.into_iter().map_until_stop_and_collect(&mut f) - )? - .update_data(|(group_expr, aggr_expr)| { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) - }), + }) => (group_expr, aggr_expr).map_elements(f)?.update_data( + |(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }, + ), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. @@ -621,16 +572,7 @@ impl LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - on.into_iter().map_until_stop_and_collect( - |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) - ), - filter, - filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(on, filter)| { + }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, right, @@ -642,17 +584,13 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => { - transform_sort_vec(expr, &mut f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) - } + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .map_elements(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - let exprs = node - .expressions() - .into_iter() - .map_until_stop_and_collect(f)?; + let exprs = node.expressions().map_elements(f)?; let plan = LogicalPlan::Extension(Extension { node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), @@ -669,64 +607,47 @@ impl LogicalPlan { projected_schema, filters, fetch, - }) => filters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|filters| { - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) - }), + }) => filters.map_elements(f)?.update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - })) => map_until_stop_and_collect!( - on_expr.into_iter().map_until_stop_and_collect(&mut f), - select_expr, - select_expr.into_iter().map_until_stop_and_collect(&mut f), - sort_expr, - transform_sort_option_vec(sort_expr, &mut f) - )? - .update_data(|(on_expr, select_expr, sort_expr)| { - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) - }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => { - let skip = skip.map(|e| *e); - let fetch = fetch.map(|e| *e); - map_until_stop_and_collect!( - skip.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }), - fetch, - fetch.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(skip, fetch)| { - LogicalPlan::Limit(Limit { - skip: skip.map(Box::new), - fetch: fetch.map(Box::new), + })) => (on_expr, select_expr, sort_expr) + .map_elements(f)? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, input, - }) + schema, + })) + }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { skip, fetch, input }) }) } - LogicalPlan::Statement(stmt) => { - stmt.map_expressions(f)?.update_data(LogicalPlan::Statement) + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(e) => { + e.parameters.map_elements(f)?.update_data(|parameters| { + Statement::Execute(Execute { parameters, ..e }) + }) + } + _ => Transformed::no(stmt), } + .update_data(LogicalPlan::Statement), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index e964091aae66..eacace5ed046 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -19,14 +19,14 @@ use crate::expr::{ AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }; -use datafusion_common::{map_until_stop_and_collect, Result}; +use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -42,9 +42,9 @@ impl TreeNode for Expr { &'n self, f: F, ) -> Result { - let children = match self { - Expr::Alias(Alias{expr,..}) - | Expr::Unnest(Unnest{expr}) + match self { + Expr::Alias(Alias { expr, .. }) + | Expr::Unnest(Unnest { expr }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -57,78 +57,50 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], + | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.iter().collect() + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), + Expr::ScalarFunction(ScalarFunction { args, .. }) => { + args.apply_elements(f) } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.iter().flatten().collect() + lists_of_exprs.apply_elements(f) } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) - | Expr::Exists {..} + | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref(), right.as_ref()] + (left, right).apply_ref_elements(f) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref(), pattern.as_ref()] + (expr, pattern).apply_ref_elements(f) } Expr::Between(Between { - expr, low, high, .. - }) => vec![expr.as_ref(), low.as_ref(), high.as_ref()], - Expr::Case(case) => { - let mut expr_vec = vec![]; - if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref()); - }; - for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref()); - expr_vec.push(then.as_ref()); - } - if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref()); - } - expr_vec - } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - => { - let mut expr_vec = args.iter().collect::>(); - if let Some(f) = filter { - expr_vec.push(f.as_ref()); - } - if let Some(order_by) = order_by { - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - } - expr_vec - } + expr, low, high, .. + }) => (expr, low, high).apply_ref_elements(f), + Expr::Case(Case { expr, when_then_expr, else_expr }) => + (expr, when_then_expr, else_expr).apply_ref_elements(f), + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let mut expr_vec = args.iter().collect::>(); - expr_vec.extend(partition_by); - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - expr_vec + args, + partition_by, + order_by, + .. + }) => { + (args, partition_by, order_by).apply_ref_elements(f) } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![expr.as_ref()]; - expr_vec.extend(list); - expr_vec + (expr, list).apply_ref_elements(f) } - }; - - children.into_iter().apply_until_stop(f) + } } /// Maps each child of `self` using the provided closure `f`. @@ -148,137 +120,103 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_) => Transformed::no(self), - Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))), + Expr::Unnest(Unnest { expr, .. }) => expr + .map_elements(f)? + .update_data(|expr| Expr::Unnest(Unnest { expr })), Expr::Alias(Alias { expr, relation, name, - }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))), + }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => transform_box(expr, &mut f)?.update_data(|be| { + }) => expr.map_elements(f)?.update_data(|be| { Expr::InSubquery(InSubquery::new(be, subquery, negated)) }), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - map_until_stop_and_collect!( - transform_box(left, &mut f), - right, - transform_box(right, &mut f) - )? + Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right) + .map_elements(f)? .update_data(|(new_left, new_right)| { Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) - }) - } + }), Expr::Like(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::Like(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) + } Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::SimilarTo(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), - Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), - Expr::IsNotNull(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) - } - Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull), - Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue), - Expr::IsFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsFalse) - } - Expr::IsUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsUnknown) - } - Expr::IsNotTrue(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue) - } - Expr::IsNotFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse) + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) } + Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not), + Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull), + Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull), + Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue), + Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse), + Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown), + Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue), + Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse), Expr::IsNotUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown) - } - Expr::Negative(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::Negative) + expr.map_elements(f)?.update_data(Expr::IsNotUnknown) } + Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative), Expr::Between(Between { expr, negated, low, high, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - low, - transform_box(low, &mut f), - high, - transform_box(high, &mut f) - )? - .update_data(|(new_expr, new_low, new_high)| { - Expr::Between(Between::new(new_expr, negated, new_low, new_high)) - }), + }) => (expr, low, high).map_elements(f)?.update_data( + |(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }, + ), Expr::Case(Case { expr, when_then_expr, else_expr, - }) => map_until_stop_and_collect!( - transform_option_box(expr, &mut f), - when_then_expr, - when_then_expr - .into_iter() - .map_until_stop_and_collect(|(when, then)| { - map_until_stop_and_collect!( - transform_box(when, &mut f), - then, - transform_box(then, &mut f) - ) - }), - else_expr, - transform_option_box(else_expr, &mut f) - )? - .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { - Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) - }), - Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + }) => (expr, when_then_expr, else_expr) + .map_elements(f)? + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + Expr::TryCast(TryCast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::ScalarFunction(ScalarFunction { func, args }) => { - transform_vec(args, &mut f)?.map_data(|new_args| { + args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( func, new_args, ))) @@ -291,22 +229,17 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - partition_by, - transform_vec(partition_by, &mut f), - order_by, - transform_sort_vec(order_by, &mut f) - )? - .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }), + }) => (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() + }, + ), Expr::AggregateFunction(AggregateFunction { args, func, @@ -314,31 +247,27 @@ impl TreeNode for Expr { filter, order_by, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - filter, - transform_option_box(filter, &mut f), - order_by, - transform_sort_option_vec(order_by, &mut f) - )? - .map_data(|(new_args, new_filter, new_order_by)| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - func, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - })?, + }) => (args, filter, order_by).map_elements(f)?.map_data( + |(new_args, new_filter, new_order_by)| { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + new_args, + distinct, + new_filter, + new_order_by, + null_treatment, + ))) + }, + )?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Rollup(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), - GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Cube(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs - .into_iter() - .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))? + .map_elements(f)? .update_data(|new_lists_of_exprs| { Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) }), @@ -347,70 +276,11 @@ impl TreeNode for Expr { expr, list, negated, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - list, - transform_vec(list, &mut f) - )? - .update_data(|(new_expr, new_list)| { - Expr::InList(InList::new(new_expr, new_list, negated)) - }), + }) => (expr, list) + .map_elements(f)? + .update_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), }) } } - -/// Transforms a boxed expression by applying the provided closure `f`. -fn transform_box Result>>( - be: Box, - f: &mut F, -) -> Result>> { - Ok(f(*be)?.update_data(Box::new)) -} - -/// Transforms an optional boxed expression by applying the provided closure `f`. -fn transform_option_box Result>>( - obe: Option>, - f: &mut F, -) -> Result>>> { - obe.map_or(Ok(Transformed::no(None)), |be| { - Ok(transform_box(be, f)?.update_data(Some)) - }) -} - -/// &mut transform a Option<`Vec` of `Expr`s> -pub fn transform_option_vec Result>>( - ove: Option>, - f: &mut F, -) -> Result>>> { - ove.map_or(Ok(Transformed::no(None)), |ve| { - Ok(transform_vec(ve, f)?.update_data(Some)) - }) -} - -/// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>>( - ve: Vec, - f: &mut F, -) -> Result>> { - ve.into_iter().map_until_stop_and_collect(f) -} - -/// Transforms an optional vector of sort expressions by applying the provided closure `f`. -pub fn transform_sort_option_vec Result>>( - sorts_option: Option>, - f: &mut F, -) -> Result>>> { - sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { - Ok(transform_sort_vec(sorts, f)?.update_data(Some)) - }) -} - -/// Transforms an vector of sort expressions by applying the provided closure `f`. -pub fn transform_sort_vec Result>>( - sorts: Vec, - f: &mut F, -) -> Result>> { - sorts.into_iter().map_until_stop_and_collect(|s| { - Ok(f(s.expr)?.update_data(|e| Sort { expr: e, ..s })) - }) -} diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b659e477f67e..1519c54dbf68 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -39,7 +39,7 @@ use datafusion_expr::{ use crate::optimize_projections::required_indices::RequiredIndicies; use crate::utils::NamePreserver; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; /// Optimizer rule to prune unnecessary columns from intermediate schemas @@ -484,7 +484,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result Result /// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { - let sort_exprs = transform_sort_vec(exprs, &mut |expr| { - expr.transform_up(|expr| { - if let Expr::Column(mut col) = expr { - col.relation = None; - Ok(Transformed::yes(Expr::Column(col))) - } else { - Ok(Transformed::no(expr)) - } + let sort_exprs = exprs + .map_elements(&mut |expr: Expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } + }) }) - }) - .data()?; + .data()?; Ok(sort_exprs) }