From 47df15a16d670d3a7190cd74b3c76c5dc9b35667 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 5 Nov 2024 10:23:41 +0100 Subject: [PATCH] Remove `Expr` clones from `SortExpr`s --- .../tests/user_defined/user_defined_plan.rs | 3 +- datafusion/expr/src/expr.rs | 8 ++++++ datafusion/expr/src/logical_plan/plan.rs | 7 +++-- datafusion/expr/src/tree_node.rs | 28 +++---------------- .../optimizer/src/common_subexpr_eliminate.rs | 19 ++++++++++--- 5 files changed, 33 insertions(+), 32 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index c96256784402..520a91aeb4d6 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -97,7 +97,6 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::tree_node::replace_sort_expression; use datafusion_expr::{FetchType, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; @@ -440,7 +439,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { Ok(Self { k: self.k, input: inputs.swap_remove(0), - expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)), + expr: self.expr.with_expr(exprs.swap_remove(0)), }) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d3a3852a1eaa..0818c2062ea3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -629,6 +629,14 @@ impl Sort { nulls_first: !self.nulls_first, } } + + pub fn with_expr(&self, expr: Expr) -> Self { + Self { + expr, + asc: self.asc, + nulls_first: self.nulls_first, + } + } } impl Display for Sort { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 191a42e38e3a..ea8fca3ec9d6 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -56,7 +56,6 @@ use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; -use crate::tree_node::replace_sort_expressions; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -866,7 +865,11 @@ impl LogicalPlan { }) => { let input = self.only_input(inputs)?; Ok(LogicalPlan::Sort(Sort { - expr: replace_sort_expressions(sort_expr.clone(), expr), + expr: expr + .into_iter() + .zip(sort_expr.iter()) + .map(|(expr, sort)| sort.with_expr(expr)) + .collect(), input: Arc::new(input), fetch: *fetch, })) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 90afe5722abb..e964091aae66 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -408,29 +408,9 @@ pub fn transform_sort_option_vec Result>>( /// Transforms an vector of sort expressions by applying the provided closure `f`. pub fn transform_sort_vec Result>>( sorts: Vec, - mut f: &mut F, + f: &mut F, ) -> Result>> { - Ok(sorts - .iter() - .map(|sort| sort.expr.clone()) - .map_until_stop_and_collect(&mut f)? - .update_data(|transformed_exprs| { - replace_sort_expressions(sorts, transformed_exprs) - })) -} - -pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { - assert_eq!(sorts.len(), new_expr.len()); - sorts - .into_iter() - .zip(new_expr) - .map(|(sort, expr)| replace_sort_expression(sort, expr)) - .collect() -} - -pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort { - Sort { - expr: new_expr, - ..sort - } + sorts.into_iter().map_until_stop_and_collect(|s| { + Ok(f(s.expr)?.update_data(|e| Sort { expr: e, ..s })) + }) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4fe22d252744..53a0453d8001 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -34,8 +34,7 @@ use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::tree_node::replace_sort_expressions; -use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator}; +use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr}; const CSE_PREFIX: &str = "__common_expr"; @@ -91,6 +90,7 @@ impl CommonSubexprEliminate { .map(LogicalPlan::Projection) }) } + fn try_optimize_sort( &self, sort: Sort, @@ -98,12 +98,23 @@ impl CommonSubexprEliminate { ) -> Result> { let Sort { expr, input, fetch } = sort; let input = Arc::unwrap_or_clone(input); - let sort_expressions = expr.iter().map(|sort| sort.expr.clone()).collect(); + let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr + .into_iter() + .map(|sort| (sort.expr, (sort.asc, sort.nulls_first))) + .unzip(); let new_sort = self .try_unary_plan(sort_expressions, input, config)? .update_data(|(new_expr, new_input)| { LogicalPlan::Sort(Sort { - expr: replace_sort_expressions(expr, new_expr), + expr: new_expr + .into_iter() + .zip(sort_params) + .map(|(expr, (asc, nulls_first))| SortExpr { + expr, + asc, + nulls_first, + }) + .collect(), input: Arc::new(new_input), fetch, })