diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 36e446d17f05..bdc4afe901f4 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -92,13 +92,19 @@ impl OptimizerRule for DecorrelateWhereExists { // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = filter.input.as_ref().clone(); for subquery in subqueries { - if let Some(x) = optimize_exists(&subquery, &cur_input, &other_exprs)? - { + if let Some(x) = optimize_exists(&subquery, &cur_input)? { cur_input = x; } else { return Ok(None); } } + + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + cur_input = LogicalPlan::Filter(new_filter); + } + Ok(Some(cur_input)) } _ => Ok(None), @@ -116,23 +122,27 @@ impl OptimizerRule for DecorrelateWhereExists { /// Takes a query like: /// -/// ```select c.id from customers c where exists (select * from orders o where o.c_id = c.id)``` +/// SELECT t1.id +/// FROM t1 +/// WHERE exists +/// ( +/// SELECT t2.id FROM t2 WHERE t1.id = t2.id +/// ) /// /// and optimizes it into: /// -/// ```select c.id from customers c -/// inner join (select o.c_id from orders o group by o.c_id) o on o.c_id = c.c_id``` +/// SELECT t1.id +/// FROM t1 LEFT SEMI +/// JOIN t2 +/// ON t1.id = t2.id /// /// # Arguments /// -/// * subqry - The subquery portion of the `where exists` (select * from orders) -/// * negated - True if the subquery is a `where not exists` -/// * filter_input - The non-subquery portion (from customers) -/// * outer_exprs - Any additional parts to the `where` expression (and c.x = y) +/// * query_info - The subquery and negated(exists/not exists) info. +/// * outer_input - The non-subquery portion (relation t1) fn optimize_exists( query_info: &SubqueryInfo, outer_input: &LogicalPlan, - outer_other_exprs: &[Expr], ) -> Result> { let subqry_filter = match query_info.query.subquery.as_ref() { LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() { @@ -180,18 +190,10 @@ fn optimize_exists( true => JoinType::LeftAnti, false => JoinType::LeftSemi, }; - let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join( - subqry_plan, - join_type, - join_keys, - join_filters, - )?; - if let Some(expr) = conjunction(outer_other_exprs.to_vec()) { - new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them - } - - let result = new_plan.build()?; - Ok(Some(result)) + let new_plan = LogicalPlanBuilder::from(outer_input.clone()) + .join(subqry_plan, join_type, join_keys, join_filters)? + .build()?; + Ok(Some(new_plan)) } struct SubqueryInfo { @@ -555,4 +557,44 @@ mod tests { assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); Ok(()) } + + #[test] + fn two_exists_subquery_with_outer_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan1 = test_table_scan_with_name("sq1")?; + let subquery_scan2 = test_table_scan_with_name("sq2")?; + + let subquery1 = LogicalPlanBuilder::from(subquery_scan1) + .filter(col("test.a").eq(col("sq1.a")))? + .project(vec![col("c")])? + .build()?; + + let subquery2 = LogicalPlanBuilder::from(subquery_scan2) + .filter(col("test.a").eq(col("sq2.a")))? + .project(vec![col("c")])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + exists(Arc::new(subquery1)) + .and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))), + )? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereExists::new()), + &plan, + expected, + ); + Ok(()) + } } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 13e3acf78876..b3222d279e47 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -23,7 +23,7 @@ use datafusion_common::{context, Column, Result}; use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; use datafusion_expr::utils::check_all_column_from_schema; -use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder}; use log::debug; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -96,17 +96,18 @@ impl OptimizerRule for DecorrelateWhereIn { return Ok(None); } - // iterate through all exists clauses in predicate, turning each into a join // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = filter.input.as_ref().clone(); for subquery in subqueries { - cur_input = optimize_where_in( - &subquery, - &cur_input, - &other_exprs, - &self.alias, - )?; + cur_input = optimize_where_in(&subquery, &cur_input, &self.alias)?; + } + + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + cur_input = LogicalPlan::Filter(new_filter); } + Ok(Some(cur_input)) } _ => Ok(None), @@ -141,7 +142,6 @@ impl OptimizerRule for DecorrelateWhereIn { fn optimize_where_in( query_info: &SubqueryInfo, left: &LogicalPlan, - outer_other_exprs: &[Expr], alias: &AliasGenerator, ) -> Result { let projection = Projection::try_from_plan(&query_info.query.subquery) @@ -207,17 +207,14 @@ fn optimize_where_in( .map(|filter| in_predicate.clone().and(filter)) .unwrap_or_else(|| in_predicate); - let mut new_plan = LogicalPlanBuilder::from(left.clone()).join( - right, - join_type, - (Vec::::new(), Vec::::new()), - Some(join_filter), - )?; - - if let Some(expr) = conjunction(outer_other_exprs.to_vec()) { - new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them - } - let new_plan = new_plan.build()?; + let new_plan = LogicalPlanBuilder::from(left.clone()) + .join( + right, + join_type, + (Vec::::new(), Vec::::new()), + Some(join_filter), + )? + .build()?; debug!("where in optimized:\n{}", new_plan.display_indent()); Ok(new_plan) @@ -1162,17 +1159,14 @@ mod tests { .project(vec![col("test.b")])? .build()?; - // Filter: test.c > UInt32(1) happen twice. - // issue: https://github.com/apache/arrow-datafusion/issues/4914 let expected = "Projection: test.b [b:UInt32]\ \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, a:UInt32]\ \n Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * UInt32(2):UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";