diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7abe6b70b64e4..52ac5daa135d1 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -870,37 +870,7 @@ impl LogicalPlan { LogicalPlan::Filter { .. } => { assert_eq!(1, expr.len()); let predicate = expr.pop().unwrap(); - - // filter predicates should not contain aliased expressions so we remove any aliases - // before this logic was added we would have aliases within filters such as for - // benchmark q6: - // - // lineitem.l_shipdate >= Date32(\"8766\") - // AND lineitem.l_shipdate < Date32(\"9131\") - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= - // Decimal128(Some(49999999999999),30,15) - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= - // Decimal128(Some(69999999999999),30,15) - // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - - let predicate = predicate - .transform_down(|expr| { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) - } - Expr::Alias(_) => Ok(Transformed::new( - expr.unalias(), - true, - TreeNodeRecursion::Jump, - )), - _ => Ok(Transformed::no(expr)), - } - }) - .data()?; + let predicate = Filter::remove_aliases(predicate)?.data; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) @@ -2230,6 +2200,40 @@ impl Filter { } false } + + /// Remove aliases from a predicate for use in a `Filter` + /// + /// filter predicates should not contain aliased expressions so we remove + /// any aliases. + /// + /// before this logic was added we would have aliases within filters such as + /// for benchmark q6: + /// + /// ```sql + /// lineitem.l_shipdate >= Date32(\"8766\") + /// AND lineitem.l_shipdate < Date32(\"9131\") + /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= + /// Decimal128(Some(49999999999999),30,15) + /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= + /// Decimal128(Some(69999999999999),30,15) + /// AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + /// ``` + pub fn remove_aliases(predicate: Expr) -> Result> { + predicate.transform_down(|expr| { + match expr { + Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, + TreeNodeRecursion::Jump, + )), + _ => Ok(Transformed::no(expr)), + } + }) + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6820ba04f0e90..cfee867f56e26 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,8 +20,9 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; use arrow::datatypes::{DataType, Field}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, @@ -31,7 +32,10 @@ use datafusion_common::{ internal_err, qualified_name, Column, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; -use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; +use datafusion_expr::logical_plan::{ + Aggregate, Filter, LogicalPlan, Projection, Sort, Window, +}; use datafusion_expr::{col, Expr, ExprSchemable}; use indexmap::IndexMap; @@ -127,21 +131,21 @@ impl CommonSubexprEliminate { /// Returns the rewritten expressions fn rewrite_exprs_list( &self, - exprs_list: &[&[Expr]], + exprs_list: Vec>, arrays_list: &[&[Vec<(usize, String)>]], expr_stats: &ExprStats, common_exprs: &mut CommonExprs, ) -> Result>> { exprs_list - .iter() + .into_iter() .zip(arrays_list.iter()) .map(|(exprs, arrays)| { exprs - .iter() - .cloned() + .into_iter() .zip(arrays.iter()) .map(|(expr, id_array)| { replace_common_expr(expr, id_array, expr_stats, common_exprs) + .data() }) .collect::>>() }) @@ -158,7 +162,7 @@ impl CommonSubexprEliminate { /// common sub-expressions that were used fn rewrite_expr( &self, - exprs_list: &[&[Expr]], + exprs_list: Vec>, arrays_list: &[&[Vec<(usize, String)>]], input: &LogicalPlan, expr_stats: &ExprStats, @@ -173,9 +177,8 @@ impl CommonSubexprEliminate { &mut common_exprs, )?; - let mut new_input = self - .try_optimize(input, config)? - .unwrap_or_else(|| input.clone()); + let mut new_input = self.rewrite(input.clone(), config)?.data; + if !common_exprs.is_empty() { new_input = build_common_expr_project_plan(new_input, common_exprs, expr_stats)?; @@ -186,7 +189,7 @@ impl CommonSubexprEliminate { fn try_optimize_window( &self, - window: &Window, + window: Window, config: &dyn OptimizerConfig, ) -> Result { let mut window_exprs = vec![]; @@ -221,7 +224,7 @@ impl CommonSubexprEliminate { arrays_per_window.push(arrays); } - let mut window_exprs = window_exprs + let window_exprs = window_exprs .iter() .map(|expr| expr.as_slice()) .collect::>(); @@ -231,14 +234,18 @@ impl CommonSubexprEliminate { .collect::>(); assert_eq!(window_exprs.len(), arrays_per_window.len()); + // todo remove clone + let num_window_exprs = window_exprs.len(); + let mut window_exprs: Vec<_> = + window_exprs.iter().map(|exprs| exprs.to_vec()).collect(); let (mut new_expr, new_input) = self.rewrite_expr( - &window_exprs, + window_exprs.clone(), &arrays_per_window, &plan, &expr_stats, config, )?; - assert_eq!(window_exprs.len(), new_expr.len()); + assert_eq!(num_window_exprs, new_expr.len()); // Construct consecutive window operator, with their corresponding new window expressions. plan = new_input; @@ -265,7 +272,7 @@ impl CommonSubexprEliminate { fn try_optimize_aggregate( &self, - aggregate: &Aggregate, + aggregate: Aggregate, config: &dyn OptimizerConfig, ) -> Result { let Aggregate { @@ -279,18 +286,18 @@ impl CommonSubexprEliminate { // rewrite inputs let input_schema = Arc::clone(input.schema()); let group_arrays = to_arrays( - group_expr, + &group_expr, Arc::clone(&input_schema), &mut expr_stats, ExprMask::Normal, )?; let aggr_arrays = - to_arrays(aggr_expr, input_schema, &mut expr_stats, ExprMask::Normal)?; + to_arrays(&aggr_expr, input_schema, &mut expr_stats, ExprMask::Normal)?; let (mut new_expr, new_input) = self.rewrite_expr( - &[group_expr, aggr_expr], + vec![group_expr.clone(), aggr_expr.clone()], &[&group_arrays, &aggr_arrays], - input, + &input, &expr_stats, config, )?; @@ -309,7 +316,7 @@ impl CommonSubexprEliminate { )?; let mut common_exprs = IndexMap::new(); let mut rewritten = self.rewrite_exprs_list( - &[&new_aggr_expr], + vec![new_aggr_expr.clone()], &[&aggr_arrays], &expr_stats, &mut common_exprs, @@ -374,42 +381,123 @@ impl CommonSubexprEliminate { } } + /// Rewrites the expr list and input to remove common subexpressions + /// + /// # Parameters + /// + /// * `exprs`: List of expressions in the node + /// * `input`: input plan (that produces the columns referred to in `exprs`) + /// + /// # Return value + /// + /// Returns `(rewritten_exprs, new_input)`. `new_input` is either: + /// + /// 1. The original `input` of no common subexpressions were extracted + /// 2. A newly added projection on top of the original input + /// that computes the common subexpressions fn try_unary_plan( &self, - plan: &LogicalPlan, + expr: Vec, + input: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result { - let expr = plan.expressions(); - let inputs = plan.inputs(); - let input = inputs[0]; - let input_schema = Arc::clone(input.schema()); + ) -> Result, LogicalPlan)>> { let mut expr_stats = ExprStats::new(); - - // Visit expr list and build expr identifier to occuring count map (`expr_stats`). - let arrays = to_arrays(&expr, input_schema, &mut expr_stats, ExprMask::Normal)?; + let arrays = to_arrays( + &expr, + Arc::clone(input.schema()), + &mut expr_stats, + ExprMask::Normal, + )?; let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], &[&arrays], input, &expr_stats, config)?; - - plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) + self.rewrite_expr(vec![expr], &[&arrays], &input, &expr_stats, config)?; + assert_eq!(new_expr.len(), 1); + let result = (new_expr.pop().unwrap(), new_input); + // todo pass up transformed from rewrite_expr + Ok(Transformed::yes(result)) } } impl OptimizerRule for CommonSubexprEliminate { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called CommonSubexprEliminate::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let original_schema = Arc::clone(plan.schema()); + let optimized_plan = match plan { - LogicalPlan::Projection(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?), + LogicalPlan::Projection(projection) => { + let Projection { + expr, + input, + schema, + .. + } = projection; + let input = unwrap_arc(input); + self.try_unary_plan(expr, input, config)?.map_data( + |(new_expr, new_input)| { + Projection::try_new_with_schema( + new_expr, + Arc::new(new_input), + schema, + ) + .map(LogicalPlan::Projection) + }, + )? + } + LogicalPlan::Sort(sort) => { + let Sort { expr, input, fetch } = sort; + let input = unwrap_arc(input); + self.try_unary_plan(expr, input, config)?.update_data( + |(new_expr, new_input)| { + LogicalPlan::Sort(Sort { + expr: new_expr, + input: Arc::new(new_input), + fetch, + }) + }, + ) + } + LogicalPlan::Filter(filter) => { + let Filter { + predicate, input, .. + } = filter; + let input = unwrap_arc(input); + let expr = vec![predicate]; + self.try_unary_plan(expr, input, config)? + .transform_data(|(mut new_expr, new_input)| { + assert_eq!(new_expr.len(), 1); // passed in vec![predicate] + let new_predicate = new_expr.pop().unwrap(); + Ok(Filter::remove_aliases(new_predicate)? + .update_data(|new_predicate| (new_predicate, new_input))) + })? + .map_data(|(new_predicate, new_input)| { + Filter::try_new(new_predicate, Arc::new(new_input)) + .map(LogicalPlan::Filter) + })? + } LogicalPlan::Window(window) => { - Some(self.try_optimize_window(window, config)?) + Transformed::yes(self.try_optimize_window(window, config)?) } LogicalPlan::Aggregate(aggregate) => { - Some(self.try_optimize_aggregate(aggregate, config)?) + Transformed::yes(self.try_optimize_aggregate(aggregate, config)?) } LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_) @@ -433,21 +521,19 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { - // apply the optimization to all inputs of the plan - utils::optimize_children(self, plan, config)? + // ApplyOrder::TopDown handles recursion + Transformed::no(plan) } }; - let original_schema = plan.schema(); - match optimized_plan { - Some(optimized_plan) if optimized_plan.schema() != original_schema => { - // add an additional projection if the output schema changed. - Ok(Some(build_recover_project_plan( - original_schema, - optimized_plan, - )?)) - } - plan => Ok(plan), + // If we rewrote the plan, ensure the schema stays the same + if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema + { + optimized_plan.map_data(|optimized_plan| { + build_recover_project_plan(&original_schema, optimized_plan) + }) + } else { + Ok(optimized_plan) } } @@ -475,13 +561,20 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) } +/// Returns the identifier list for each element in `exprs` +/// +/// Returns and array with 1 element for each input expr in `exprs` +/// +/// Each element is itself the result of [`expr_to_identifier`] for that expr +/// (e.g. the identifiers for each node in the tree) fn to_arrays( - expr: &[Expr], + exprs: &[Expr], input_schema: DFSchemaRef, expr_stats: &mut ExprStats, expr_mask: ExprMask, ) -> Result>> { - expr.iter() + exprs + .iter() .map(|e| { let mut id_array = vec![]; expr_to_identifier( @@ -823,14 +916,13 @@ fn replace_common_expr( id_array: &IdArray, expr_stats: &ExprStats, common_exprs: &mut CommonExprs, -) -> Result { +) -> Result> { expr.rewrite(&mut CommonSubexprRewriter { expr_stats, id_array, common_exprs, down_index: 0, }) - .data() } #[cfg(test)] @@ -853,12 +945,11 @@ mod test { use super::*; - fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) { + fn assert_optimized_plan_eq(expected: &str, plan: LogicalPlan) { let optimizer = CommonSubexprEliminate {}; - let optimized_plan = optimizer - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + let optimized_plan = optimizer.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(optimized_plan.transformed, "failed to optimize plan"); + let optimized_plan = optimized_plan.data; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(expected, formatted_plan); } @@ -957,7 +1048,7 @@ mod test { \n Projection: test.a * (Int32(1) - test.b) AS {test.a * (Int32(1) - test.b)|{Int32(1) - test.b|{test.b}|{Int32(1)}}|{test.a}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1010,7 +1101,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}, AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6, my_agg(test.c) AS {my_agg(test.c)}]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1029,7 +1120,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1044,7 +1135,7 @@ mod test { let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1061,7 +1152,7 @@ mod test { \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -1083,7 +1174,7 @@ mod test { \n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1110,7 +1201,7 @@ mod test { \n Projection: UInt32(1) + table.test.col.a AS {UInt32(1) + table.test.col.a|{table.test.col.a}|{UInt32(1)}}, table.test.col.a\ \n TableScan: table.test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1130,7 +1221,7 @@ mod test { \n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1146,7 +1237,7 @@ mod test { let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1164,7 +1255,7 @@ mod test { \n Projection: Int32(1) + test.a, test.a\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1264,10 +1355,9 @@ mod test { .build() .unwrap(); let rule = CommonSubexprEliminate {}; - let optimized_plan = rule - .try_optimize(&plan, &OptimizerContext::new()) - .unwrap() - .unwrap(); + let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(optimized_plan.transformed); + let optimized_plan = optimized_plan.data; let schema = optimized_plan.schema(); let fields_with_datatypes: Vec<_> = schema @@ -1306,7 +1396,7 @@ mod test { \n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) }