From 1bedfec3b1b5692fde11297403251e8608131927 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 24 Mar 2024 08:58:55 -0400 Subject: [PATCH] Add TreeNodeMutator API Use TreeNode API in Optimizer --- datafusion-examples/examples/rewrite_expr.rs | 2 +- datafusion/common/src/tree_node.rs | 162 +++++++- datafusion/core/src/execution/context/mod.rs | 4 +- .../core/tests/optimizer_integration.rs | 2 +- datafusion/expr/src/logical_plan/ddl.rs | 18 + datafusion/expr/src/logical_plan/mod.rs | 1 + datafusion/expr/src/logical_plan/mutate.rs | 346 ++++++++++++++++++ datafusion/expr/src/tree_node/plan.rs | 7 + .../optimizer/src/analyzer/type_coercion.rs | 3 +- .../src/decorrelate_predicate_subquery.rs | 96 ++--- .../src/eliminate_duplicated_expr.rs | 6 +- datafusion/optimizer/src/eliminate_filter.rs | 14 +- datafusion/optimizer/src/eliminate_join.rs | 6 +- datafusion/optimizer/src/eliminate_limit.rs | 32 +- .../optimizer/src/eliminate_nested_union.rs | 22 +- .../optimizer/src/eliminate_one_union.rs | 6 +- .../optimizer/src/eliminate_outer_join.rs | 12 +- .../src/extract_equijoin_predicate.rs | 18 +- .../optimizer/src/filter_null_join_keys.rs | 14 +- .../optimizer/src/optimize_projections.rs | 50 +-- datafusion/optimizer/src/optimizer.rs | 282 +++++++------- .../optimizer/src/propagate_empty_relation.rs | 22 +- datafusion/optimizer/src/push_down_filter.rs | 147 ++++---- datafusion/optimizer/src/push_down_limit.rs | 70 ++-- .../optimizer/src/push_down_projection.rs | 76 ++-- .../src/replace_distinct_aggregate.rs | 4 +- .../optimizer/src/scalar_subquery_to_join.rs | 28 +- .../src/single_distinct_to_groupby.rs | 40 +- datafusion/optimizer/src/test/mod.rs | 69 ++-- .../optimizer/tests/optimizer_integration.rs | 2 +- datafusion/sqllogictest/test_files/join.slt | 2 +- .../sqllogictest/test_files/predicates.slt | 2 +- datafusion/sqllogictest/test_files/scalar.slt | 2 - .../sqllogictest/test_files/subquery.slt | 16 +- .../sqllogictest/test_files/timestamps.slt | 2 +- 35 files changed, 1049 insertions(+), 536 deletions(-) create mode 100644 datafusion/expr/src/logical_plan/mutate.rs diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 541448ebf149..dcebbb55fb66 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -59,7 +59,7 @@ pub fn main() -> Result<()> { // then run the optimizer with our custom rule let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?; + let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; println!( "Optimized Logical Plan:\n\n{}\n", optimized_plan.display_indent() diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 2d653a27c47b..3555b0551e81 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -20,7 +20,7 @@ use std::sync::Arc; -use crate::Result; +use crate::{error::_not_impl_err, Result}; /// This macro is used to control continuation behaviors during tree traversals /// based on the specified direction. Depending on `$DIRECTION` and the value of @@ -100,6 +100,10 @@ pub trait TreeNode: Sized { /// Visit the tree node using the given [`TreeNodeVisitor`], performing a /// depth-first walk of the node and its children. /// + /// See also: + /// * [`Self::mutate`] to rewrite `TreeNode`s in place + /// * [`Self::rewrite`] to rewrite owned `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -144,6 +148,10 @@ pub trait TreeNode: Sized { /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for /// recursively transforming [`TreeNode`]s. /// + /// See also: + /// * [`Self::mutate`] to rewrite `TreeNode`s in place + /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -174,6 +182,70 @@ pub trait TreeNode: Sized { }) } + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for + /// recursively mutating / rewriting [`TreeNode`]s in place. + /// + /// See also: + /// * [`Self::rewrite`] to rewrite owned `TreeNode`s + /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s + /// + /// Consider the following tree structure: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// Here, the nodes would be mutataed in the following order: + /// ```text + /// TreeNodeMutator::f_down(ParentNode) + /// TreeNodeMutator::f_down(ChildNode1) + /// TreeNodeMutator::f_up(ChildNode1) + /// TreeNodeMutator::f_down(ChildNode2) + /// TreeNodeMutator::f_up(ChildNode2) + /// TreeNodeMutator::f_up(ParentNode) + /// ``` + /// + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. + /// + /// # Error Handling + /// + /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], + /// the recursion stops immediately and the tree may be left partially changed + /// + /// # Changing Children During Traversal + /// + /// If `f_down` changes the nodes children, the new children are visited + /// (not the old children prior to rewrite) + fn mutate>( + &mut self, + mutator: &mut M, + ) -> Result> { + // Note this is an inlined version of handle_transform_recursion! + let pre_visited = mutator.f_down(self)?; + + // Traverse children and then call f_up on self if necessary + match pre_visited.tnr { + TreeNodeRecursion::Continue => { + // rewrite children recursively with mutator + self.mutate_children(|c| c.mutate(mutator))? + .try_transform_node_with( + |_: ()| mutator.f_up(self), + TreeNodeRecursion::Jump, + ) + } + TreeNodeRecursion::Jump => { + // skip other children and start back up + mutator.f_up(self) + } + TreeNodeRecursion::Stop => return Ok(pre_visited), + } + .map(|mut post_visited| { + post_visited.transformed |= pre_visited.transformed; + post_visited + }) + } + /// Applies `f` to the node and its children. `f` is applied in a pre-order /// way, and it is controlled by [`TreeNodeRecursion`], which means result /// of the `f` on a node can cause an early return. @@ -353,13 +425,38 @@ pub trait TreeNode: Sized { } /// Apply the closure `F` to the node's children. + /// + /// See `mutate_children` for rewriting in place fn apply_children Result>( &self, f: &mut F, ) -> Result; - /// Apply transform `F` to the node's children. Note that the transform `F` - /// might have a direction (pre-order or post-order). + /// Rewrite the node's children in place using `F`. + /// + /// On error, `self` is left partially rewritten. + /// + /// # Notes + /// + /// Using [`Self::map_children`], the owned API, has clearer semantics on + /// error (the node is consumed). However, it requires copying the interior + /// fields of the tree node during rewrite. + /// + /// This API writes the nodes in place, which can be faster as it avoids + /// copying, but leaves the tree node in an partially rewritten state when + /// an error occurs. + fn mutate_children Result>>( + &mut self, + _f: F, + ) -> Result> { + _not_impl_err!( + "mutate_children not implemented for {} yet", + std::any::type_name::() + ) + } + + /// Apply transform `F` to potentially rewrite the node's children. Note + /// that the transform `F` might have a direction (pre-order or post-order). fn map_children Result>>( self, f: F, @@ -411,6 +508,41 @@ pub trait TreeNodeRewriter: Sized { } } +/// Trait for mutating (rewriting in place) [`TreeNode`]s +/// +/// # See Also: +/// * [`TreeNodeRewriter`] for rewriting owned `TreeNode`e +/// * [`TreeNodeVisitor`] for visiting, but not changing, `TreeNode`s +pub trait TreeNodeMutator: Sized { + /// The node type to mutating. + type Node: TreeNode; + + /// Invoked while traversing down the tree before any children are mutated. + /// Default implementation does nothing to the node and continues recursion. + /// + /// # Notes + /// + /// As the node maybe mutated in place, the returned [`Transformed`] object + /// returns `()` (no data). + /// + /// If the node's children are changed by `f_down`, the *new* children are + /// visited, not the original children. + fn f_down(&mut self, _node: &mut Self::Node) -> Result> { + Ok(Transformed::no(())) + } + + /// Invoked while traversing up the tree after all children have been mutated. + /// Default implementation does nothing to the node and continues recursion. + /// + /// # Notes + /// + /// As the node maybe mutated in place, the returned [`Transformed`] object + /// returns `()` (no data). + fn f_up(&mut self, _node: &mut Self::Node) -> Result> { + Ok(Transformed::no(())) + } +} + /// Controls how [`TreeNode`] recursions should proceed. #[derive(Debug, PartialEq, Clone, Copy)] pub enum TreeNodeRecursion { @@ -489,6 +621,11 @@ impl Transformed { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } + /// Invokes f(), depending on the value of self.tnr. + /// + /// This is used to conditionally apply a function during a f_up tree + /// traversal, if the result of children traversal was `[`TreeNodeRecursion::Continue`]. + /// /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] /// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently /// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of @@ -532,6 +669,25 @@ impl Transformed { } } +impl Transformed<()> { + /// Invoke the given function `f` and combine the transformed state with + /// the current state: + /// + /// * if `f` returns an Err, returns that err + /// + /// * If `f` returns Ok, sets `self.transformed` to `true` if either self or + /// the result of `f` were transformed. + pub fn and_then(self, f: F) -> Result> + where + F: FnOnce() -> Result>, + { + f().map(|mut t| { + t.transformed |= self.transformed; + t + }) + } +} + /// Transformation helper to process tree nodes that are siblings. pub trait TransformedIterator: Iterator { fn map_until_stop_and_collect< diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 31f390607f04..c78fba3e5635 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1877,7 +1877,7 @@ impl SessionState { // optimize the child plan, capturing the output of each optimizer let optimized_plan = self.optimizer.optimize( - &analyzed_plan, + analyzed_plan, self, |optimized_plan, optimizer| { let optimizer_name = optimizer.name().to_string(); @@ -1907,7 +1907,7 @@ impl SessionState { let analyzed_plan = self.analyzer .execute_and_check(plan, self.options(), |_, _| {})?; - self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) + self.optimizer.optimize(analyzed_plan, self, |_, _| {}) } } diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index 60010bdddfb8..6e938361ddb4 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}) } #[derive(Default)] diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 968c40c8bf62..73b98565774d 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -112,6 +112,24 @@ impl DdlStatement { } } + /// Return a mutable reference to the input `LogicalPlan`, if any + pub fn input_mut(&mut self) -> Option<&mut Arc> { + match self { + DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) => { + Some(input) + } + DdlStatement::CreateExternalTable(_) => None, + DdlStatement::CreateView(CreateView { input, .. }) => Some(input), + DdlStatement::CreateCatalogSchema(_) => None, + DdlStatement::CreateCatalog(_) => None, + DdlStatement::DropTable(_) => None, + DdlStatement::DropView(_) => None, + DdlStatement::DropCatalogSchema(_) => None, + DdlStatement::CreateFunction(_) => None, + DdlStatement::DropFunction(_) => None, + } + } + /// Return a `format`able structure with the a human readable /// description of this LogicalPlan node per node, not including /// children. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 84781cb2e9ec..ef7a4a20f218 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -20,6 +20,7 @@ mod ddl; pub mod display; pub mod dml; mod extension; +mod mutate; mod plan; mod statement; diff --git a/datafusion/expr/src/logical_plan/mutate.rs b/datafusion/expr/src/logical_plan/mutate.rs new file mode 100644 index 000000000000..da2a4d6d6b65 --- /dev/null +++ b/datafusion/expr/src/logical_plan/mutate.rs @@ -0,0 +1,346 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::plan::*; +use crate::expr::{Exists, InSubquery}; +use crate::{Expr, UserDefinedLogicalNode}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Result}; +use datafusion_common::{Column, DFSchema, DFSchemaRef}; +use std::sync::{Arc, OnceLock}; + +impl LogicalPlan { + /// applies `f` to each expression of this node, potentially rewriting it in + /// place + /// + /// If `f` returns an error, the error is returned and the expressions are + /// left in a partially modified state + pub fn rewrite_exprs(&mut self, mut f: F) -> Result> + where + F: FnMut(&mut Expr) -> Result>, + { + match self { + LogicalPlan::Projection(Projection { expr, .. }) => { + rewrite_expr_iter_mut(expr.iter_mut(), f) + } + LogicalPlan::Values(Values { values, .. }) => { + rewrite_expr_iter_mut(values.iter_mut().flatten(), f) + } + LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::Hash(expr, _) => rewrite_expr_iter_mut(expr.iter_mut(), f), + Partitioning::DistributeBy(expr) => { + rewrite_expr_iter_mut(expr.iter_mut(), f) + } + Partitioning::RoundRobinBatch(_) => Ok(Transformed::no(())), + }, + LogicalPlan::Window(Window { window_expr, .. }) => { + rewrite_expr_iter_mut(window_expr.iter_mut(), f) + } + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => { + let exprs = group_expr.iter_mut().chain(aggr_expr.iter_mut()); + rewrite_expr_iter_mut(exprs, f) + } + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { on, filter, .. }) => { + let exprs = on + .iter_mut() + .flat_map(|(e1, e2)| std::iter::once(e1).chain(std::iter::once(e2))); + + let result = rewrite_expr_iter_mut(exprs, &mut f)?; + + if let Some(filter) = filter.as_mut() { + result.and_then(|| f(filter)) + } else { + Ok(result) + } + } + LogicalPlan::Sort(Sort { expr, .. }) => { + rewrite_expr_iter_mut(expr.iter_mut(), f) + } + LogicalPlan::Extension(extension) => { + rewrite_extension_exprs(&mut extension.node, f) + } + LogicalPlan::TableScan(TableScan { filters, .. }) => { + rewrite_expr_iter_mut(filters.iter_mut(), f) + } + LogicalPlan::Unnest(Unnest { column, .. }) => rewrite_column(column, f), + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => { + let exprs = on_expr + .iter_mut() + .chain(select_expr.iter_mut()) + .chain(sort_expr.iter_mut().flat_map(|x| x.iter_mut())); + + rewrite_expr_iter_mut(exprs, f) + } + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Ok(Transformed::no(())), + } + } + + /// applies `f` to each input of this node, rewriting them in place. + /// + /// # Notes + /// Inputs include both direct children as well as any embedded subquery + /// `LogicalPlan`s, for example such as are in [`Expr::Exists`]. + /// + /// If `f` returns an `Err`, that Err is returned, and the inputs are left + /// in a partially modified state + pub fn rewrite_inputs(&mut self, mut f: F) -> Result> + where + F: FnMut(&mut LogicalPlan) -> Result>, + { + let children_result = match self { + LogicalPlan::Projection(Projection { input, .. }) => { + rewrite_arc(input, &mut f) + } + LogicalPlan::Filter(Filter { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Repartition(Repartition { input, .. }) => { + rewrite_arc(input, &mut f) + } + LogicalPlan::Window(Window { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Aggregate(Aggregate { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Join(Join { left, right, .. }) => { + rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f)) + } + LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { + rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f)) + } + LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Subquery(Subquery { subquery, .. }) => { + rewrite_arc(subquery, &mut f) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { + rewrite_arc(input, &mut f) + } + LogicalPlan::Extension(extension) => { + rewrite_extension_inputs(&mut extension.node, &mut f) + } + LogicalPlan::Union(Union { inputs, .. }) => inputs + .iter_mut() + .try_fold(Transformed::no(()), |acc, input| { + acc.and_then(|| rewrite_arc(input, &mut f)) + }), + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => rewrite_arc(input, &mut f), + LogicalPlan::Explain(explain) => rewrite_arc(&mut explain.plan, &mut f), + LogicalPlan::Analyze(analyze) => rewrite_arc(&mut analyze.input, &mut f), + LogicalPlan::Dml(write) => rewrite_arc(&mut write.input, &mut f), + LogicalPlan::Copy(copy) => rewrite_arc(&mut copy.input, &mut f), + LogicalPlan::Ddl(ddl) => { + if let Some(input) = ddl.input_mut() { + rewrite_arc(input, &mut f) + } else { + Ok(Transformed::no(())) + } + } + LogicalPlan::Unnest(Unnest { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Prepare(Prepare { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => rewrite_arc(static_term, &mut f)? + .and_then(|| rewrite_arc(recursive_term, &mut f)), + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())), + }?; + + // after visiting the actual children we we need to visit any subqueries + // that are inside the expressions + children_result.and_then(|| self.rewrite_subqueries(&mut f)) + } + + /// applies `f` to LogicalPlans in any subquery expressions + /// + /// If Err is returned, the plan may be left in a partially modified state + fn rewrite_subqueries(&mut self, mut f: F) -> Result> + where + F: FnMut(&mut LogicalPlan) -> Result>, + { + self.rewrite_exprs(|expr| match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + rewrite_arc(&mut subquery.subquery, &mut f) + } + _ => Ok(Transformed::no(())), + }) + } +} + +/// writes each `&mut Expr` in the iterator using `f` +fn rewrite_expr_iter_mut<'a, F>( + i: impl IntoIterator, + mut f: F, +) -> Result> +where + F: FnMut(&mut Expr) -> Result>, +{ + i.into_iter() + .try_fold(Transformed::no(()), |acc, expr| acc.and_then(|| f(expr))) +} + +/// A temporary node that is left in place while rewriting the children of a +/// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is +/// always in a valid state (from the Rust perspective) +static PLACEHOLDER: OnceLock> = OnceLock::new(); + +/// Applies `f` to rewrite the existing node, while avoiding `clone`'ing as much +/// as possiblw. +/// +/// TODO eventually remove `Arc` from `LogicalPlan` and have it own +/// its inputs, so this code would not be needed. However, for now we try and +/// unwrap the `Arc` which avoids `clone`ing in most cases. +/// +/// On error, node be left with a placeholder logical plan +fn rewrite_arc(node: &mut Arc, mut f: F) -> Result> +where + F: FnMut(&mut LogicalPlan) -> Result>, +{ + // We need to leave a valid node in the Arc, while we rewrite the existing + // one, so use a single global static placeholder node + let mut new_node = PLACEHOLDER + .get_or_init(|| { + Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(DFSchema::empty()), + })) + }) + .clone(); + + // take the old value out of the Arc + std::mem::swap(node, &mut new_node); + + // try to update existing node, if it isn't shared with others + let mut new_node = Arc::try_unwrap(new_node) + // if None is returned, there is another reference to this + // LogicalPlan, so we must clone instead + .unwrap_or_else(|node| node.as_ref().clone()); + + // apply the actual transform + let result = f(&mut new_node)?; + + // put the new value back into the Arc + let mut new_node = Arc::new(new_node); + std::mem::swap(node, &mut new_node); + + Ok(result) +} + +/// Rewrites a [`Column`] in place using the provided closure +fn rewrite_column(column: &mut Column, mut f: F) -> Result> +where + F: FnMut(&mut Expr) -> Result>, +{ + // Since `Column`'s isn't an `Expr`, but the closure in terms of Exprs, + // we make a temporary Expr to rewrite and then put it back + + let mut swap_column = Column::new_unqualified("TEMP_unnest_column"); + std::mem::swap(column, &mut swap_column); + + let mut expr = Expr::Column(swap_column); + let result = f(&mut expr)?; + // Get the rewritten column + let Expr::Column(mut swap_column) = expr else { + return internal_err!( + "Rewrite of Column Expr must return Column, returned {expr:?}" + ); + }; + // put the rewritten column back + std::mem::swap(column, &mut swap_column); + Ok(result) +} + +/// Rewrites all expressions for an Extension node "in place" +/// (it currently has to copy values because there are no APIs for in place modification) +/// TODO file ticket for inplace modificiation of Extension nodes +/// +/// Should be removed when we have an API for in place modifications of the +/// extension to avoid these copies +fn rewrite_extension_exprs( + node: &mut Arc, + f: F, +) -> Result> +where + F: FnMut(&mut Expr) -> Result>, +{ + let mut exprs = node.expressions(); + let result = rewrite_expr_iter_mut(exprs.iter_mut(), f)?; + let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); + let mut new_node = node.from_template(&exprs, &inputs); + std::mem::swap(node, &mut new_node); + Ok(result) +} + +/// Rewrties all inputs for an Extension node "in place" +/// (it currently has to copy values because there are no APIs for in place modification) +/// +/// Should be removed when we have an API for in place modifications of the +/// extension to avoid these copies +fn rewrite_extension_inputs( + node: &mut Arc, + mut f: F, +) -> Result> +where + F: FnMut(&mut LogicalPlan) -> Result>, +{ + let mut inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); + + let result = inputs + .iter_mut() + .try_fold(Transformed::no(()), |acc, input| acc.and_then(|| f(input)))?; + let exprs = node.expressions(); + let mut new_node = node.from_template(&exprs, &inputs); + std::mem::swap(node, &mut new_node); + Ok(result) +} diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 02d5d1851289..07b47a4efa02 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -110,4 +110,11 @@ impl TreeNode for LogicalPlan { Ok(new_children.update_data(|_| self)) } } + + fn mutate_children(&mut self, f: F) -> Result> + where + F: FnMut(&mut Self) -> Result>, + { + self.rewrite_inputs(f) + } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c76c1c8a7bd0..79f95a177540 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1297,7 +1297,8 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); dbg!(&plan); let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation"; + "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\ + \n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; Ok(()) } diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index b94cf37c5c12..744feeac0914 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -337,7 +337,7 @@ mod tests { Operator, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), plan, @@ -377,7 +377,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_2.c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional AND filter @@ -403,7 +403,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional OR filter @@ -429,7 +429,7 @@ mod tests { \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -457,7 +457,7 @@ mod tests { \n Projection: sq2.c [c:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for nested IN subqueries @@ -486,7 +486,7 @@ mod tests { \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for filter input modification in case filter not supported @@ -518,7 +518,7 @@ mod tests { \n Projection: sq_inner.c [c:UInt32]\ \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test multiple correlated subqueries @@ -556,7 +556,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -606,7 +606,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -641,7 +641,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -674,7 +674,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -705,7 +705,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -738,7 +738,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -771,7 +771,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -805,7 +805,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); @@ -862,7 +862,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -895,7 +895,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -961,7 +961,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -999,7 +999,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1029,7 +1029,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1053,7 +1053,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1077,7 +1077,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1106,7 +1106,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1141,7 +1141,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1177,7 +1177,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1223,7 +1223,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1254,7 +1254,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1288,7 +1288,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test recursive correlated subqueries @@ -1331,7 +1331,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional subquery filters @@ -1361,7 +1361,7 @@ mod tests { \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1386,7 +1386,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for exists subquery with both columns in schema @@ -1404,7 +1404,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for correlated exists subquery not equal @@ -1432,7 +1432,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery less than @@ -1460,7 +1460,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with subquery disjunction @@ -1489,7 +1489,7 @@ mod tests { \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists without projection @@ -1515,7 +1515,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists expressions @@ -1543,7 +1543,7 @@ mod tests { \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional filters @@ -1571,7 +1571,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with disjustions @@ -1598,7 +1598,7 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated EXISTS subquery filter @@ -1623,7 +1623,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for single exists subquery filter @@ -1635,7 +1635,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for single NOT exists subquery filter @@ -1647,7 +1647,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } #[test] @@ -1686,7 +1686,7 @@ mod tests { \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1712,7 +1712,7 @@ mod tests { \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1738,7 +1738,7 @@ mod tests { \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1766,7 +1766,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1794,7 +1794,7 @@ mod tests { \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1822,6 +1822,6 @@ mod tests { \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index de05717a72e2..fae0eb5c8b1d 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -114,7 +114,7 @@ mod tests { use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(EliminateDuplicatedExpr::new()), plan, @@ -132,7 +132,7 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -151,6 +151,6 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index fea14342ca77..9287752a3f99 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -88,7 +88,7 @@ mod tests { use crate::test::*; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) } @@ -104,7 +104,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -119,7 +119,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -141,7 +141,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -156,7 +156,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -179,7 +179,7 @@ mod tests { \n TableScan: test\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -202,6 +202,6 @@ mod tests { // Filter is removed let expected = "Projection: test.a\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 0dbebcc8a051..f4123c6503e8 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -82,7 +82,7 @@ mod tests { use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) } @@ -97,7 +97,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -114,6 +114,6 @@ mod tests { CrossJoin:\ \n EmptyRelation\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 4386253740aa..bfd660ce884f 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -93,24 +93,19 @@ mod tests { use crate::push_down_limit::PushDownLimit; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } fn assert_optimized_plan_eq_with_pushdown( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -124,7 +119,6 @@ mod tests { .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } @@ -137,7 +131,7 @@ mod tests { .build()?; // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -157,7 +151,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -171,7 +165,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -191,7 +185,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -209,7 +203,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -227,7 +221,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -249,7 +243,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n TableScan: test\ \n TableScan: test1"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -262,6 +256,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 5771ea2e19a2..d27766b33543 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -114,7 +114,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) } @@ -131,7 +131,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -147,7 +147,7 @@ mod tests { \n Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -167,7 +167,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -188,7 +188,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -210,7 +210,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -230,7 +230,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // We don't need to use project_with_column_index in logical optimizer, @@ -261,7 +261,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -291,7 +291,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -337,7 +337,7 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -384,6 +384,6 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 70ee490346ff..cb79cd88bd03 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -76,7 +76,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_with_rules( vec![Arc::new(EliminateOneUnion::new())], plan, @@ -97,7 +97,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -113,6 +113,6 @@ mod tests { }); let expected = "TableScan: table"; - assert_optimized_plan_equal(&single_union_plan, expected) + assert_optimized_plan_equal(single_union_plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 56a4a76987f7..edc2131564b7 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -306,7 +306,7 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) } @@ -330,7 +330,7 @@ mod tests { \n Left Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -353,7 +353,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -380,7 +380,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -407,7 +407,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -434,6 +434,6 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 24664d57c38d..efe92e2702b3 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -164,7 +164,7 @@ mod tests { col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; - fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(ExtractEquijoinPredicate {}), plan, @@ -186,7 +186,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -205,7 +205,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -228,7 +228,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -255,7 +255,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -281,7 +281,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -318,7 +318,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -351,7 +351,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -375,6 +375,6 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 95cd8a9fd36c..a91768312fcf 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -119,7 +119,7 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) } @@ -131,7 +131,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -142,7 +142,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -179,7 +179,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -200,7 +200,7 @@ mod tests { \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -221,7 +221,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -244,7 +244,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } fn build_plan( diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index b942f187c331..7be10ade7b6a 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -942,7 +942,7 @@ mod tests { UserDefinedLogicalNodeCore, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1091,7 +1091,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1105,7 +1105,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1118,7 +1118,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1131,7 +1131,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1153,7 +1153,7 @@ mod tests { \n Projection: \ \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1176,7 +1176,7 @@ mod tests { .build()?; let expected = "Projection: (?table?.s)[x]\ \n TableScan: ?table? projection=[s]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1188,7 +1188,7 @@ mod tests { let expected = "Projection: (- test.a)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1200,7 +1200,7 @@ mod tests { let expected = "Projection: test.a IS NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1212,7 +1212,7 @@ mod tests { let expected = "Projection: test.a IS NOT NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1224,7 +1224,7 @@ mod tests { let expected = "Projection: test.a IS TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1236,7 +1236,7 @@ mod tests { let expected = "Projection: test.a IS NOT TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1248,7 +1248,7 @@ mod tests { let expected = "Projection: test.a IS FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1260,7 +1260,7 @@ mod tests { let expected = "Projection: test.a IS NOT FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1272,7 +1272,7 @@ mod tests { let expected = "Projection: test.a IS UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1284,7 +1284,7 @@ mod tests { let expected = "Projection: test.a IS NOT UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1296,7 +1296,7 @@ mod tests { let expected = "Projection: NOT test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1308,7 +1308,7 @@ mod tests { let expected = "Projection: TRY_CAST(test.a AS Float64)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1324,7 +1324,7 @@ mod tests { let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1336,7 +1336,7 @@ mod tests { let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Test outer projection isn't discarded despite the same schema as inner @@ -1357,7 +1357,7 @@ mod tests { let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\ \n Projection: test.a, Int32(0) AS d\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Since only column `a` is referred at the output. Scan should only contain projection=[a]. @@ -1378,7 +1378,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses column `b` @@ -1405,7 +1405,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` @@ -1440,7 +1440,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Columns `l.a`, `l.c`, `r.a` is referred at the output. @@ -1465,6 +1465,6 @@ mod tests { \n UserDefinedCrossJoin\ \n TableScan: l projection=[a, c]\ \n TableScan: r projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index fe63766fc265..16e4a0f9bde8 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -48,10 +48,11 @@ use crate::utils::log_plan; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use chrono::{DateTime, Utc}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeMutator}; use log::{debug, warn}; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which @@ -184,11 +185,11 @@ pub struct Optimizer { pub rules: Vec>, } -/// If a rule is with `ApplyOrder`, it means the optimizer will derive to handle children instead of -/// recursively handling in rule. -/// We just need handle a subtree pattern itself. +/// If a rule is with `ApplyOrder`, it means the optimizer will handle +/// recursion. If it is `None` the rule must handle any required recursion /// -/// Notice: **sometime** result after optimize still can be optimized, we need apply again. +/// Notice: **sometime** result after optimize still can be optimized, we need +/// apply again. /// /// Usage Example: Merge Limit (subtree pattern is: Limit-Limit) /// ```rust @@ -217,6 +218,7 @@ pub struct Optimizer { /// } /// } /// ``` +#[derive(Debug, Clone, Copy, PartialEq)] pub enum ApplyOrder { TopDown, BottomUp, @@ -274,12 +276,90 @@ impl Optimizer { pub fn with_rules(rules: Vec>) -> Self { Self { rules } } +} + +struct Mutator<'a> { + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, +} + +impl<'a> TreeNodeMutator for Mutator<'a> { + type Node = LogicalPlan; + fn f_down(&mut self, node: &mut Self::Node) -> Result> { + if self.apply_order == ApplyOrder::TopDown { + optimize_in_place(node, self.rule, self.config) + } else { + Ok(Transformed::no(())) + } + } + + fn f_up(&mut self, node: &mut Self::Node) -> Result> { + if self.apply_order == ApplyOrder::BottomUp { + optimize_in_place(node, self.rule, self.config) + } else { + Ok(Transformed::no(())) + } + } +} + +/// Applies rule to `plan` in place, returning `Transformed` with the rewritten +/// plan +fn rewrite_in_place( + mut plan: LogicalPlan, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result> { + let transformed = match rule.apply_order() { + Some(apply_order) => { + // use &mut to rewrite plan in place + plan.mutate(&mut Mutator { + apply_order, + rule, + config, + }) + } + None => optimize_in_place(&mut plan, rule, config), + } + // convert to bool to drop mut borrow on plan + .map(|tnr| tnr.transformed); + + // take back ownership + transformed.map(|transformed| { + if transformed { + Transformed::yes(plan) + } else { + Transformed::no(plan) + } + }) +} + +/// Invokes the Optimizer rule to rewrite the LogicalPlan in place. +fn optimize_in_place( + plan: &mut LogicalPlan, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result> { + // TODO: introduce a better API to OptimizerRule to allow rewriting in place + rule.try_optimize(plan, config).map(|maybe_plan| { + match maybe_plan { + Some(new_plan) => { + // if the node was rewritten by the optimizer, replace the node + *plan = new_plan; // TODO avoid this copy with better OptimizerRule::try_optimize + Transformed::yes(()) + } + None => Transformed::no(()), + } + }) +} + +impl Optimizer { /// Optimizes the logical plan by applying optimizer rules, and /// invoking observer function after each call pub fn optimize( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, mut observer: F, ) -> Result @@ -287,7 +367,7 @@ impl Optimizer { F: FnMut(&LogicalPlan, &dyn OptimizerRule), { let options = config.options(); - let mut new_plan = plan.clone(); + let mut new_plan = plan; let start_time = Instant::now(); @@ -299,44 +379,65 @@ impl Optimizer { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); for rule in &self.rules { - let result = - self.optimize_recursively(rule, &new_plan, config) - .and_then(|plan| { - if let Some(plan) = &plan { - assert_schema_is_the_same(rule.name(), plan, &new_plan)?; - } - Ok(plan) - }); - match result { - Ok(Some(plan)) => { + // If we need to skip failed rules, must copy plan before attempting to rewrite + // as rewriting is destructive + let prev_plan = options + .optimizer + .skip_failed_rules + .then(|| new_plan.clone()); + + let starting_schema = new_plan.schema().clone(); + + let result = rewrite_in_place(new_plan, rule.as_ref(), config) + // verify the rule didn't change the schema + .and_then(|tnr| { + if tnr.transformed { + assert_only_schema_is_the_same( + rule.name(), + &starting_schema, + &tnr.data, + )?; + } + Ok(tnr) + }); + + match (result, prev_plan) { + ( + Ok(Transformed { + data: plan, + transformed, + .. + }), + _, + ) => { new_plan = plan; observer(&new_plan, rule.as_ref()); - log_plan(rule.name(), &new_plan); - } - Ok(None) => { - observer(&new_plan, rule.as_ref()); - debug!( - "Plan unchanged by optimizer rule '{}' (pass {})", - rule.name(), - i - ); + if transformed { + log_plan(rule.name(), &new_plan); + } else { + debug!( + "Plan unchanged by optimizer rule '{}' (pass {})", + rule.name(), + i + ); + } } - Err(e) => { - if options.optimizer.skip_failed_rules { - // Note to future readers: if you see this warning it signals a - // bug in the DataFusion optimizer. Please consider filing a ticket - // https://github.com/apache/arrow-datafusion - warn!( + (Err(e), Some(orig_plan)) => { + // Note to future readers: if you see this warning it signals a + // bug in the DataFusion optimizer. Please consider filing a ticket + // https://github.com/apache/arrow-datafusion + warn!( "Skipping optimizer rule '{}' due to unexpected error: {}", rule.name(), e ); - } else { - return Err(DataFusionError::Context( - format!("Optimizer rule '{}' failed", rule.name(),), - Box::new(e), - )); - } + new_plan = orig_plan; + } + (Err(e), None) => { + return Err(e.context(format!( + "Optimizer rule '{}' failed", + rule.name() + ))); } } } @@ -356,97 +457,22 @@ impl Optimizer { debug!("Optimizer took {} ms", start_time.elapsed().as_millis()); Ok(new_plan) } - - fn optimize_node( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // TODO: future feature: We can do Batch optimize - rule.try_optimize(plan, config) - } - - fn optimize_inputs( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - let inputs = plan.inputs(); - let result = inputs - .iter() - .map(|sub_plan| self.optimize_recursively(rule, sub_plan, config)) - .collect::>>()?; - if result.is_empty() || result.iter().all(|o| o.is_none()) { - return Ok(None); - } - - let new_inputs = result - .into_iter() - .zip(inputs) - .map(|(new_plan, old_plan)| match new_plan { - Some(plan) => plan, - None => old_plan.clone(), - }) - .collect(); - - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) - } - - /// Use a rule to optimize the whole plan. - /// If the rule with `ApplyOrder`, we don't need to recursively handle children in rule. - pub fn optimize_recursively( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - match rule.apply_order() { - Some(order) => match order { - ApplyOrder::TopDown => { - let optimize_self_opt = self.optimize_node(rule, plan, config)?; - let optimize_inputs_opt = match &optimize_self_opt { - Some(optimized_plan) => { - self.optimize_inputs(rule, optimized_plan, config)? - } - _ => self.optimize_inputs(rule, plan, config)?, - }; - Ok(optimize_inputs_opt.or(optimize_self_opt)) - } - ApplyOrder::BottomUp => { - let optimize_inputs_opt = self.optimize_inputs(rule, plan, config)?; - let optimize_self_opt = match &optimize_inputs_opt { - Some(optimized_plan) => { - self.optimize_node(rule, optimized_plan, config)? - } - _ => self.optimize_node(rule, plan, config)?, - }; - Ok(optimize_self_opt.or(optimize_inputs_opt)) - } - }, - _ => rule.try_optimize(plan, config), - } - } } -/// Returns an error if plans have different schemas. +/// Returns an error if the plan has a different schema than `prev_schema` /// /// It ignores metadata and nullability. -pub(crate) fn assert_schema_is_the_same( +pub(crate) fn assert_only_schema_is_the_same( rule_name: &str, - prev_plan: &LogicalPlan, + prev_schema: &DFSchema, new_plan: &LogicalPlan, ) -> Result<()> { - let equivalent = new_plan - .schema() - .equivalent_names_and_types(prev_plan.schema()); + let equivalent = new_plan.schema().equivalent_names_and_types(prev_schema); if !equivalent { let e = DataFusionError::Internal(format!( "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", - prev_plan.schema(), + prev_schema, new_plan.schema() )); Err(DataFusionError::Context( @@ -479,7 +505,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -490,7 +516,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'bad rule' failed\ncaused by\n\ Error during planning: rule failed", @@ -506,16 +532,16 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ Internal error: Failed due to a difference in schemas, \ - original schema: DFSchema { fields: [\ + original schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ + new schema: DFSchema { fields: [\ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }], \ - metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ - new schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }.\ + metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }.\ \nThis was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", err.strip_backtrace() ); @@ -529,7 +555,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -550,7 +576,7 @@ mod tests { // optimizing should be ok, but the schema will have changed (no metadata) assert_ne!(plan.schema().as_ref(), input_schema.as_ref()); - let optimized_plan = opt.optimize(&plan, &config, &observe)?; + let optimized_plan = opt.optimize(plan, &config, &observe)?; // metadata was removed assert_eq!(optimized_plan.schema().as_ref(), input_schema.as_ref()); Ok(()) @@ -571,7 +597,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan.clone(), &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 3 plans assert_eq!(3, plans.len()); @@ -597,7 +623,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan, &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 4 plans assert_eq!(4, plans.len()); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index d1f9f87a32a3..d28cdc2158d8 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -197,12 +197,12 @@ mod tests { use super::*; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } fn assert_together_optimized_plan_eq( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { assert_optimized_plan_eq_with_rules( @@ -225,7 +225,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_eq(&plan, expected) + assert_eq(plan, expected) } #[test] @@ -248,7 +248,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -261,7 +261,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -286,7 +286,7 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -311,7 +311,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -338,7 +338,7 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -351,7 +351,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -366,7 +366,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -399,6 +399,6 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e93e171e0324..60ecd2522c8c 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1052,8 +1052,9 @@ mod tests { }; use async_trait::async_trait; + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(PushDownFilter::new()), plan, @@ -1062,29 +1063,17 @@ mod tests { } fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(PushDownFilter::new()), ]); - let mut optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(plan.schema(), optimized_plan.schema()); assert_eq!(expected, formatted_plan); Ok(()) } @@ -1100,7 +1089,7 @@ mod tests { let expected = "\ Projection: test.a, test.b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1117,7 +1106,7 @@ mod tests { \n Limit: skip=0, fetch=10\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1127,7 +1116,7 @@ mod tests { .filter(lit(0i64).eq(lit(1i64)))? .build()?; let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1143,7 +1132,7 @@ mod tests { Projection: test.c, test.b\ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1157,7 +1146,7 @@ mod tests { let expected = "\ Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1170,7 +1159,7 @@ mod tests { let expected = "Filter: test.b > Int64(10)\ \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1182,7 +1171,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1197,7 +1186,7 @@ mod tests { Filter: b > Int64(10)\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1212,7 +1201,7 @@ mod tests { let expected = "\ Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } fn add(left: Expr, right: Expr) -> Expr { @@ -1256,7 +1245,7 @@ mod tests { let expected = "\ Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1288,7 +1277,7 @@ mod tests { Projection: b * Int32(3) AS a, test.c\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1351,7 +1340,7 @@ mod tests { let expected = "\ NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1368,7 +1357,7 @@ mod tests { Filter: test.c = Int64(2)\ \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1385,7 +1374,7 @@ mod tests { NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1403,7 +1392,7 @@ mod tests { \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -1436,7 +1425,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -1470,7 +1459,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two limits are in place, we jump neither @@ -1492,7 +1481,7 @@ mod tests { \n Limit: skip=0, fetch=20\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1507,7 +1496,7 @@ mod tests { let expected = "Union\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1530,7 +1519,7 @@ mod tests { \n SubqueryAlias: test2\ \n Projection: test.a AS b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1561,7 +1550,7 @@ mod tests { \n Projection: test1.d, test1.e, test1.f\ \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1587,7 +1576,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters with the same columns are correctly placed @@ -1621,7 +1610,7 @@ mod tests { \n Projection: test.a\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters to be placed on the same depth are ANDed @@ -1651,7 +1640,7 @@ mod tests { \n Limit: skip=0, fetch=1\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters on a plan with user nodes are not lost @@ -1677,7 +1666,7 @@ mod tests { TestUserDefined\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -1715,7 +1704,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -1752,7 +1741,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from both sides are converted to join filterss @@ -1794,7 +1783,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -1836,7 +1825,7 @@ mod tests { \n TableScan: test, full_filters=[test.b <= Int64(1)]\ \n Projection: test2.a, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the right side of a left join are not duplicated @@ -1875,7 +1864,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the left side of a right join are not duplicated @@ -1913,7 +1902,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -1951,7 +1940,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -1989,7 +1978,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2032,7 +2021,7 @@ mod tests { \n TableScan: test, full_filters=[test.c > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// join filter should be completely removed after pushdown @@ -2074,7 +2063,7 @@ mod tests { \n TableScan: test, full_filters=[test.b > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2114,7 +2103,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to right input @@ -2157,7 +2146,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to left input @@ -2200,7 +2189,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should not be pushed @@ -2238,7 +2227,7 @@ mod tests { ); let expected = &format!("{plan:?}"); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } struct PushDownProvider { @@ -2297,7 +2286,7 @@ mod tests { let expected = "\ TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2308,7 +2297,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2316,7 +2305,7 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let optimised_plan = PushDownFilter::new() + let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new()) .expect("failed to optimize plan") .unwrap(); @@ -2327,7 +2316,7 @@ mod tests { // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(&optimised_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2338,7 +2327,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2367,7 +2356,7 @@ mod tests { \n Filter: a = Int64(10) AND b > Int64(11)\ \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2398,7 +2387,7 @@ Projection: a, b "# .trim(); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2426,7 +2415,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2458,7 +2447,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2483,7 +2472,7 @@ Projection: a, b Projection: test.a AS b, test.c AS d\ \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2523,7 +2512,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b AS d\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2552,7 +2541,7 @@ Projection: a, b Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2584,7 +2573,7 @@ Projection: a, b \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2620,7 +2609,7 @@ Projection: a, b \n Subquery:\ \n Projection: sq.c\ \n TableScan: sq"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2653,7 +2642,7 @@ Projection: a, b \n Projection: Int64(0) AS a\ \n Filter: Int64(0) = Int64(1)\ \n EmptyRelation"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2681,14 +2670,14 @@ Projection: a, b \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; + assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new())? .expect("failed to optimize plan"); - assert_optimized_plan_eq(&optimized_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2729,7 +2718,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2770,7 +2759,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2816,7 +2805,7 @@ Projection: a, b \n TableScan: test1\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2861,7 +2850,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2894,7 +2883,7 @@ Projection: a, b \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2936,6 +2925,6 @@ Projection: a, b \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 33d02d5c5628..da445c7f4cb4 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -284,7 +284,7 @@ mod test { max, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } @@ -303,7 +303,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -321,7 +321,7 @@ mod test { let expected = "Limit: skip=0, fetch=10\ \n TableScan: test, fetch=10"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -338,7 +338,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -358,7 +358,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -375,7 +375,7 @@ mod test { \n Sort: test.a, fetch=10\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -392,7 +392,7 @@ mod test { \n Sort: test.a, fetch=15\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -411,7 +411,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -426,7 +426,7 @@ mod test { let expected = "Limit: skip=10, fetch=None\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -444,7 +444,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -461,7 +461,7 @@ mod test { \n Limit: skip=10, fetch=990\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -478,7 +478,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -494,7 +494,7 @@ mod test { let expected = "Limit: skip=10, fetch=10\ \n TableScan: test, fetch=20"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -511,7 +511,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -531,7 +531,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -555,7 +555,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -579,7 +579,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -608,7 +608,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -637,7 +637,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -663,7 +663,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -682,7 +682,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -701,7 +701,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -719,7 +719,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -737,7 +737,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -755,7 +755,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1) .join( @@ -773,7 +773,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -798,7 +798,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -823,7 +823,7 @@ mod test { \n TableScan: test, fetch=1010\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -848,7 +848,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -873,7 +873,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test2, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -893,7 +893,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -913,7 +913,7 @@ mod test { \n Limit: skip=0, fetch=2000\ \n TableScan: test2, fetch=2000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -928,7 +928,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -943,7 +943,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -960,6 +960,6 @@ mod test { \n Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 28b3ff090fe6..ef6c5c18912f 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -27,7 +27,7 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; - use crate::OptimizerContext; + use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, DFField, DFSchema, Result}; use datafusion_expr::builder::table_scan_with_filters; @@ -51,7 +51,7 @@ mod tests { let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -65,7 +65,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -81,7 +81,7 @@ mod tests { \n SubqueryAlias: a\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -98,7 +98,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -123,7 +123,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ \n TableScan: m4 projection=[tag.one]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -137,7 +137,7 @@ mod tests { let expected = "Projection: test.a, test.c, test.b\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -147,7 +147,7 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; let expected = "TableScan: test projection=[b, a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -160,7 +160,7 @@ mod tests { let expected = "Projection: test.a, test.b\ \n TableScan: test projection=[b, a]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -173,7 +173,7 @@ mod tests { let expected = "Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -195,7 +195,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -215,7 +215,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -258,7 +258,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -299,7 +299,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -334,7 +334,7 @@ mod tests { let expected = "Projection: CAST(test.c AS Float64)\ \n TableScan: test projection=[c]"; - assert_optimized_plan_eq(&projection, expected) + assert_optimized_plan_eq(projection, expected) } #[test] @@ -350,7 +350,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -371,7 +371,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -391,7 +391,7 @@ mod tests { \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -400,7 +400,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -411,7 +411,7 @@ mod tests { .build()?; let expected = "Projection: Int64(1), Int64(2)\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes unused columns in projections @@ -430,14 +430,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); - let plan = optimize(&plan).expect("failed to optimize plan"); + let plan = optimize(plan).expect("failed to optimize plan"); let expected = "\ Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes un-needed projections @@ -459,7 +459,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -488,7 +488,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that optimizing twice yields same plan @@ -501,9 +501,9 @@ mod tests { .project(vec![lit(1).alias("a")])? .build()?; - let optimized_plan1 = optimize(&plan).expect("failed to optimize plan"); + let optimized_plan1 = optimize(plan).expect("failed to optimize plan"); let optimized_plan2 = - optimize(&optimized_plan1).expect("failed to optimize plan"); + optimize(optimized_plan1.clone()).expect("failed to optimize plan"); let formatted_plan1 = format!("{optimized_plan1:?}"); let formatted_plan2 = format!("{optimized_plan2:?}"); @@ -532,7 +532,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -558,7 +558,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -575,7 +575,7 @@ mod tests { \n Distinct:\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -614,25 +614,23 @@ mod tests { \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimized_plan = optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) } - fn optimize(plan: &LogicalPlan) -> Result { + fn optimize(plan: LogicalPlan) -> Result { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + Ok(optimized_plan) } + + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 0666c324d12c..c4b70e929831 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -174,7 +174,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } @@ -197,7 +197,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 8acc36e479ca..85dcec63ab58 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -427,7 +427,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -483,7 +483,7 @@ mod tests { \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -521,7 +521,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -557,7 +557,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -591,7 +591,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -730,7 +730,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -796,7 +796,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -835,7 +835,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -875,7 +875,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -908,7 +908,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -940,7 +940,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -971,7 +971,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1028,7 +1028,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1077,7 +1077,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 07a9d84f7d48..538a732e3aa6 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -309,7 +309,7 @@ mod tests { min, sum, AggregateFunction, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), plan, @@ -331,7 +331,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]] [MAX(test.b):UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -348,7 +348,7 @@ mod tests { \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -369,7 +369,7 @@ mod tests { let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -387,7 +387,7 @@ mod tests { let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -406,7 +406,7 @@ mod tests { let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -422,7 +422,7 @@ mod tests { \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -439,7 +439,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -457,7 +457,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -486,7 +486,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -504,7 +504,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -521,7 +521,7 @@ mod tests { \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -551,7 +551,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -570,7 +570,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -589,7 +589,7 @@ mod tests { \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -612,7 +612,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -635,7 +635,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -658,7 +658,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -681,7 +681,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -704,6 +704,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index e691fe9a5351..b8e9c66bc2b4 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::{assert_schema_is_the_same, Optimizer}; +use crate::optimizer::Optimizer; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -152,20 +152,16 @@ pub fn assert_analyzer_check_err( } pub fn assert_optimized_plan_eq( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule.clone()]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + + // in tests we are applying only one rule once + let opt_context = OptimizerContext::new().with_max_passes(1); - // Ensure schemas always match after an optimization - assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; + let optimizer = Optimizer::with_rules(vec![rule.clone()]); + let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -174,7 +170,7 @@ pub fn assert_optimized_plan_eq( pub fn assert_optimized_plan_eq_with_rules( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -187,58 +183,46 @@ pub fn assert_optimized_plan_eq_with_rules( .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } +fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + pub fn assert_optimized_plan_eq_display_indent( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ) - .expect("failed to optimize plan") - .unwrap_or_else(|| plan.clone()); + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_multi_rules_optimized_plan_eq_display_indent( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(rules); - let mut optimized_plan = plan.clone(); - for rule in &optimizer.rules { - optimized_plan = optimizer - .optimize_recursively(rule, &optimized_plan, &OptimizerContext::new()) - .expect("failed to optimize plan") - .unwrap_or_else(|| optimized_plan.clone()); - } + let optimized_plan = optimizer + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_optimizer_err( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); - let res = optimizer.optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ); + let res = optimizer.optimize(plan, &OptimizerContext::new(), observe); match res { - Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "An error"), + Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), Err(ref e) => { let actual = format!("{e}"); if expected.is_empty() || !actual.contains(expected) { @@ -250,16 +234,11 @@ pub fn assert_optimizer_err( pub fn assert_optimization_skipped( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![rule]); - let new_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(), observe)?; + assert_eq!( format!("{}", plan.display_indent()), format!("{}", new_plan.display_indent()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index acafc0bafaf4..c28349447dbb 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -315,7 +315,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}) } #[derive(Default)] diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index da9b4168e7e0..135ab8075425 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -587,7 +587,7 @@ FROM t1 ---- 11 11 11 -# subsequent inner join +# subsequent inner join query III rowsort SELECT t1.t1_id, t2.t2_id, t3.t3_id FROM t1 diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 33c9ff7c3eed..4c9254beef6b 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -781,4 +781,4 @@ logical_plan EmptyRelation physical_plan EmptyExec statement ok -drop table t; \ No newline at end of file +drop table t; diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 20c8b3d25fdd..95878e0c433d 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -2162,5 +2162,3 @@ query I select strpos('joséésoj', arrow_cast(null, 'Utf8')); ---- NULL - - diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 4fb94cfab523..d1c2450488b8 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -531,13 +531,13 @@ query TT explain SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name)) ---- logical_plan -Filter: EXISTS () ---Subquery: -----Projection: Int64(1) -------Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) ---------TableScan: t1 ---------TableScan: t2 +LeftSemi Join: t0.t0_name = __correlated_sq_2.t1_name --TableScan: t0 projection=[t0_id, t0_name] +--SubqueryAlias: __correlated_sq_2 +----Projection: t1.t1_name +------Inner Join: t1.t1_id = t2.t2_id +--------TableScan: t1 projection=[t1_id, t1_name] +--------TableScan: t2 projection=[t2_id] #subquery_contains_join_contains_correlated_columns query TT @@ -651,8 +651,8 @@ explain SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where logical_plan Filter: t1.t1_id IN () --Subquery: -----Limit: skip=0, fetch=10 -------Projection: t2.t2_id +----Projection: t2.t2_id +------Limit: skip=0, fetch=10 --------Filter: outer_ref(t1.t1_name) = t2.t2_name ----------TableScan: t2 --TableScan: t1 projection=[t1_id, t1_name] diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index f0e04b522a78..491b9b810687 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2794,4 +2794,4 @@ SELECT '2000-12-01 04:04:12' AT TIME ZONE 'America/New York'; # abbreviated timezone is not supported statement error -SELECT '2023-03-12 02:00:00' AT TIME ZONE 'EDT'; \ No newline at end of file +SELECT '2023-03-12 02:00:00' AT TIME ZONE 'EDT';