From 63e05fa3be53ebd3c357a2c987bd621071028d07 Mon Sep 17 00:00:00 2001 From: Dmitry Bugakov Date: Thu, 2 May 2024 15:38:45 +0200 Subject: [PATCH] Stop copying LogicalPlan and Exprs in PropagateEmptyRelation #10290 (#10332) --- .../optimizer/src/propagate_empty_relation.rs | 141 +++++++++++------- .../optimizer/tests/optimizer_integration.rs | 27 ++++ 2 files changed, 113 insertions(+), 55 deletions(-) diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 4003acaa7d65..d08820c58a05 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -16,11 +16,16 @@ // under the License. //! [`PropagateEmptyRelation`] eliminates nodes fed by `EmptyRelation` -use datafusion_common::{plan_err, Result}; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; + use std::sync::Arc; +use datafusion_common::tree_node::Transformed; +use datafusion_common::JoinType::Inner; +use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{EmptyRelation, Projection, Union}; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; @@ -38,11 +43,31 @@ impl PropagateEmptyRelation { impl OptimizerRule for PropagateEmptyRelation { fn try_optimize( &self, - plan: &LogicalPlan, + _plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called PropagateEmptyRelation::rewrite") + } + + fn name(&self) -> &str { + "propagate_empty_relation" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { match plan { - LogicalPlan::EmptyRelation(_) => {} + LogicalPlan::EmptyRelation(_) => Ok(Transformed::no(plan)), LogicalPlan::Projection(_) | LogicalPlan::Filter(_) | LogicalPlan::Window(_) @@ -50,20 +75,26 @@ impl OptimizerRule for PropagateEmptyRelation { | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Repartition(_) | LogicalPlan::Limit(_) => { - if let Some(empty) = empty_child(plan)? { - return Ok(Some(empty)); + let empty = empty_child(&plan)?; + if let Some(empty_plan) = empty { + return Ok(Transformed::yes(empty_plan)); } + Ok(Transformed::no(plan)) } - LogicalPlan::CrossJoin(_) => { - let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?; + LogicalPlan::CrossJoin(ref join) => { + let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; if left_empty || right_empty { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: plan.schema().clone(), - }))); + return Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: false, + schema: plan.schema().clone(), + }, + ))); } + Ok(Transformed::no(LogicalPlan::CrossJoin(join.clone()))) } - LogicalPlan::Join(join) => { + + LogicalPlan::Join(ref join) if join.join_type == Inner => { // TODO: For Join, more join type need to be careful: // For LeftOuter/LeftSemi/LeftAnti Join, only the left side is empty, the Join result is empty. // For LeftSemi Join, if the right side is empty, the Join result is empty. @@ -76,17 +107,26 @@ impl OptimizerRule for PropagateEmptyRelation { // columns + right side columns replaced with null values. // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side // columns + left side columns replaced with null values. - if join.join_type == JoinType::Inner { - let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?; - if left_empty || right_empty { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; + if left_empty || right_empty { + return Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { produce_one_row: false, - schema: plan.schema().clone(), - }))); + schema: join.schema.clone(), + }, + ))); + } + Ok(Transformed::no(LogicalPlan::Join(join.clone()))) + } + LogicalPlan::Aggregate(ref agg) => { + if !agg.group_expr.is_empty() { + if let Some(empty_plan) = empty_child(&plan)? { + return Ok(Transformed::yes(empty_plan)); } } + Ok(Transformed::no(LogicalPlan::Aggregate(agg.clone()))) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(ref union) => { let new_inputs = union .inputs .iter() @@ -98,49 +138,36 @@ impl OptimizerRule for PropagateEmptyRelation { .collect::>(); if new_inputs.len() == union.inputs.len() { - return Ok(None); + Ok(Transformed::no(plan)) } else if new_inputs.is_empty() { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: plan.schema().clone(), - }))); + Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: false, + schema: plan.schema().clone(), + }, + ))) } else if new_inputs.len() == 1 { - let child = (*new_inputs[0]).clone(); + let child = unwrap_arc(new_inputs[0].clone()); if child.schema().eq(plan.schema()) { - return Ok(Some(child)); + Ok(Transformed::yes(child)) } else { - return Ok(Some(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::Projection( Projection::new_from_schema( Arc::new(child), plan.schema().clone(), ), - ))); + ))) } } else { - return Ok(Some(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::Union(Union { inputs: new_inputs, schema: union.schema.clone(), - }))); - } - } - LogicalPlan::Aggregate(agg) => { - if !agg.group_expr.is_empty() { - if let Some(empty) = empty_child(plan)? { - return Ok(Some(empty)); - } + }))) } } - _ => {} - } - Ok(None) - } - fn name(&self) -> &str { - "propagate_empty_relation" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) + _ => Ok(Transformed::no(plan)), + } } } @@ -182,18 +209,22 @@ fn empty_child(plan: &LogicalPlan) -> Result> { #[cfg(test)] mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + + use datafusion_common::{Column, DFSchema, JoinType, ScalarValue}; + use datafusion_expr::logical_plan::table_scan; + use datafusion_expr::{ + binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, Operator, + }; + use crate::eliminate_filter::EliminateFilter; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ assert_optimized_plan_eq, assert_optimized_plan_eq_with_rules, test_table_scan, test_table_scan_fields, test_table_scan_with_name, }; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Column, DFSchema, ScalarValue}; - use datafusion_expr::logical_plan::table_scan; - use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, Operator, - }; use super::*; diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 2430d2d52eb3..adf62efa0b67 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -20,6 +20,7 @@ use std::collections::HashMap; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; @@ -290,6 +291,32 @@ fn eliminate_nested_filters() { assert_eq!(expected, format!("{plan:?}")); } +#[test] +fn test_propagate_empty_relation_inner_join_and_unions() { + let sql = "\ + SELECT A.col_int32 FROM test AS A \ + INNER JOIN ( \ + SELECT col_int32 FROM test WHERE 1 = 0 \ + ) AS B ON A.col_int32 = B.col_int32 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE 1 = 1 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE 0 = 0 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE test.col_int32 < 0 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE 1 = 0"; + + let plan = test_sql(sql).unwrap(); + let expected = "\ + Union\ + \n TableScan: test projection=[col_int32]\ + \n TableScan: test projection=[col_int32]\ + \n Filter: test.col_int32 < Int32(0)\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{plan:?}")); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...