diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 9dfc238ab9e83..88b43ccdede7c 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -95,12 +95,12 @@ impl MyAnalyzerRule { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, filter.input, )?)) } - _ => Transformed::No(plan), + _ => Transformed::no(plan), }) }) } @@ -111,11 +111,11 @@ impl MyAnalyzerRule { Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { // transform to UInt64 - Transformed::Yes(Expr::Literal(ScalarValue::UInt64( + Transformed::yes(Expr::Literal(ScalarValue::UInt64( i.map(|i| i as u64), ))) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) } @@ -175,12 +175,12 @@ fn my_rewrite(expr: Expr) -> Result { let low: Expr = *low; let high: Expr = *high; if negated { - Transformed::Yes(expr.clone().lt(low).or(expr.gt(high))) + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) } else { - Transformed::Yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) + Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) } } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 85338a7200df5..6b84eda978afe 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -41,21 +41,6 @@ macro_rules! handle_tree_recursion { }; } -macro_rules! handle_tree_recursion_without_stop { - ($TNR:expr, $NODE:expr) => { - match $TNR { - TreeNodeRecursion::Continue => {} - // If the recursion should skip, do not apply to its children, let - // the recursion continue: - TreeNodeRecursion::Skip => return Ok($NODE), - // Stop is not (yet) supported - TreeNodeRecursion::Stop => { - panic!("Stop can't be used in `TreeNode::transform()` and `TreeNode::rewrite()`") - } - } - }; -} - /// Defines a visitable and rewriteable a tree node. This trait is /// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as /// well as expression trees ([`PhysicalExpr`], [`Expr`]) in @@ -142,62 +127,62 @@ pub trait TreeNode: Sized { /// and please note that [`TreeNodeRecursion::Stop`] is not supported. /// /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. - fn transform(self, f_down: &mut FD, f_up: &mut FU) -> Result + fn transform( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> where - FD: FnMut(Self) -> Result<(Transformed, TreeNodeRecursion)>, - FU: FnMut(Self) -> Result, + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, { - let (new_node, tnr) = f_down(self).map(|(t, tnr)| (t.into(), tnr))?; - handle_tree_recursion_without_stop!(tnr, new_node); - let node_with_new_children = - new_node.map_children(|node| node.transform(f_down, f_up))?; - f_up(node_with_new_children) + f_down(self)?.and_then_transform_children(|t| { + t.map_children(|node| node.transform(f_down, f_up))? + .and_then_transform_sibling(f_up) + }) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result + fn transform_down(self, f: &F) -> Result> where F: Fn(Self) -> Result>, { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down(op)) + f(self)?.and_then_transform_children(|t| t.map_children(|n| n.transform_down(f))) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal) using a mutable function, `F`. /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down_mut(self, op: &mut F) -> Result + fn transform_down_mut(self, f: &mut F) -> Result> where F: FnMut(Self) -> Result>, { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down_mut(op)) + f(self)? + .and_then_transform_children(|t| t.map_children(|n| n.transform_down_mut(f))) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result + fn transform_up(self, f: &F) -> Result> where F: Fn(Self) -> Result>, { - let after_op_children = self.map_children(|node| node.transform_up(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) + self.map_children(|node| node.transform_up(f))? + .and_then_transform_sibling(f) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal) using a mutable function, `F`. /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up_mut(self, op: &mut F) -> Result + fn transform_up_mut(self, f: &mut F) -> Result> where F: FnMut(Self) -> Result>, { - let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) + self.map_children(|n| n.transform_up_mut(f))? + .and_then_transform_sibling(f) } /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for @@ -225,12 +210,14 @@ pub trait TreeNode: Sized { /// /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], /// recursion is stopped immediately. - fn rewrite>(self, rewriter: &mut R) -> Result { - let (new_node, tnr) = rewriter.f_down(self)?; - handle_tree_recursion_without_stop!(tnr, new_node); - let node_with_new_children = - new_node.map_children(|node| node.rewrite(rewriter))?; - rewriter.f_up(node_with_new_children) + fn rewrite>( + self, + rewriter: &mut R, + ) -> Result> { + rewriter.f_down(self)?.and_then_transform_children(|t| { + t.map_children(|n| n.rewrite(rewriter))? + .and_then_transform_sibling(|t| rewriter.f_up(t)) + }) } /// Apply the closure `F` to the node's children @@ -239,9 +226,9 @@ pub trait TreeNode: Sized { F: FnMut(&Self) -> Result; /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result; + F: FnMut(Self) -> Result>; } /// Implements the [visitor @@ -291,20 +278,20 @@ pub trait TreeNodeRewriter: Sized { /// Invoked while traversing down the tree before any children are rewritten / /// visited. /// Default implementation returns the node unmodified and continues recursion. - fn f_down(&mut self, node: Self::Node) -> Result<(Self::Node, TreeNodeRecursion)> { - Ok((node, TreeNodeRecursion::Continue)) + fn f_down(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) } /// Invoked while traversing up the tree after all children have been rewritten / /// visited. /// Default implementation returns the node unmodified. - fn f_up(&mut self, node: Self::Node) -> Result { - Ok(node) + fn f_up(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) } } /// Controls how [`TreeNode`] recursions should proceed. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, @@ -314,27 +301,147 @@ pub enum TreeNodeRecursion { Stop, } -pub enum Transformed { - /// The item was transformed / rewritten somehow - Yes(T), - /// The item was not transformed - No(T), +pub struct Transformed { + pub data: T, + pub transformed: bool, + pub tnr: TreeNodeRecursion, } impl Transformed { - pub fn into(self) -> T { - match self { - Transformed::Yes(t) => t, - Transformed::No(t) => t, + pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self { + Self { + data, + transformed, + tnr, } } - pub fn into_pair(self) -> (T, bool) { - match self { - Transformed::Yes(t) => (t, true), - Transformed::No(t) => (t, false), + pub fn yes(data: T) -> Self { + Self { + data, + transformed: true, + tnr: TreeNodeRecursion::Continue, } } + + pub fn no(data: T) -> Self { + Self { + data, + transformed: false, + tnr: TreeNodeRecursion::Continue, + } + } + + pub fn map_data U>(self, f: F) -> Transformed { + Transformed { + data: f(self.data), + transformed: self.transformed, + tnr: self.tnr, + } + } + + pub fn flat_map_data Result>( + self, + f: F, + ) -> Result> { + Ok(Transformed { + data: f(self.data)?, + transformed: self.transformed, + tnr: self.tnr, + }) + } + + fn and_then_transform Result>>( + self, + f: F, + children: bool, + ) -> Result> { + match self.tnr { + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Skip => { + // If the next transformation would happen on children return immediately + // on `Skip`. + if children { + return Ok(Transformed { + tnr: TreeNodeRecursion::Continue, + ..self + }); + } + } + TreeNodeRecursion::Stop => return Ok(self), + }; + let t = f(self.data)?; + Ok(Transformed { + transformed: t.transformed || self.transformed, + ..t + }) + } + + pub fn and_then_transform_sibling Result>>( + self, + f: F, + ) -> Result> { + self.and_then_transform(f, false) + } + + pub fn and_then_transform_children Result>>( + self, + f: F, + ) -> Result> { + self.and_then_transform(f, true) + } +} + +pub trait TransformedIterator: Iterator { + fn map_till_continue_and_collect( + self, + f: F, + ) -> Result>> + where + F: FnMut(Self::Item) -> Result>, + Self: Sized; +} + +impl TransformedIterator for I { + fn map_till_continue_and_collect( + self, + mut f: F, + ) -> Result>> + where + F: FnMut(Self::Item) -> Result>, + { + let mut new_tnr = TreeNodeRecursion::Continue; + let mut new_transformed = false; + let new_data = self + .map(|i| { + if new_tnr == TreeNodeRecursion::Continue + || new_tnr == TreeNodeRecursion::Skip + { + let Transformed { + data, + transformed, + tnr, + } = f(i)?; + new_tnr = if tnr == TreeNodeRecursion::Skip { + // Iterator always considers the elements as siblings so `Skip` + // can be safely converted to `Continue`. + TreeNodeRecursion::Continue + } else { + tnr + }; + new_transformed |= transformed; + Ok(data) + } else { + Ok(i) + } + }) + .collect::>>()?; + Ok(Transformed { + data: new_data, + transformed: new_transformed, + tnr: new_tnr, + }) + } } /// Helper trait for implementing [`TreeNode`] that have children stored as Arc's @@ -350,7 +457,7 @@ pub trait DynTreeNode { &self, arc_self: Arc, new_children: Vec>, - ) -> Result>; + ) -> Result>>; } /// Blanket implementation for Arc for any tye that implements @@ -367,18 +474,18 @@ impl TreeNode for Arc { Ok(TreeNodeRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { let children = self.arc_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; + let t = children.into_iter().map_till_continue_and_collect(f)?; + // TODO: once we trust `t.transformed` don't create new node if not necessary let arc_self = Arc::clone(&self); - self.with_new_arc_children(arc_self, new_children) + self.with_new_arc_children(arc_self, t.data) } else { - Ok(self) + Ok(Transformed::no(self)) } } } @@ -409,17 +516,19 @@ impl TreeNode for T { Ok(TreeNodeRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { let (new_self, children) = self.take_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; - new_self.with_new_children(new_children) + children + .into_iter() + .map_till_continue_and_collect(f)? + // TODO: once we trust `transformed` don't create new node if not necessary + .flat_map_data(|new_children| new_self.with_new_children(new_children)) } else { - Ok(new_self) + Ok(Transformed::no(new_self)) } } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index ddfeb146b876d..bdd607095f445 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -20,7 +20,9 @@ use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; @@ -189,7 +191,7 @@ impl<'a> FilterCandidateBuilder<'a> { metadata: &ParquetMetaData, ) -> Result> { let expr = self.expr.clone(); - let expr = expr.rewrite(&mut self)?; + let expr = expr.rewrite(&mut self)?.data; if self.non_primitive_columns || self.projected_columns { Ok(None) @@ -214,27 +216,30 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { fn f_down( &mut self, node: Arc, - ) -> Result<(Arc, TreeNodeRecursion)> { + ) -> Result>> { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok((node, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Skip)); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok((node, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Skip)); } } - Ok((node, TreeNodeRecursion::Continue)) + Ok(Transformed::no(node)) } - fn f_up(&mut self, expr: Arc) -> Result> { + fn f_up( + &mut self, + expr: Arc, + ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { if self.file_schema.field_with_name(column.name()).is_err() { // the column expr must be in the table schema @@ -242,7 +247,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { Ok(field) => { // return the null value corresponding to the data type let null_value = ScalarValue::try_from(field.data_type())?; - Ok(Arc::new(Literal::new(null_value))) + Ok(Transformed::yes(Arc::new(Literal::new(null_value)))) } Err(e) => { // If the column is not in the table schema, should throw the error @@ -252,7 +257,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4fe11c14a7583..5f872831ef93b 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -27,7 +27,7 @@ use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -85,10 +85,14 @@ impl PhysicalOptimizerRule for AggregateStatistics { Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| { + self.optimize(child, _config).map(Transformed::yes) + }) + .map(|t| t.data) } } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes)) + .map(|t| t.data) } } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 7b66ca5290942..e3565e451669c 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -71,14 +71,15 @@ impl PhysicalOptimizerRule for CoalesceBatches { }) .unwrap_or(false); if wrap_in_coalesce { - Ok(Transformed::Yes(Arc::new(CoalesceBatchesExec::new( + Ok(Transformed::yes(Arc::new(CoalesceBatchesExec::new( plan, target_batch_size, )))) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .map(|t| t.data) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index b26d9763e53a5..ccc9a2909cca4 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -109,11 +109,12 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { }); Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) }) + .map(|t| t.data) } fn name(&self) -> &str { @@ -185,11 +186,12 @@ fn discard_column_index(group_expr: Arc) -> Arc None, }; Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) + Transformed::yes(normalized_form) } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .map(|t| t.data) .unwrap_or(group_expr) } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index fab26c49c2daa..ff033d168e77b 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -197,22 +197,25 @@ impl PhysicalOptimizerRule for EnforceDistribution { let adjusted = if top_down_join_key_reordering { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new_default(plan); - let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; + let adjusted = plan_requirements + .transform_down(&adjust_input_keys_ordering)? + .data; adjusted.plan } else { // Run a bottom-up process plan.transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) })? + .data }; let distribution_context = DistributionContext::new_default(adjusted); // Distribution enforcement needs to be applied bottom-up. - let distribution_context = - distribution_context.transform_up(&|distribution_context| { + let distribution_context = distribution_context + .transform_up(&|distribution_context| { ensure_distribution(distribution_context, config) - })?; + })? + .data; Ok(distribution_context.plan) } @@ -306,7 +309,7 @@ fn adjust_input_keys_ordering( vec![], &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } PartitionMode::CollectLeft => { // Push down requirements to the right side @@ -368,18 +371,18 @@ fn adjust_input_keys_ordering( sort_options.clone(), &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } else if let Some(aggregate_exec) = plan.as_any().downcast_ref::() { if !requirements.data.is_empty() { if aggregate_exec.mode() == &AggregateMode::FinalPartitioned { return reorder_aggregate_keys(requirements, aggregate_exec) - .map(Transformed::Yes); + .map(Transformed::yes); } else { requirements.data.clear(); } } else { // Keep everything unchanged - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } } else if let Some(proj) = plan.as_any().downcast_ref::() { let expr = proj.expr(); @@ -407,7 +410,7 @@ fn adjust_input_keys_ordering( child.data = requirements.data.clone(); } } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn reorder_partitioned_join_keys( @@ -1065,7 +1068,7 @@ fn ensure_distribution( let dist_context = update_children(dist_context)?; if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); + return Ok(Transformed::no(dist_context)); } let target_partitions = config.execution.target_partitions; @@ -1245,7 +1248,7 @@ fn ensure_distribution( plan.with_new_children(children_plans)? }; - Ok(Transformed::Yes(DistributionContext::new( + Ok(Transformed::yes(DistributionContext::new( plan, data, children, ))) } @@ -1718,7 +1721,7 @@ pub(crate) mod tests { config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; config.optimizer.prefer_existing_sort = prefer_existing_sort; - ensure_distribution(distribution_context, &config).map(|item| item.into().plan) + ensure_distribution(distribution_context, &config).map(|item| item.data.plan) } /// Test whether plan matches with expected plan @@ -1786,22 +1789,22 @@ pub(crate) mod tests { let plan_requirements = PlanWithKeyRequirements::new_default($PLAN.clone()); let adjusted = plan_requirements - .transform_down(&adjust_input_keys_ordering) + .transform_down(&adjust_input_keys_ordering).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. adjusted.plan } else { // Run reorder_join_keys_to_inputs rule $PLAN.clone().transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) - })? + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) + })?.data }; // Then run ensure_distribution rule DistributionContext::new_default(adjusted) .transform_up(&|distribution_context| { ensure_distribution(distribution_context, &config) - }) + }).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 5c46e64a22f69..7a3b2c512111a 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -158,33 +158,35 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_requirements = PlanWithCorrespondingSort::new_default(plan); // Execute a bottom-up traversal to enforce sorting requirements, // remove unnecessary sorts, and optimize sort-sensitive operators: - let adjusted = plan_requirements.transform_up(&ensure_sorting)?; + let adjusted = plan_requirements.transform_up(&ensure_sorting)?.data; let new_plan = if config.optimizer.repartition_sorts { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); - let parallel = - plan_with_coalesce_partitions.transform_up(¶llelize_sorts)?; + let parallel = plan_with_coalesce_partitions + .transform_up(¶llelize_sorts)? + .data; parallel.plan } else { adjusted.plan }; let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); - let updated_plan = - plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + let updated_plan = plan_with_pipeline_fixer + .transform_up(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, true, config, ) - })?; + })? + .data; // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; + let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?.data; Ok(adjusted.plan) } @@ -221,7 +223,7 @@ fn parallelize_sorts( // `SortPreservingMergeExec` or a `CoalescePartitionsExec`, and they // all have a single child. Therefore, if the first child has no // connection, we can return immediately. - Ok(Transformed::No(requirements)) + Ok(Transformed::no(requirements)) } else if (is_sort(&requirements.plan) || is_sort_preserving_merge(&requirements.plan)) && requirements.plan.output_partitioning().partition_count() <= 1 @@ -250,7 +252,7 @@ fn parallelize_sorts( } let spm = SortPreservingMergeExec::new(sort_exprs, requirements.plan.clone()); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(spm.with_fetch(fetch)), false, @@ -264,7 +266,7 @@ fn parallelize_sorts( // For the removal of self node which is also a `CoalescePartitionsExec`. requirements = requirements.children.swap_remove(0); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(CoalescePartitionsExec::new(requirements.plan.clone())), false, @@ -272,7 +274,7 @@ fn parallelize_sorts( ), )) } else { - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -285,10 +287,12 @@ fn ensure_sorting( // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.children.is_empty() { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } let maybe_requirements = analyze_immediate_sort_removal(requirements); - let Transformed::No(mut requirements) = maybe_requirements else { + requirements = if !maybe_requirements.transformed { + maybe_requirements.data + } else { return Ok(maybe_requirements); }; @@ -327,17 +331,17 @@ fn ensure_sorting( // calculate the result in reverse: let child_node = &requirements.children[0]; if is_window(plan) && child_node.data { - return adjust_window_sort_removal(requirements).map(Transformed::Yes); + return adjust_window_sort_removal(requirements).map(Transformed::yes); } else if is_sort_preserving_merge(plan) && child_node.plan.output_partitioning().partition_count() <= 1 { // This `SortPreservingMergeExec` is unnecessary, input already has a // single partition. let child_node = requirements.children.swap_remove(0); - return Ok(Transformed::Yes(child_node)); + return Ok(Transformed::yes(child_node)); } - update_sort_ctx_children(requirements, false).map(Transformed::Yes) + update_sort_ctx_children(requirements, false).map(Transformed::yes) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input @@ -367,10 +371,10 @@ fn analyze_immediate_sort_removal( child.data = false; } node.data = false; - return Transformed::Yes(node); + return Transformed::yes(node); } } - Transformed::No(node) + Transformed::no(node) } /// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine @@ -641,7 +645,7 @@ mod tests { { let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); let adjusted = plan_requirements - .transform_up(&ensure_sorting) + .transform_up(&ensure_sorting).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -649,7 +653,7 @@ mod tests { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); let parallel = plan_with_coalesce_partitions - .transform_up(¶llelize_sorts) + .transform_up(¶llelize_sorts).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. parallel.plan @@ -666,14 +670,14 @@ mod tests { true, state.config_options(), ) - }) + }).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); sort_pushdown - .transform_down(&pushdown_sorts) + .transform_down(&pushdown_sorts).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 02626056f6ccb..98a05b5877e04 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -241,7 +241,9 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; + let state = pipeline + .transform_up(&|p| apply_subrules(p, &subrules, config))? + .data; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -256,13 +258,16 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; - state.plan.transform_up(&|plan| { - statistical_join_selection_subrule( - plan, - collect_threshold_byte_size, - collect_threshold_num_rows, - ) - }) + state + .plan + .transform_up(&|plan| { + statistical_join_selection_subrule( + plan, + collect_threshold_byte_size, + collect_threshold_num_rows, + ) + }) + .map(|t| t.data) } fn name(&self) -> &str { @@ -438,9 +443,9 @@ fn statistical_join_selection_subrule( }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) } @@ -671,7 +676,7 @@ fn apply_subrules( // etc. If this doesn't happen, the final `PipelineChecker` rule will // catch this and raise an error anyway. .unwrap_or(true); - Ok(Transformed::Yes(input)) + Ok(Transformed::yes(input)) } #[cfg(test)] @@ -836,6 +841,7 @@ mod tests_statistical { ]; let state = pipeline .transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new())) + .map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. let config = ConfigOptions::new().optimizer; diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 9855247151b88..caf8b61c5b2cf 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -106,7 +106,7 @@ impl LimitedDistinctAggregation { let mut rewrite_applicable = true; let mut closure = |plan: Arc| { if !rewrite_applicable { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { if found_match_aggr { @@ -117,7 +117,7 @@ impl LimitedDistinctAggregation { // a partial and final aggregation with different groupings disqualifies // rewriting the child aggregation rewrite_applicable = false; - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } } } @@ -128,14 +128,18 @@ impl LimitedDistinctAggregation { Some(new_aggr) => { match_aggr = plan; found_match_aggr = true; - return Ok(Transformed::Yes(new_aggr)); + return Ok(Transformed::yes(new_aggr)); } } } rewrite_applicable = false; - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child + .clone() + .transform_down_mut(&mut closure) + .map(|t| t.data) + .ok()?; if is_global_limit { return Some(Arc::new(GlobalLimitExec::new( child, @@ -165,12 +169,13 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { if let Some(plan) = LimitedDistinctAggregation::transform_limit(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) })? + .data } else { plan }; diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index 4d03840d3dd31..38877d0bab692 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -196,15 +196,17 @@ impl PhysicalOptimizerRule for OutputRequirements { ) -> Result> { match self.mode { RuleMode::Add => require_top_ordering(plan), - RuleMode::Remove => plan.transform_up(&|plan| { - if let Some(sort_req) = - plan.as_any().downcast_ref::() - { - Ok(Transformed::Yes(sort_req.input())) - } else { - Ok(Transformed::No(plan)) - } - }), + RuleMode::Remove => plan + .transform_up(&|plan| { + if let Some(sort_req) = + plan.as_any().downcast_ref::() + { + Ok(Transformed::yes(sort_req.input())) + } else { + Ok(Transformed::no(plan)) + } + }) + .map(|t| t.data), } } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index bb0665c10bcc2..c09d9ada7def0 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -53,7 +53,8 @@ impl PhysicalOptimizerRule for PipelineChecker { ) -> Result> { let pipeline = PipelineStatePropagator::new_default(plan); let state = pipeline - .transform_up(&|p| check_finiteness_requirements(p, &config.optimizer))?; + .transform_up(&|p| check_finiteness_requirements(p, &config.optimizer))? + .data; Ok(state.plan) } @@ -93,7 +94,7 @@ pub fn check_finiteness_requirements( .unbounded_output(&children_unbounded(&input)) .map(|value| { input.data = value; - Transformed::Yes(input) + Transformed::yes(input) }) } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index b2be307c3bd94..d1af2a29cf916 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -74,6 +74,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { _config: &ConfigOptions, ) -> Result> { plan.transform_down(&remove_unnecessary_projections) + .map(|t| t.data) } fn name(&self) -> &str { @@ -98,7 +99,7 @@ pub fn remove_unnecessary_projections( // If the projection does not cause any change on the input, we can // safely remove it: if is_projection_removable(projection) { - return Ok(Transformed::Yes(projection.input().clone())); + return Ok(Transformed::yes(projection.input().clone())); } // If it does, check if we can push it under its child(ren): let input = projection.input().as_any(); @@ -112,7 +113,7 @@ pub fn remove_unnecessary_projections( // To unify 3 or more sequential projections: remove_unnecessary_projections(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; } else if let Some(output_req) = input.downcast_ref::() { try_swapping_with_output_req(projection, output_req)? @@ -148,10 +149,10 @@ pub fn remove_unnecessary_projections( None } } else { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); }; - Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) + Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes)) } /// Tries to embed `projection` to its input (`csv`). If possible, returns @@ -896,16 +897,16 @@ fn update_expr( .clone() .transform_up_mut(&mut |expr: Arc| { if state == RewriteState::RewrittenInvalid { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); } let Some(column) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: - Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + Ok(Transformed::yes(projected_exprs[column.index()].0.clone())) } else { // default to invalid, in case we can't find the relevant column state = RewriteState::RewrittenInvalid; @@ -924,11 +925,12 @@ fn update_expr( ) }) .map_or_else( - || Ok(Transformed::No(expr)), - |c| Ok(Transformed::Yes(c)), + || Ok(Transformed::no(expr)), + |c| Ok(Transformed::yes(c)), ) } - }); + }) + .map(|t| t.data); new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } @@ -1045,7 +1047,7 @@ fn new_columns_for_join_on( }) .map(|(index, (_, alias))| Column::new(alias, index)); if let Some(new_column) = new_column { - Ok(Transformed::Yes(Arc::new(new_column))) + Ok(Transformed::yes(Arc::new(new_column))) } else { // If the column is not found in the projection expressions, // it means that the column is not projected. In this case, @@ -1056,9 +1058,10 @@ fn new_columns_for_join_on( ))) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } }) + .map(|t| t.data) .ok() }) .collect::>(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index aa72771b1eb3f..ecf4bc0e1b7ea 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -840,12 +840,13 @@ fn rewrite_column_expr( e.transform_up(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { - return Ok(Transformed::Yes(Arc::new(column_new.clone()))); + return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) } fn reverse_operator(op: Operator) -> Result { diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index bc9bd0010dc58..4629152cddd91 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -236,7 +236,7 @@ pub(crate) fn replace_with_order_preserving_variants( ) -> Result> { update_children(&mut requirements); if !(is_sort(&requirements.plan) && requirements.children[0].data) { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } // For unbounded cases, we replace with the order-preserving variant in any @@ -260,13 +260,13 @@ pub(crate) fn replace_with_order_preserving_variants( for child in alternate_plan.children.iter_mut() { child.data = false; } - Ok(Transformed::Yes(alternate_plan)) + Ok(Transformed::yes(alternate_plan)) } else { // The alternate plan does not help, use faster order-breaking variants: alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; alternate_plan.data = false; requirements.children = vec![alternate_plan]; - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -395,7 +395,7 @@ mod tests { // Run the rule top-down let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new_default(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).and_then(check_integrity)?; + let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).map(|t| t.data).and_then(check_integrity)?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 3413486c6b460..16b96fce73017 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -87,7 +87,7 @@ pub(crate) fn pushdown_sorts( } // Can push down requirements child.data = None; - return Ok(Transformed::Yes(child)); + return Ok(Transformed::yes(child)); } else { // Can not push down requirements requirements.children = vec![child]; @@ -112,7 +112,7 @@ pub(crate) fn pushdown_sorts( requirements = add_sort_above(requirements, sort_reqs, None); assign_initial_requirements(&mut requirements); } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn pushdown_requirement_to_children( diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5de6cff0b4fad..0ab1d4edfe8d9 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -374,15 +374,19 @@ pub fn sort_exec( /// TODO: Once [`ExecutionPlan`] implements [`PartialEq`], string comparisons should be /// replaced with direct plan equality checks. pub fn check_integrity(context: PlanContext) -> Result> { - context.transform_up(&|node| { - let children_plans = node.plan.children(); - assert_eq!(node.children.len(), children_plans.len()); - for (child_plan, child_node) in children_plans.iter().zip(node.children.iter()) { - assert_eq!( - displayable(child_plan.as_ref()).one_line().to_string(), - displayable(child_node.plan.as_ref()).one_line().to_string() - ); - } - Ok(Transformed::No(node)) - }) + context + .transform_up(&|node| { + let children_plans = node.plan.children(); + assert_eq!(node.children.len(), children_plans.len()); + for (child_plan, child_node) in + children_plans.iter().zip(node.children.iter()) + { + assert_eq!( + displayable(child_plan.as_ref()).one_line().to_string(), + displayable(child_node.plan.as_ref()).one_line().to_string() + ); + } + Ok(Transformed::no(node)) + }) + .map(|t| t.data) } diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index dd02614203043..245617a4d4462 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -101,13 +101,13 @@ impl TopKAggregation { let mut cardinality_preserved = true; let mut closure = |plan: Arc| { if !cardinality_preserved { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it match Self::transform_agg(aggr, order, limit) { None => cardinality_preserved = false, - Some(plan) => return Ok(Transformed::Yes(plan)), + Some(plan) => return Ok(Transformed::yes(plan)), } } else { // or we continue down whitelisted nodes of other types @@ -115,9 +115,13 @@ impl TopKAggregation { cardinality_preserved = false; } } - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child + .clone() + .transform_down_mut(&mut closure) + .map(|t| t.data) + .ok()?; let sort = SortExec::new(sort.expr().to_vec(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); @@ -141,12 +145,13 @@ impl PhysicalOptimizerRule for TopKAggregation { plan.transform_down(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) })? + .data } else { plan }; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3e94ecbd746a8..77ad9591d8c23 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1268,8 +1268,9 @@ impl Expr { rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; } - Ok(Transformed::Yes(expr)) + Ok(Transformed::yes(expr)) }) + .map(|t| t.data) } /// Returns true if some of this `exprs` subexpressions may not be evaluated diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 76bd51619954a..c72c0f00a7378 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -37,12 +37,13 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions @@ -61,12 +62,13 @@ pub fn normalize_col_with_schemas( Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage @@ -80,12 +82,13 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( if let Expr::Column(c) = expr { let col = c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// Recursively normalize all [`Column`] expressions in a list of expression trees @@ -106,14 +109,15 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { - Some(new_c) => Transformed::Yes(Expr::Column((*new_c).to_owned())), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// Recursively 'unnormalize' (remove all qualifiers) from an @@ -129,12 +133,13 @@ pub fn unnormalize_col(expr: Expr) -> Expr { relation: None, name: c.name, }; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) .expect("Unnormalize is infallable") } @@ -167,12 +172,13 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform_up(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) .expect("strip_outer_reference is infallable") } @@ -253,7 +259,7 @@ where R: TreeNodeRewriter, { let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?; + let expr = expr.rewrite(rewriter)?.data; expr.alias_if_changed(original_name) } @@ -263,7 +269,7 @@ mod test { use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -275,14 +281,14 @@ mod test { impl TreeNodeRewriter for RecordingRewriter { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { self.v.push(format!("Previsited {expr}")); - Ok((expr, TreeNodeRecursion::Continue)) + Ok(Transformed::no(expr)) } - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { self.v.push(format!("Mutated {expr}")); - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -297,10 +303,10 @@ mod test { } else { utf8_val }; - Ok(Transformed::Yes(lit(utf8_val))) + Ok(Transformed::yes(lit(utf8_val))) } // otherwise, return None - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }; @@ -308,6 +314,7 @@ mod test { let rewritten = col("state") .eq(lit("foo")) .transform_up(&transformer) + .map(|t| t.data) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); @@ -315,6 +322,7 @@ mod test { let rewritten = col("state") .eq(lit("baz")) .transform_up(&transformer) + .map(|t| t.data) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -452,8 +460,8 @@ mod test { impl TreeNodeRewriter for TestRewriter { type Node = Expr; - fn f_up(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) + fn f_up(&mut self, _: Expr) -> Result> { + Ok(Transformed::yes(self.rewrite_to.clone())) } } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 1e7efcafd04df..1cc35a1a4b949 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -91,7 +91,7 @@ fn rewrite_in_terms_of_projection( .to_field(input.schema()) .map(|f| f.qualified_column())?, ); - return Ok(Transformed::Yes(col)); + return Ok(Transformed::yes(col)); } // if that doesn't work, try to match the expression as an @@ -103,7 +103,7 @@ fn rewrite_in_terms_of_projection( e } else { // The expr is not based on Aggregate plan output. Skip it. - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; // expr is an actual expr like min(t.c2), but we are looking @@ -118,7 +118,7 @@ fn rewrite_in_terms_of_projection( // look for the column named the same as this expr if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { let found = found.clone(); - return Ok(Transformed::Yes(match normalized_expr { + return Ok(Transformed::yes(match normalized_expr { Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { expr: Box::new(found), data_type, @@ -131,8 +131,9 @@ fn rewrite_in_terms_of_projection( })); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) } /// Does the underlying expr match e? diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 80ce38fe93897..f7b035609e05c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -647,29 +647,29 @@ impl LogicalPlan { // Decimal128(Some(69999999999999),30,15) // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - fn unalias_down( - expr: Expr, - ) -> Result<(Transformed, TreeNodeRecursion)> { + fn unalias_down(expr: Expr) -> Result> { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok((Transformed::No(expr), TreeNodeRecursion::Skip)) + Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)) } - Expr::Alias(_) => Ok(( - Transformed::Yes(expr.unalias()), + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, TreeNodeRecursion::Skip, )), - _ => Ok((Transformed::No(expr), TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(expr)), } } - fn dummy_up(expr: Expr) -> Result { - Ok(expr) + fn dummy_up(expr: Expr) -> Result> { + Ok(Transformed::no(expr)) } - let predicate = predicate.transform(&mut unalias_down, &mut dummy_up)?; + let predicate = + predicate.transform(&mut unalias_down, &mut dummy_up)?.data; Filter::try_new(predicate, Arc::new(inputs[0].clone())) .map(LogicalPlan::Filter) @@ -1243,19 +1243,20 @@ impl LogicalPlan { Expr::Placeholder(Placeholder { id, .. }) => { let value = param_values.get_placeholders_with_values(id)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = Arc::new(qry.subquery.replace_params_with_values(param_values)?); - Ok(Transformed::Yes(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns: qry.outer_ref_columns.clone(), }))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }) + .map(|t| t.data) } } @@ -3310,10 +3311,11 @@ digraph { Arc::new(LogicalPlan::TableScan(table)), ) .unwrap(); - Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + Ok(Transformed::yes(LogicalPlan::Filter(filter))) } - x => Ok(Transformed::No(x)), + x => Ok(Transformed::no(x)), }) + .map(|t| t.data) .unwrap(); let expected = "Explain\ diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index f2b0b4c2d266e..5e7dd1990923d 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,7 +24,9 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{handle_tree_recursion, internal_err, DataFusionError, Result}; impl TreeNode for Expr { @@ -135,10 +137,10 @@ impl TreeNode for Expr { Ok(TreeNodeRecursion::Continue) } - fn map_children Result>( - self, - mut transform: F, - ) -> Result { + fn map_children(self, mut f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { Ok(match self { Expr::Column(_) | Expr::Wildcard { .. } @@ -147,27 +149,28 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) => self, + | Expr::Literal(_) => Transformed::no(self), Expr::Alias(Alias { expr, relation, name, - }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), + }) => f(*expr)?.map_data(|e| Expr::Alias(Alias::new(e, relation, name))), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => Expr::InSubquery(InSubquery::new( - transform_boxed(expr, &mut transform)?, - subquery, - negated, - )), + }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::InSubquery(InSubquery::new(be, subquery, negated))), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - Expr::BinaryExpr(BinaryExpr::new( - transform_boxed(left, &mut transform)?, - op, - transform_boxed(right, &mut transform)?, - )) + transform_box(left, &mut f)? + .map_data(|new_left| (new_left, right)) + .and_then_transform_sibling(|(new_left, right)| { + Ok(transform_box(right, &mut f)? + .map_data(|new_right| (new_left, new_right))) + })? + .map_data(|(new_left, new_right)| { + Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) + }) } Expr::Like(Like { negated, @@ -175,213 +178,281 @@ impl TreeNode for Expr { pattern, escape_char, case_insensitive, - }) => Expr::Like(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, pattern)) + .and_then_transform_sibling(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .map_data(|new_pattern| (new_expr, new_pattern))) + })? + .map_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => Expr::SimilarTo(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), - Expr::Not(expr) => Expr::Not(transform_boxed(expr, &mut transform)?), + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, pattern)) + .and_then_transform_sibling(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .map_data(|new_pattern| (new_expr, new_pattern))) + })? + .map_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), + Expr::Not(expr) => transform_box(expr, &mut f)?.map_data(|be| Expr::Not(be)), Expr::IsNotNull(expr) => { - Expr::IsNotNull(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotNull(be)) + } + Expr::IsNull(expr) => { + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNull(be)) + } + Expr::IsTrue(expr) => { + transform_box(expr, &mut f)?.map_data(|be| Expr::IsTrue(be)) + } + Expr::IsFalse(expr) => { + transform_box(expr, &mut f)?.map_data(|be| Expr::IsFalse(be)) } - Expr::IsNull(expr) => Expr::IsNull(transform_boxed(expr, &mut transform)?), - Expr::IsTrue(expr) => Expr::IsTrue(transform_boxed(expr, &mut transform)?), - Expr::IsFalse(expr) => Expr::IsFalse(transform_boxed(expr, &mut transform)?), Expr::IsUnknown(expr) => { - Expr::IsUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsUnknown(be)) } Expr::IsNotTrue(expr) => { - Expr::IsNotTrue(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotTrue(be)) } Expr::IsNotFalse(expr) => { - Expr::IsNotFalse(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotFalse(be)) } Expr::IsNotUnknown(expr) => { - Expr::IsNotUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotUnknown(be)) } Expr::Negative(expr) => { - Expr::Negative(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::Negative(be)) } Expr::Between(Between { expr, negated, low, high, - }) => Expr::Between(Between::new( - transform_boxed(expr, &mut transform)?, - negated, - transform_boxed(low, &mut transform)?, - transform_boxed(high, &mut transform)?, - )), - Expr::Case(case) => { - let expr = transform_option_box(case.expr, &mut transform)?; - let when_then_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - transform_boxed(when, &mut transform)?, - transform_boxed(then, &mut transform)?, + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, low, high)) + .and_then_transform_sibling(|(new_expr, low, high)| { + Ok(transform_box(low, &mut f)? + .map_data(|new_low| (new_expr, new_low, high))) + })? + .and_then_transform_sibling(|(new_expr, new_low, high)| { + Ok(transform_box(high, &mut f)? + .map_data(|new_high| (new_expr, new_low, new_high))) + })? + .map_data(|(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }), + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => transform_option_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, when_then_expr, else_expr)) + .and_then_transform_sibling(|(new_expr, when_then_expr, else_expr)| { + Ok(when_then_expr + .into_iter() + .map_till_continue_and_collect(|(when, then)| { + transform_box(when, &mut f)? + .map_data(|new_when| (new_when, then)) + .and_then_transform_sibling(|(new_when, then)| { + Ok(transform_box(then, &mut f)? + .map_data(|new_then| (new_when, new_then))) + }) + })? + .map_data(|new_when_then_expr| { + (new_expr, new_when_then_expr, else_expr) + })) + })? + .and_then_transform_sibling( + |(new_expr, new_when_then_expr, else_expr)| { + Ok(transform_option_box(else_expr, &mut f)?.map_data( + |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), )) - }) - .collect::>>()?; - let else_expr = transform_option_box(case.else_expr, &mut transform)?; - - Expr::Case(Case::new(expr, when_then_expr, else_expr)) - } - Expr::Cast(Cast { expr, data_type }) => { - Expr::Cast(Cast::new(transform_boxed(expr, &mut transform)?, data_type)) - } - Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new( - transform_boxed(expr, &mut transform)?, - data_type, - )), + }, + )? + .map_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::Cast(Cast::new(be, data_type))), + Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::Sort(Sort { expr, asc, nulls_first, - }) => Expr::Sort(Sort::new( - transform_boxed(expr, &mut transform)?, - asc, - nulls_first, - )), - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( - ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + transform_vec(args, &mut f)?.flat_map_data(|new_args| match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args))) + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + })? + } Expr::WindowFunction(WindowFunction { args, fun, partition_by, order_by, window_frame, - }) => Expr::WindowFunction(WindowFunction::new( - fun, - transform_vec(args, &mut transform)?, - transform_vec(partition_by, &mut transform)?, - transform_vec(order_by, &mut transform)?, - window_frame, - )), + }) => transform_vec(args, &mut f)? + .map_data(|new_args| (new_args, partition_by, order_by)) + .and_then_transform_sibling(|(new_args, partition_by, order_by)| { + Ok(transform_vec(partition_by, &mut f)?.map_data( + |new_partition_by| (new_args, new_partition_by, order_by), + )) + })? + .and_then_transform_sibling(|(new_args, new_partition_by, order_by)| { + Ok(transform_vec(order_by, &mut f)?.map_data(|new_order_by| { + (new_args, new_partition_by, new_order_by) + })) + })? + .map_data(|(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new( + fun, + new_args, + new_partition_by, + new_order_by, + window_frame, + )) + }), Expr::AggregateFunction(AggregateFunction { args, func_def, distinct, filter, order_by, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::UDF(fun) => { - let order_by = order_by - .map(|order_by| transform_vec(order_by, &mut transform)) - .transpose()?; - Expr::AggregateFunction(AggregateFunction::new_udf( - fun, - transform_vec(args, &mut transform)?, - false, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_vec(args, &mut f)? + .map_data(|new_args| (new_args, filter, order_by)) + .and_then_transform_sibling(|(new_args, filter, order_by)| { + Ok(transform_option_box(filter, &mut f)? + .map_data(|new_filter| (new_args, new_filter, order_by))) + })? + .and_then_transform_sibling(|(new_args, new_filter, order_by)| { + Ok(transform_option_vec(order_by, &mut f)? + .map_data(|new_order_by| (new_args, new_filter, new_order_by))) + })? + .flat_map_data(|(new_args, new_filter, new_order_by)| match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun, + new_args, + distinct, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + new_args, + false, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + })?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::Cube(exprs) => Expr::GroupingSet(GroupingSet::Cube( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::GroupingSets(lists_of_exprs) => { - Expr::GroupingSet(GroupingSet::GroupingSets( - lists_of_exprs - .into_iter() - .map(|exprs| transform_vec(exprs, &mut transform)) - .collect::>>()?, - )) - } + GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + .map_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), + GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + .map_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), + GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs + .into_iter() + .map_till_continue_and_collect(|exprs| transform_vec(exprs, &mut f))? + .map_data(|new_lists_of_exprs| { + Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) + }), }, Expr::InList(InList { expr, list, negated, - }) => Expr::InList(InList::new( - transform_boxed(expr, &mut transform)?, - transform_vec(list, &mut transform)?, - negated, - )), + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, list)) + .and_then_transform_sibling(|(new_expr, list)| { + Ok(transform_vec(list, &mut f)? + .map_data(|new_list| (new_expr, new_list))) + })? + .map_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - Expr::GetIndexedField(GetIndexedField::new( - transform_boxed(expr, &mut transform)?, - field, - )) + transform_box(expr, &mut f)? + .map_data(|be| Expr::GetIndexedField(GetIndexedField::new(be, field))) } }) } } -fn transform_boxed Result>( - boxed_expr: Box, - transform: &mut F, -) -> Result> { - // TODO: It might be possible to avoid an allocation (the Box::new) below by reusing the box. - transform(*boxed_expr).map(Box::new) +fn transform_box(be: Box, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + Ok(f(*be)?.map_data(Box::new)) } -fn transform_option_box Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|expr| transform_boxed(expr, transform)) - .transpose() +fn transform_option_box( + obe: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + obe.map_or(Ok(Transformed::no(None)), |be| { + Ok(transform_box(be, f)?.map_data(Some)) + }) } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|exprs| transform_vec(exprs, transform)) - .transpose() +fn transform_option_vec( + ove: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + ove.map_or(Ok(Transformed::no(None)), |ve| { + Ok(transform_vec(ve, f)?.map_data(Some)) + }) } /// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>( - v: Vec, - transform: &mut F, -) -> Result> { - v.into_iter().map(transform).collect() +fn transform_vec(ve: Vec, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + ve.into_iter().map_till_continue_and_collect(f) } diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 8be24638c1cc8..64e678344ea4e 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -19,7 +19,9 @@ use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; use datafusion_common::{handle_tree_recursion, Result}; impl TreeNode for LogicalPlan { @@ -76,26 +78,28 @@ impl TreeNode for LogicalPlan { Ok(TreeNodeRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { let old_children = self.inputs(); - let new_children = old_children + let t = old_children .iter() - .map(|&c| c.clone()) - .map(transform) - .collect::>>()?; - - // if any changes made, make a new child + .map(|c| (*c).clone()) + .map_till_continue_and_collect(f)?; + // TODO: once we trust `t.transformed` remove additional check if old_children .iter() - .zip(new_children.iter()) + .zip(t.data.iter()) .any(|(c1, c2)| c1 != &c2) { - self.with_new_exprs(self.expressions(), new_children.as_slice()) + Ok(Transformed::new( + self.with_new_exprs(self.expressions(), t.data.as_slice())?, + true, + t.tnr, + )) } else { - Ok(self) + Ok(Transformed::new(self, false, t.tnr)) } } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 90046ca2aac0e..4b6c355cb7e9a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -43,7 +43,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal) + plan.transform_down(&analyze_internal).map(|t| t.data) } fn name(&self) -> &str { @@ -61,7 +61,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes( + Ok(Transformed::yes( LogicalPlanBuilder::from((*window.input).clone()) .window(window_expr)? .build()?, @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Aggregate( + Ok(Transformed::yes(LogicalPlan::Aggregate( Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, ))) } @@ -83,7 +83,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Sort(Sort { + Ok(Transformed::yes(LogicalPlan::Sort(Sort { expr: sort_expr, input, fetch, @@ -95,7 +95,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::Projection( Projection::try_new(projection_expr, projection.input)?, ))) } @@ -103,12 +103,12 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { predicate, input, .. }) => { let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; - Ok(Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, input, )?))) } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } @@ -117,8 +117,8 @@ struct CountWildcardRewriter {} impl TreeNodeRewriter for CountWildcardRewriter { type Node = Expr; - fn f_up(&mut self, old_expr: Expr) -> Result { - let new_expr = match old_expr.clone() { + fn f_up(&mut self, old_expr: Expr) -> Result> { + Ok(match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( @@ -130,7 +130,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { window_frame, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::WindowFunction(expr::WindowFunction { + Transformed::yes(Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), @@ -138,10 +138,10 @@ impl TreeNodeRewriter for CountWildcardRewriter { partition_by, order_by, window_frame, - }) + })) } - _ => old_expr, + _ => Transformed::no(old_expr), }, Expr::AggregateFunction(AggregateFunction { func_def: @@ -154,68 +154,65 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction::new( + Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, order_by, - )) + ))) } - _ => old_expr, + _ => Transformed::no(old_expr), }, ScalarSubquery(Subquery { subquery, outer_ref_columns, - }) => { - let new_plan = subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) - } + }) => subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .map_data(|new_plan| { + ScalarSubquery(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + }) + }), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) - } - Expr::Exists(expr::Exists { subquery, negated }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) - } - _ => old_expr, - }; - Ok(new_expr) + }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .map_data(|new_plan| { + Expr::InSubquery(InSubquery::new( + expr, + Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + )) + }), + Expr::Exists(expr::Exists { subquery, negated }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .map_data(|new_plan| { + Expr::Exists(expr::Exists { + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + }) + }), + _ => Transformed::no(old_expr), + }) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index a418fbf5537be..36f0c33183710 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal) + plan.transform_up(&analyze_internal).map(|t| t.data) } fn name(&self) -> &str { @@ -71,16 +71,16 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { // that reference this table. .alias(table_name)? .build()?; - Transformed::Yes(plan) + Transformed::yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + let new_expr = filter.predicate.transform_up(&rewrite_subquery)?.data; + Transformed::yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, )?)) } - _ => Transformed::No(plan), + _ => Transformed::no(plan), }) } @@ -88,9 +88,9 @@ fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal)?.data; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::Exists(Exists { subquery, negated }))) + Ok(Transformed::yes(Expr::Exists(Exists { subquery, negated }))) } Expr::InSubquery(InSubquery { expr, @@ -98,19 +98,19 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal)?.data; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, )))) } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal)?.data; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::ScalarSubquery(subquery))) + Ok(Transformed::yes(Expr::ScalarSubquery(subquery))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } } diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 829197b4d9481..8f5f1f4292ca3 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::utils::list_ndims; use datafusion_common::DFSchema; use datafusion_common::DFSchemaRef; @@ -96,8 +96,8 @@ pub(crate) struct OperatorToFunctionRewriter { impl TreeNodeRewriter for OperatorToFunctionRewriter { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { - match expr { + fn f_up(&mut self, expr: Expr) -> Result> { + Ok(match expr { Expr::BinaryExpr(BinaryExpr { ref left, op, @@ -119,16 +119,16 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { // Convert &Box -> Expr let left = (**left).clone(); let right = (**right).clone(); - return Ok(Expr::ScalarFunction(ScalarFunction { + return Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args: vec![left, right], - })); + }))); } - Ok(expr) + Transformed::no(expr) } - _ => Ok(expr), - } + _ => Transformed::no(expr), + }) } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 14e15f71b18b1..7e5b8de9beae0 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -128,27 +128,27 @@ pub(crate) struct TypeCoercionRewriter { impl TreeNodeRewriter for TypeCoercionRewriter { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match expr { Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns, }) => { let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, - })) + }))) } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Expr::Exists(Exists { + Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }, negated, - })) + }))) } Expr::InSubquery(InSubquery { expr, @@ -166,42 +166,34 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }; - Ok(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( Box::new(expr.cast_to(&common_type, &self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, - ))) - } - Expr::Not(expr) => { - let expr = not(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsTrue(expr) => { - let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotTrue(expr) => { - let expr = is_not_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsFalse(expr) => { - let expr = is_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotFalse(expr) => { - let expr = - is_not_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsUnknown(expr) => { - let expr = is_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotUnknown(expr) => { - let expr = - is_not_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + )))) } + Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( + &expr, + &self.schema, + )?))), + Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), Expr::Like(Like { negated, expr, @@ -223,14 +215,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { })?; let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new( + Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, pattern, escape_char, case_insensitive, - )); - Ok(expr) + )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( @@ -238,12 +229,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &op, &right.get_type(&self.schema)?, )?; - - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left.cast_to(&left_type, &self.schema)?), op, Box::new(right.cast_to(&right_type, &self.schema)?), - ))) + )))) } Expr::Between(Between { expr, @@ -273,13 +263,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - let expr = Expr::Between(Between::new( + Ok(Transformed::yes(Expr::Between(Between::new( Box::new(expr.cast_to(&coercion_type, &self.schema)?), negated, Box::new(low.cast_to(&coercion_type, &self.schema)?), Box::new(high.cast_to(&coercion_type, &self.schema)?), - )); - Ok(expr) + )))) } Expr::InList(InList { expr, @@ -306,18 +295,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list_expr.cast_to(&coerced_type, &self.schema) }) .collect::>>()?; - let expr = Expr::InList(InList ::new( + Ok(Transformed::yes(Expr::InList(InList ::new( Box::new(cast_expr), cast_list_expr, negated, - )); - Ok(expr) + )))) } } } Expr::Case(case) => { let case = coerce_case_expression(case, &self.schema)?; - Ok(Expr::Case(case)) + Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -331,7 +319,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun, )?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( + fun, new_args, + )))) } ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -339,7 +329,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + Ok(Transformed::yes(Expr::ScalarFunction( + ScalarFunction::new_udf(fun, new_expr), + ))) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -359,10 +351,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + ), + ))) } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -370,10 +363,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + ), + ))) } AggregateFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -401,14 +395,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { _ => args, }; - let expr = Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( fun, args, partition_by, order_by, window_frame, - )); - Ok(expr) + )))) } Expr::Alias(_) | Expr::Column(_) @@ -425,7 +418,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(expr), + | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), } } } @@ -1283,7 +1276,7 @@ mod test { std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1318,7 +1311,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; assert_eq!(expected, result); // eq @@ -1329,7 +1322,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; assert_eq!(expected, result); // lt @@ -1340,7 +1333,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f3b8d4b4842a8..fafc6340f1a19 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -25,7 +25,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -745,24 +745,24 @@ struct CommonSubexprRewriter<'a> { impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. if expr.short_circuits() || is_volatile_expression(&expr)? { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)); } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok((expr, TreeNodeRecursion::Continue)); + return Ok(Transformed::no(expr)); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { @@ -771,7 +771,11 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // This expr tree is finished. if self.curr_index >= self.id_array.len() { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Skip, + )); } let (series_number, id) = &self.id_array[self.curr_index]; @@ -784,7 +788,11 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { || id.is_empty() || expr_set_item.1 <= 1 { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Skip, + )); } self.max_series_number = *series_number; @@ -799,10 +807,14 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // Alias this `Column` expr to it original "expr name", // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. - Ok((col(id).alias(expr_name), TreeNodeRecursion::Skip)) + Ok(Transformed::new( + col(id).alias(expr_name), + true, + TreeNodeRecursion::Skip, + )) } else { self.curr_index += 1; - Ok((expr, TreeNodeRecursion::Continue)) + Ok(Transformed::no(expr)) } } _ => internal_err!("expr_set invalid state"), @@ -823,6 +835,7 @@ fn replace_common_expr( max_series_number: 0, curr_index: 0, }) + .map(|t| t.data) } #[cfg(test)] diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 49d3c322ca2b0..b7119966c41ce 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -58,17 +58,17 @@ pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { type Node = LogicalPlan; - fn f_down(&mut self, plan: LogicalPlan) -> Result<(LogicalPlan, TreeNodeRecursion)> { + fn f_down(&mut self, plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Filter(_) => Ok((plan, TreeNodeRecursion::Continue)), + LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case self.can_pull_up = false; - Ok((plan, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) } else { - Ok((plan, TreeNodeRecursion::Continue)) + Ok(Transformed::no(plan)) } } LogicalPlan::Limit(_) => { @@ -77,21 +77,21 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { (false, true) => { // the unsupported case self.can_pull_up = false; - Ok((plan, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) } - _ => Ok((plan, TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(plan)), } } _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; - Ok((plan, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) } - _ => Ok((plan, TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(plan)), } } - fn f_up(&mut self, plan: LogicalPlan) -> Result { + fn f_up(&mut self, plan: LogicalPlan) -> Result> { let subquery_schema = plan.schema().clone(); match &plan { LogicalPlan::Filter(plan_filter) => { @@ -140,7 +140,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { .build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } (None, _) => { // if the subquery still has filter expressions, restore them. @@ -152,7 +152,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = plan.build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } } } @@ -196,7 +196,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => @@ -240,7 +240,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::SubqueryAlias(alias) => { let mut local_correlated_cols = BTreeSet::new(); @@ -262,7 +262,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(plan.clone(), input_map.clone()); } - Ok(plan) + Ok(Transformed::no(plan)) } LogicalPlan::Limit(limit) => { let input_expr_map = self @@ -273,7 +273,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => { + (true, false) => Transformed::yes( if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -281,17 +281,17 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { }) } else { LogicalPlanBuilder::from((*limit.input).clone()).build()? - } - } - _ => plan, + }, + ), + _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { self.collected_count_expr_map - .insert(new_plan.clone(), input_map); + .insert(new_plan.data.clone(), input_map); } Ok(new_plan) } - _ => Ok(plan), + _ => Ok(Transformed::no(plan)), } } } @@ -370,31 +370,34 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let result_expr = e.clone().transform_up(&|expr| { - let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { - match func_def { + let result_expr = e + .clone() + .transform_up(&|expr| { + let new_expr = match expr { + Expr::AggregateFunction(expr::AggregateFunction { + func_def, .. + }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( 0, )))) } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } } AggregateFunctionDefinition::UDF { .. } => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::Name(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } - } - } - _ => Transformed::No(expr), - }; - Ok(new_expr) - })?; + }, + _ => Transformed::no(expr), + }; + Ok(new_expr) + })? + .data; let result_expr = result_expr.unalias(); let props = ExecutionProps::new(); @@ -415,17 +418,22 @@ fn proj_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for expr in proj_expr.iter() { - let result_expr = expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(name) + { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + })? + .data; if result_expr.ne(expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); @@ -448,17 +456,20 @@ fn filter_exprs_evaluation_result_on_empty_batch( input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result> { - let result_expr = filter_expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = filter_expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + })? + .data; let pull_up_expr = if result_expr.ne(filter_expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema); diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 450336376a239..4e94bcc2b085d 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -228,7 +228,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery.clone().rewrite(&mut pull_up)?; + let new_plan = subquery.clone().rewrite(&mut pull_up)?.data; if !pull_up.can_pull_up { return Ok(None); } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 0ae0bc696a352..bb1ff413088f0 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1022,13 +1022,14 @@ pub fn replace_cols_by_name( e.transform_up(&|expr| { Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::Yes(new_c.clone()), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .map(|t| t.data) } /// check whether the expression uses the columns in `check_map`. diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index e1c35e468f68a..0ac053dacd29b 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -56,7 +56,7 @@ impl ScalarSubqueryToJoin { sub_query_info: vec![], alias_gen, }; - let new_expr = predicate.clone().rewrite(&mut extract)?; + let new_expr = predicate.clone().rewrite(&mut extract)?.data; Ok((extract.sub_query_info, new_expr)) } } @@ -86,20 +86,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { build_join(&subquery, &cur_input, &alias)? { if !expr_check_map.is_empty() { - rewrite_expr = - rewrite_expr.clone().transform_up(&|expr| { + rewrite_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + })? + .data; } cur_input = optimized_subquery; } else { @@ -141,20 +143,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { if let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { - let new_expr = - rewrite_expr.clone().transform_up(&|expr| { + let new_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + })? + .data; expr_to_rewrite_expr_map.insert(expr, new_expr); } } @@ -203,7 +207,7 @@ struct ExtractScalarSubQuery { impl TreeNodeRewriter for ExtractScalarSubQuery { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); @@ -213,15 +217,16 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; - Ok(( + Ok(Transformed::new( Expr::Column(create_col_from_scalar_expr( &scalar_expr, subqry_alias, )?), + true, TreeNodeRecursion::Skip, )) } - _ => Ok((expr, TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(expr)), } } } @@ -278,7 +283,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?; + let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?.data; if !pull_up.can_pull_up { return Ok(None); } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fd77071ea7286..9f0e4b82e3a59 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -33,7 +33,7 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::tree_node::Transformed; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{TreeNode, TreeNodeRewriter}, @@ -143,18 +143,25 @@ impl ExprSimplifier { // simplifications can enable new constant evaluation) // https://github.com/apache/arrow-datafusion/issues/1160 expr.rewrite(&mut const_evaluator)? + .data .rewrite(&mut simplifier)? + .data .rewrite(&mut or_in_list_simplifier)? + .data .rewrite(&mut inlist_simplifier)? + .data .rewrite(&mut guarantee_rewriter)? + .data // run both passes twice to try an minimize simplifications that we missed .rewrite(&mut const_evaluator)? + .data .rewrite(&mut simplifier) + .map(|t| t.data) } pub fn canonicalize(&self, expr: Expr) -> Result { let mut canonicalizer = Canonicalizer::new(); - expr.rewrite(&mut canonicalizer) + expr.rewrite(&mut canonicalizer).map(|t| t.data) } /// Apply type coercion to an [`Expr`] so that it can be /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr). @@ -169,7 +176,7 @@ impl ExprSimplifier { pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.rewrite(&mut expr_rewrite).map(|t| t.data) } /// Input guarantees about the values of columns. @@ -249,30 +256,34 @@ impl Canonicalizer { impl TreeNodeRewriter for Canonicalizer { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { - return Ok(expr); + return Ok(Transformed::no(expr)); }; match (left.as_ref(), right.as_ref(), op.swap()) { // (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) if right_col > left_col => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } // (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } - _ => Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })), + _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))), } } } @@ -313,7 +324,7 @@ enum ConstSimplifyResult { impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -337,10 +348,10 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok((expr, TreeNodeRecursion::Continue)) + Ok(Transformed::no(expr)) } - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting // and may not evalute all their sub expressions. Thus if @@ -349,11 +360,15 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { Some(true) => { let result = self.evaluate_to_scalar(expr); match result { - ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)), - ConstSimplifyResult::SimplifyRuntimeError(_, expr) => Ok(expr), + ConstSimplifyResult::Simplified(s) => { + Ok(Transformed::yes(Expr::Literal(s))) + } + ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { + Ok(Transformed::yes(expr)) + } } } - Some(false) => Ok(expr), + Some(false) => Ok(Transformed::no(expr)), _ => internal_err!("Failed to pop can_evaluate"), } } @@ -508,7 +523,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { type Node = Expr; /// rewrite the expression simplifying any constant expressions - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, @@ -516,7 +531,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }; let info = self.info; - let new_expr = match expr { + Ok(match expr { // // Rules for Eq // @@ -529,11 +544,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => *right, Some(false) => Expr::Not(right), None => lit_bool_null(), - } + }) } // A = true --> A // A = false --> !A @@ -543,11 +558,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => *left, Some(false) => Expr::Not(left), None => lit_bool_null(), - } + }) } // expr IN () --> false // expr NOT IN () --> true @@ -556,7 +571,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { list, negated, }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { - lit(negated) + Transformed::yes(lit(negated)) } // null in (x, y, z) --> null @@ -565,7 +580,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { expr, list: _, negated: _, - }) if is_null(&expr) => lit_bool_null(), + }) if is_null(&expr) => Transformed::yes(lit_bool_null()), // expr IN ((subquery)) -> expr IN (subquery), see ##5529 Expr::InList(InList { @@ -578,7 +593,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { let Expr::ScalarSubquery(subquery) = list.remove(0) else { unreachable!() }; - Expr::InSubquery(InSubquery::new(expr, subquery, negated)) + Transformed::yes(Expr::InSubquery(InSubquery::new( + expr, subquery, negated, + ))) } // if expr is a single column reference: @@ -599,7 +616,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { let first_val = list[0].clone(); - if negated { + Transformed::yes(if negated { list.into_iter().skip(1).fold( (*expr.clone()).not_eq(first_val), |acc, y| { @@ -631,7 +648,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { acc.or((*expr.clone()).eq(y)) }, ) - } + }) } // // Rules for NotEq @@ -645,11 +662,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => Expr::Not(right), Some(false) => *right, None => lit_bool_null(), - } + }) } // A != true --> !A // A != false --> A @@ -659,11 +676,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => Expr::Not(left), Some(false) => *left, None => lit_bool_null(), - } + }) } // @@ -675,32 +692,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Or, right: _, - }) if is_true(&left) => *left, + }) if is_true(&left) => Transformed::yes(*left), // false OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&left) => *right, + }) if is_false(&left) => Transformed::yes(*right), // A OR true --> true (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: Or, right, - }) if is_true(&right) => *right, + }) if is_true(&right) => Transformed::yes(*right), // A OR false --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&right) => *left, + }) if is_false(&right) => Transformed::yes(*left), // A OR !A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // !A OR A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -708,32 +725,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // (..A..) OR A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&left, &right, Or) => *left, + }) if expr_contains(&left, &right, Or) => Transformed::yes(*left), // A OR (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&right, &left, Or) => *right, + }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), // A OR (A AND B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => { + Transformed::yes(*left) + } // (A AND B) OR A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => { + Transformed::yes(*right) + } // // Rules for AND @@ -744,32 +765,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: And, right, - }) if is_true(&left) => *right, + }) if is_true(&left) => Transformed::yes(*right), // false AND A --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: And, right: _, - }) if is_false(&left) => *left, + }) if is_false(&left) => Transformed::yes(*left), // A AND true --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if is_true(&right) => *left, + }) if is_true(&right) => Transformed::yes(*left), // A AND false --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: And, right, - }) if is_false(&right) => *right, + }) if is_false(&right) => Transformed::yes(*right), // A AND !A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: And, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // !A AND A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -777,32 +798,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // (..A..) AND A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&left, &right, And) => *left, + }) if expr_contains(&left, &right, And) => Transformed::yes(*left), // A AND (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&right, &left, And) => *right, + }) if expr_contains(&right, &left, And) => Transformed::yes(*right), // A AND (A OR B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => { + Transformed::yes(*left) + } // (A OR B) AND A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -813,25 +838,25 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Multiply, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // 1 * A --> A Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if is_one(&left) => *right, + }) if is_one(&left) => Transformed::yes(*right), // A * null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Multiply, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null * A --> null Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -842,7 +867,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - *right + Transformed::yes(*right) } // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -853,7 +878,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&right)?.is_floating() && is_zero(&left) => { - *left + Transformed::yes(*left) } // @@ -865,19 +890,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Divide, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // null / A --> null Expr::BinaryExpr(BinaryExpr { left, op: Divide, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A / null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Divide, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // // Rules for Modulo @@ -888,13 +913,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: Modulo, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null % A --> null Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -904,7 +929,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - lit(0) + Transformed::yes(lit(0)) } // @@ -916,28 +941,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseAnd, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null & A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A & 0 -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&left)? && is_zero(&right) => *right, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), // 0 & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&right)? && is_zero(&left) => *left, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), // !A & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -945,7 +970,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // A & !A -> 0 (if A not nullable) @@ -954,7 +981,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // (..A..) & A --> (..A..) @@ -962,14 +991,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, - }) if expr_contains(&left, &right, BitwiseAnd) => *left, + }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), // A & (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if expr_contains(&right, &left, BitwiseAnd) => *right, + }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), // A & (A | B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -977,7 +1006,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { - *left + Transformed::yes(*left) } // (A | B) & A --> A (if B not null) @@ -986,7 +1015,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -998,28 +1027,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseOr, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null | A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A | 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // 0 | A -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&left) => *right, + }) if is_zero(&left) => Transformed::yes(*right), // !A | A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1027,7 +1056,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A | !A -> -1 (if A not nullable) @@ -1036,7 +1067,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) | A --> (..A..) @@ -1044,14 +1077,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, - }) if expr_contains(&left, &right, BitwiseOr) => *left, + }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), // A | (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if expr_contains(&right, &left, BitwiseOr) => *right, + }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), // A | (A & B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -1059,7 +1092,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { - *left + Transformed::yes(*left) } // (A & B) | A --> A (if B not null) @@ -1068,7 +1101,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -1080,28 +1113,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseXor, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null ^ A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A ^ 0 -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&left)? && is_zero(&right) => *left, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), // 0 ^ A -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&right)? && is_zero(&left) => *right, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), // !A ^ A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1109,7 +1142,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A ^ !A -> -1 (if A not nullable) @@ -1118,7 +1153,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1128,11 +1165,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); - if expr == *right { + Transformed::yes(if expr == *right { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) } else { expr - } + }) } // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A) @@ -1142,11 +1179,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); - if expr == *left { + Transformed::yes(if expr == *left { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } else { expr - } + }) } // @@ -1158,21 +1195,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftRight, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null >> A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A >> 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for BitwiseShiftRight @@ -1183,31 +1220,31 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftLeft, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null << A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A << 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for Not // - Expr::Not(inner) => negate_clause(*inner), + Expr::Not(inner) => Transformed::yes(negate_clause(*inner)), // // Rules for Negative // - Expr::Negative(inner) => distribute_negation(*inner), + Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)), // // Rules for Case @@ -1261,19 +1298,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, - }) => simpl_log(args, <&S>::clone(&info))?, + }) => Transformed::yes(simpl_log(args, <&S>::clone(&info))?), // power Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, - }) => simpl_power(args, <&S>::clone(&info))?, + }) => Transformed::yes(simpl_power(args, <&S>::clone(&info))?), // concat Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, - }) => simpl_concat(args)?, + }) => Transformed::yes(simpl_concat(args)?), // concat_ws Expr::ScalarFunction(ScalarFunction { @@ -1283,11 +1320,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ), args, }) => match &args[..] { - [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, - _ => Expr::ScalarFunction(ScalarFunction::new( + [delimiter, vals @ ..] => { + Transformed::yes(simpl_concat_ws(delimiter, vals)?) + } + _ => Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::ConcatWithSeparator, args, - )), + ))), }, // @@ -1296,18 +1335,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // a between 3 and 5 --> a >= 3 AND a <=5 // a not between 3 and 5 --> a < 3 OR a > 5 - Expr::Between(between) => { - if between.negated { - let l = *between.expr.clone(); - let r = *between.expr; - or(l.lt(*between.low), r.gt(*between.high)) - } else { - and( - between.expr.clone().gt_eq(*between.low), - between.expr.lt_eq(*between.high), - ) - } - } + Expr::Between(between) => Transformed::yes(if between.negated { + let l = *between.expr.clone(); + let r = *between.expr; + or(l.lt(*between.low), r.gt(*between.high)) + } else { + and( + between.expr.clone().gt_eq(*between.low), + between.expr.lt_eq(*between.high), + ) + }), // // Rules for regexes @@ -1316,7 +1353,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, - }) => simplify_regex_expr(left, op, right)?, + }) => Transformed::yes(simplify_regex_expr(left, op, right)?), // Rules for Like Expr::Like(Like { @@ -1331,25 +1368,24 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::Literal(ScalarValue::Utf8(Some(pattern_str))) if pattern_str == "%" ) => { - lit(!negated) + Transformed::yes(lit(!negated)) } // a is not null/unknown --> true (if a is not nullable) Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) if !info.nullable(&expr)? => { - lit(true) + Transformed::yes(lit(true)) } // a is null/unknown --> false (if a is not nullable) Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { - lit(false) + Transformed::yes(lit(false)) } // no additional rewrites possible - expr => expr, - }; - Ok(new_expr) + expr => Transformed::no(expr), + }) } } @@ -1473,6 +1509,7 @@ mod tests { let evaluated_expr = input_expr .clone() .rewrite(&mut const_evaluator) + .map(|t| t.data) .expect("successfully evaluated"); assert_eq!( diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index e7c619c046de8..8b243f82c714d 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -21,6 +21,7 @@ use std::{borrow::Cow, collections::HashMap}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; @@ -59,21 +60,23 @@ impl<'a> GuaranteeRewriter<'a> { impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if self.guarantees.is_empty() { - return Ok(expr); + return Ok(Transformed::no(expr)); } match &expr { Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(true)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), + Some(NullableInterval::NotNull { .. }) => { + Ok(Transformed::yes(lit(false))) + } + _ => Ok(Transformed::no(expr)), }, Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(false)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), + Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), + _ => Ok(Transformed::no(expr)), }, Expr::Between(Between { expr: inner, @@ -93,14 +96,14 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let contains = expr_interval.contains(*interval)?; if contains.is_certainly_true() { - Ok(lit(!negated)) + Ok(Transformed::yes(lit(!negated))) } else if contains.is_certainly_false() { - Ok(lit(*negated)) + Ok(Transformed::yes(lit(*negated))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -135,23 +138,23 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let result = left_interval.apply_operator(op, right_interval.as_ref())?; if result.is_certainly_true() { - Ok(lit(true)) + Ok(Transformed::yes(lit(true))) } else if result.is_certainly_false() { - Ok(lit(false)) + Ok(Transformed::yes(lit(false))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } // Columns (if interval is collapsed to a single value) Expr::Column(_) => { if let Some(interval) = self.guarantees.get(&expr) { - Ok(interval.single_value().map_or(expr, lit)) + Ok(Transformed::yes(interval.single_value().map_or(expr, lit))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -181,17 +184,17 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { }) .collect::>()?; - Ok(Expr::InList(InList { + Ok(Transformed::yes(Expr::InList(InList { expr: inner.clone(), list: new_list, negated: *negated, - })) + }))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -221,12 +224,12 @@ mod tests { // x IS NULL => guaranteed false let expr = col("x").is_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); assert_eq!(output, lit(false)); // x IS NOT NULL => guaranteed true let expr = col("x").is_not_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); assert_eq!(output, lit(true)); } @@ -236,7 +239,7 @@ mod tests { T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).map(|t| t.data).unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -248,7 +251,7 @@ mod tests { fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).map(|t| t.data).unwrap(); assert_eq!( &output, expr, "{} was simplified to {}, but expected it to be unchanged", @@ -478,7 +481,7 @@ mod tests { let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let output = col("x").rewrite(&mut rewriter).unwrap(); + let output = col("x").rewrite(&mut rewriter).map(|t| t.data).unwrap(); assert_eq!(output, Expr::Literal(scalar.clone())); } } @@ -522,7 +525,7 @@ mod tests { .collect(), *negated, ); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); let expected_list = expected_list .iter() .map(|v| lit(ScalarValue::Int32(Some(*v)))) diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 867e96d213d99..c9d9c00c335ec 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -19,7 +19,7 @@ use std::collections::HashSet; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; @@ -51,30 +51,30 @@ impl InListSimplifier { impl TreeNodeRewriter for InListSimplifier { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) = (left.as_ref(), op, right.as_ref()) { if l1.expr == l2.expr && !l1.negated && !l2.negated { - return inlist_intersection(l1, l2, false); + return Ok(Transformed::yes(inlist_intersection(l1, l2, false)?)); } else if l1.expr == l2.expr && l1.negated && l2.negated { - return inlist_union(l1, l2, true); + return Ok(Transformed::yes(inlist_union(l1, l2, true)?)); } else if l1.expr == l2.expr && !l1.negated && l2.negated { - return inlist_except(l1, l2); + return Ok(Transformed::yes(inlist_except(l1, l2)?)); } else if l1.expr == l2.expr && l1.negated && !l2.negated { - return inlist_except(l2, l1); + return Ok(Transformed::yes(inlist_except(l2, l1)?)); } } else if let (Expr::InList(l1), Operator::Or, Expr::InList(l2)) = (left.as_ref(), op, right.as_ref()) { if l1.expr == l2.expr && l1.negated && l2.negated { - return inlist_intersection(l1, l2, true); + return Ok(Transformed::yes(inlist_intersection(l1, l2, true)?)); } } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs index ea02c1f3af8a2..ff50b337e1580 100644 --- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs @@ -20,7 +20,7 @@ use std::borrow::Cow; use std::collections::HashSet; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; use datafusion_expr::{BinaryExpr, Expr, Operator}; @@ -39,7 +39,7 @@ impl OrInListSimplifier { impl TreeNodeRewriter for OrInListSimplifier { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if *op == Operator::Or { let left = as_inlist(left); @@ -66,13 +66,13 @@ impl TreeNodeRewriter for OrInListSimplifier { list, negated: false, }; - return Ok(Expr::InList(merged_inlist)); + return Ok(Transformed::yes(Expr::InList(merged_inlist))); } } } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 0232a28c722a6..52c9eefb9bab8 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; @@ -129,7 +129,7 @@ struct UnwrapCastExprRewriter { impl TreeNodeRewriter for UnwrapCastExprRewriter { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match &expr { // For case: // try_cast/cast(expr as data_type) op literal @@ -157,11 +157,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( lit(value), *op, expr.as_ref().clone(), - )); + ))); } } ( @@ -176,11 +176,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( expr.as_ref().clone(), *op, lit(value), - )); + ))); } } (_, _) => { @@ -189,7 +189,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; } // return the new binary op - Ok(binary_expr(left, *op, right)) + Ok(Transformed::yes(binary_expr(left, *op, right))) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) @@ -213,12 +213,12 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { let internal_left_type = internal_left.get_type(&self.schema); if internal_left_type.is_err() { // error data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let internal_left_type = internal_left_type?; if !is_support_data_type(&internal_left_type) { // not supported data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let right_exprs = list .iter() @@ -253,17 +253,19 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }) .collect::>>(); match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) - } - Err(_) => Ok(expr), + Ok(right_exprs) => Ok(Transformed::yes(in_list( + internal_left, + right_exprs, + *negated, + ))), + Err(_) => Ok(Transformed::no(expr)), } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } // TODO: handle other expr type and dfs visit them - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -730,7 +732,7 @@ mod tests { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.rewrite(&mut expr_rewriter).unwrap() + expr.rewrite(&mut expr_rewriter).map(|t| t.data).unwrap() } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 9ee9be94a5f25..87e71e3458cd5 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -263,11 +263,12 @@ impl EquivalenceGroup { .transform_up(&|expr| { for cls in self.iter() { if cls.contains(&expr) { - return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + return Ok(Transformed::yes(cls.canonical_expr().unwrap())); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) .unwrap_or(expr) } @@ -458,11 +459,12 @@ impl EquivalenceGroup { column.index() + left_size, )) as _; - return Ok(Transformed::Yes(new_column)); + return Ok(Transformed::yes(new_column)); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) .unwrap(); result.add_equal_conditions(&new_lhs, &new_rhs); } diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 387dce2cdc8b2..43cb90e72f5fb 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -48,12 +48,13 @@ pub fn add_offset_to_expr( offset: usize, ) -> Arc { expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + Some(col) => Ok(Transformed::yes(Arc::new(Column::new( col.name(), offset + col.index(), )))), - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .map(|t| t.data) .unwrap() // Note that we can safely unwrap here since our transform always returns // an `Ok` value. diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 0f92b2c2f431d..a96fbb6e484b2 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -68,10 +68,11 @@ impl ProjectionMapping { let matching_input_field = input_schema.field(idx); let matching_input_column = Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) + Ok(Transformed::yes(Arc::new(matching_input_column))) } - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .map(|t| t.data) .map(|source_expr| (source_expr, target_expr)) }) .collect::>>() diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 2471d9249e163..cf05a97e21dd9 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -778,6 +778,7 @@ impl EquivalenceProperties { pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new_default(expr.clone()) .transform_up(&|expr| Ok(update_ordering(expr, self))) + .map(|t| t.data) // Guaranteed to always return `Ok`. .unwrap() } @@ -816,9 +817,9 @@ fn update_ordering( // We have a Literal, which is the other possible leaf node type: node.data = node.expr.get_ordering(&[]); } else { - return Transformed::No(node); + return Transformed::no(node); } - Transformed::Yes(node) + Transformed::yes(node) } /// This function determines whether the provided expression is constant diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b04c66b237289..59c6886d0c0ef 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -972,11 +972,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .map(|t| t.data) .unwrap(); let expr3 = expr @@ -993,11 +994,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .map(|t| t.data) .unwrap(); assert!(expr.ne(&expr2)); diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index a8d1e3638a177..253ed8da695b7 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -32,6 +32,7 @@ use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; +use datafusion_common::tree_node::Transformed; use itertools::izip; /// Expression that can be evaluated against a RecordBatch @@ -185,7 +186,7 @@ pub type PhysicalExprRef = Arc; pub fn with_new_children_if_necessary( expr: Arc, children: Vec>, -) -> Result> { +) -> Result>> { let old_children = expr.children(); if children.len() != old_children.len() { internal_err!("PhysicalExpr: Wrong number of children") @@ -195,9 +196,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - expr.with_new_children(children) + Ok(Transformed::yes(expr.with_new_children(children)?)) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/physical-expr/src/tree_node.rs b/datafusion/physical-expr/src/tree_node.rs index 42dc6673af6ab..8f21ffb824570 100644 --- a/datafusion/physical-expr/src/tree_node.rs +++ b/datafusion/physical-expr/src/tree_node.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::physical_expr::{with_new_children_if_necessary, PhysicalExpr}; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode, Transformed}; use datafusion_common::Result; impl DynTreeNode for dyn PhysicalExpr { @@ -34,7 +34,7 @@ impl DynTreeNode for dyn PhysicalExpr { &self, arc_self: Arc, new_children: Vec>, - ) -> Result> { + ) -> Result>> { with_new_children_if_necessary(arc_self, new_children) } } @@ -63,7 +63,7 @@ impl ExprContext { pub fn update_expr_from_children(mut self) -> Result { let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); - self.expr = with_new_children_if_necessary(self.expr, children_expr)?; + self.expr = with_new_children_if_necessary(self.expr, children_expr)?.data; Ok(self) } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 8d4f4cad4afaa..694a18e147d35 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -172,7 +172,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(Transformed::Yes(node)) + Ok(Transformed::yes(node)) } } @@ -193,7 +193,9 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.transform_up_mut(&mut |node| builder.mutate(node))?; + let root = init + .transform_up_mut(&mut |node| builder.mutate(node))? + .data; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -230,13 +232,14 @@ pub fn reassign_predicate_columns( Err(_) if ignore_not_found => usize::MAX, Err(e) => return Err(e.into()), }; - return Ok(Transformed::Yes(Arc::new(Column::new( + return Ok(Transformed::yes(Arc::new(Column::new( column.name(), index, )))); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) } /// Reverses the ORDER BY expression, which is useful during equivalent window diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 41c8dbed14536..bdb8234be7918 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -165,7 +165,7 @@ mod tests { let schema = test::aggr_test_schema(); let empty = Arc::new(EmptyExec::new(schema.clone())); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); + let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.data; assert_eq!(empty.schema(), empty2.schema()); let too_many_kids = vec![empty2]; diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 9a4c98927683d..3484ee45ba6ae 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -284,14 +284,16 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { - Some(transformed) => Transformed::Yes(transformed), - None => Transformed::No(p), - } - }) - })?; + let converted_filter_expr = expr + .transform_up(&|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { + Some(transformed) => Transformed::yes(transformed), + None => Transformed::no(p), + } + }) + })? + .data; // Search the converted `PhysicalExpr` in filter expression; if an exact // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 70f315917351e..073d5a035e0c9 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -478,13 +478,17 @@ fn replace_on_columns_of_right_ordering( ) -> Result<()> { for (left_col, right_col) in on_columns { for item in right_ordering.iter_mut() { - let new_expr = item.expr.clone().transform_up(&|e| { - if e.eq(right_col) { - Ok(Transformed::Yes(left_col.clone())) - } else { - Ok(Transformed::No(e)) - } - })?; + let new_expr = item + .expr + .clone() + .transform_up(&|e| { + if e.eq(right_col) { + Ok(Transformed::yes(left_col.clone())) + } else { + Ok(Transformed::no(e)) + } + })? + .data; item.expr = new_expr; } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 0a9eab5c8633a..0a147a29e1a83 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -489,9 +489,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - Ok(Transformed::Yes(plan.with_new_children(children)?)) + Ok(Transformed::yes(plan.with_new_children(children)?)) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 3ab3de62f37a7..04482d7c1cc17 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -172,7 +172,7 @@ mod tests { let placeholder = Arc::new(PlaceholderRowExec::new(schema)); let placeholder_2 = - with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + with_new_children_if_necessary(placeholder.clone(), vec![])?.data; assert_eq!(placeholder.schema(), placeholder_2.schema()); let too_many_kids = vec![placeholder_2]; diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 614ab990ac49a..1683159f3cee1 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -317,16 +317,17 @@ fn assign_work_table( ) } else { work_table_refs += 1; - Ok(Transformed::Yes(Arc::new( + Ok(Transformed::yes(Arc::new( exec.with_work_table(work_table.clone()), ))) } } else if plan.as_any().is::() { not_impl_err!("Recursive queries cannot be nested") } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .map(|t| t.data) } impl Stream for RecursiveQueryStream { diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index b8a5f95c53252..c4223cb734304 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -34,8 +34,8 @@ impl DynTreeNode for dyn ExecutionPlan { &self, arc_self: Arc, new_children: Vec>, - ) -> Result> { - with_new_children_if_necessary(arc_self, new_children).map(Transformed::into) + ) -> Result>> { + with_new_children_if_necessary(arc_self, new_children) } } @@ -63,7 +63,7 @@ impl PlanContext { pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); - self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + self.plan = with_new_children_if_necessary(self.plan, children_plans)?.data; Ok(self) } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 0dc1258ebabea..3f6f3aa483ab4 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -33,18 +33,20 @@ use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - expr.clone().transform_up(&|nested_expr| { - match nested_expr { - Expr::Column(col) => { - let field = plan.schema().field_from_column(&col)?; - Ok(Transformed::Yes(Expr::Column(field.qualified_column()))) - } - _ => { - // keep recursing - Ok(Transformed::No(nested_expr)) + expr.clone() + .transform_up(&|nested_expr| { + match nested_expr { + Expr::Column(col) => { + let field = plan.schema().field_from_column(&col)?; + Ok(Transformed::yes(Expr::Column(field.qualified_column()))) + } + _ => { + // keep recursing + Ok(Transformed::no(nested_expr)) + } } - } - }) + }) + .map(|t| t.data) } /// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s. @@ -66,13 +68,15 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_down(&|nested_expr| { - if base_exprs.contains(&nested_expr) { - Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) - } else { - Ok(Transformed::No(nested_expr)) - } - }) + expr.clone() + .transform_down(&|nested_expr| { + if base_exprs.contains(&nested_expr) { + Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?)) + } else { + Ok(Transformed::no(nested_expr)) + } + }) + .map(|t| t.data) } /// Determines if the set of `Expr`'s are a valid projection on the input @@ -170,16 +174,18 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - expr.clone().transform_up(&|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::Yes(aliased_expr.clone())) - } else { - Ok(Transformed::No(Expr::Column(c))) + expr.clone() + .transform_up(&|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::yes(aliased_expr.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) + } } - } - _ => Ok(Transformed::No(nested_expr)), - }) + _ => Ok(Transformed::no(nested_expr)), + }) + .map(|t| t.data) } /// given a slice of window expressions sharing the same sort key, find their common partition diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index b128d661f31a9..ab2b0a2ce960d 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -92,7 +92,7 @@ In our example, we'll use rewriting to update our `add_one` UDF, to be rewritten ### Rewriting with `transform` -To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::No` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::Yes` is used to wrap the new `Expr`. +To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::no` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::yes` is used to wrap the new `Expr`. ```rust fn rewrite_add_one(expr: Expr) -> Result { @@ -102,9 +102,9 @@ fn rewrite_add_one(expr: Expr) -> Result { let input_arg = scalar_fun.args[0].clone(); let new_expression = input_arg + lit(1i64); - Transformed::Yes(new_expression) + Transformed::yes(new_expression) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) }