Skip to content

Commit

Permalink
Only add outer filter once when transforming exists/in subquery to jo…
Browse files Browse the repository at this point in the history
…in (#4944)

* Avoding add outer filter multiple time when transforming exists/in subquery to join

* fix comment of optimize_exists

* fix comment of optimize_exists

* fix comment
  • Loading branch information
ygf11 authored Jan 20, 2023
1 parent 22d106a commit e566bfc
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 50 deletions.
86 changes: 64 additions & 22 deletions datafusion/optimizer/src/decorrelate_where_exists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,19 @@ impl OptimizerRule for DecorrelateWhereExists {
// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = filter.input.as_ref().clone();
for subquery in subqueries {
if let Some(x) = optimize_exists(&subquery, &cur_input, &other_exprs)?
{
if let Some(x) = optimize_exists(&subquery, &cur_input)? {
cur_input = x;
} else {
return Ok(None);
}
}

let expr = conjunction(other_exprs);
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
cur_input = LogicalPlan::Filter(new_filter);
}

Ok(Some(cur_input))
}
_ => Ok(None),
Expand All @@ -116,23 +122,27 @@ impl OptimizerRule for DecorrelateWhereExists {

/// Takes a query like:
///
/// ```select c.id from customers c where exists (select * from orders o where o.c_id = c.id)```
/// SELECT t1.id
/// FROM t1
/// WHERE exists
/// (
/// SELECT t2.id FROM t2 WHERE t1.id = t2.id
/// )
///
/// and optimizes it into:
///
/// ```select c.id from customers c
/// inner join (select o.c_id from orders o group by o.c_id) o on o.c_id = c.c_id```
/// SELECT t1.id
/// FROM t1 LEFT SEMI
/// JOIN t2
/// ON t1.id = t2.id
///
/// # Arguments
///
/// * subqry - The subquery portion of the `where exists` (select * from orders)
/// * negated - True if the subquery is a `where not exists`
/// * filter_input - The non-subquery portion (from customers)
/// * outer_exprs - Any additional parts to the `where` expression (and c.x = y)
/// * query_info - The subquery and negated(exists/not exists) info.
/// * outer_input - The non-subquery portion (relation t1)
fn optimize_exists(
query_info: &SubqueryInfo,
outer_input: &LogicalPlan,
outer_other_exprs: &[Expr],
) -> Result<Option<LogicalPlan>> {
let subqry_filter = match query_info.query.subquery.as_ref() {
LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() {
Expand Down Expand Up @@ -180,18 +190,10 @@ fn optimize_exists(
true => JoinType::LeftAnti,
false => JoinType::LeftSemi,
};
let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join(
subqry_plan,
join_type,
join_keys,
join_filters,
)?;
if let Some(expr) = conjunction(outer_other_exprs.to_vec()) {
new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them
}

let result = new_plan.build()?;
Ok(Some(result))
let new_plan = LogicalPlanBuilder::from(outer_input.clone())
.join(subqry_plan, join_type, join_keys, join_filters)?
.build()?;
Ok(Some(new_plan))
}

struct SubqueryInfo {
Expand Down Expand Up @@ -555,4 +557,44 @@ mod tests {
assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}

#[test]
fn two_exists_subquery_with_outer_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let subquery_scan1 = test_table_scan_with_name("sq1")?;
let subquery_scan2 = test_table_scan_with_name("sq2")?;

let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
.filter(col("test.a").eq(col("sq1.a")))?
.project(vec![col("c")])?
.build()?;

let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
.filter(col("test.a").eq(col("sq2.a")))?
.project(vec![col("c")])?
.build()?;

let plan = LogicalPlanBuilder::from(table_scan)
.filter(
exists(Arc::new(subquery1))
.and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))),
)?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
\n LeftSemi Join: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\
\n LeftSemi Join: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereExists::new()),
&plan,
expected,
);
Ok(())
}
}
50 changes: 22 additions & 28 deletions datafusion/optimizer/src/decorrelate_where_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion_common::{context, Column, Result};
use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col};
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
use datafusion_expr::utils::check_all_column_from_schema;
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder};
use log::debug;
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
Expand Down Expand Up @@ -96,17 +96,18 @@ impl OptimizerRule for DecorrelateWhereIn {
return Ok(None);
}

// iterate through all exists clauses in predicate, turning each into a join
// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = filter.input.as_ref().clone();
for subquery in subqueries {
cur_input = optimize_where_in(
&subquery,
&cur_input,
&other_exprs,
&self.alias,
)?;
cur_input = optimize_where_in(&subquery, &cur_input, &self.alias)?;
}

let expr = conjunction(other_exprs);
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
cur_input = LogicalPlan::Filter(new_filter);
}

Ok(Some(cur_input))
}
_ => Ok(None),
Expand Down Expand Up @@ -141,7 +142,6 @@ impl OptimizerRule for DecorrelateWhereIn {
fn optimize_where_in(
query_info: &SubqueryInfo,
left: &LogicalPlan,
outer_other_exprs: &[Expr],
alias: &AliasGenerator,
) -> Result<LogicalPlan> {
let projection = Projection::try_from_plan(&query_info.query.subquery)
Expand Down Expand Up @@ -207,17 +207,14 @@ fn optimize_where_in(
.map(|filter| in_predicate.clone().and(filter))
.unwrap_or_else(|| in_predicate);

let mut new_plan = LogicalPlanBuilder::from(left.clone()).join(
right,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
Some(join_filter),
)?;

if let Some(expr) = conjunction(outer_other_exprs.to_vec()) {
new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them
}
let new_plan = new_plan.build()?;
let new_plan = LogicalPlanBuilder::from(left.clone())
.join(
right,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
Some(join_filter),
)?
.build()?;

debug!("where in optimized:\n{}", new_plan.display_indent());
Ok(new_plan)
Expand Down Expand Up @@ -1162,17 +1159,14 @@ mod tests {
.project(vec![col("test.b")])?
.build()?;

// Filter: test.c > UInt32(1) happen twice.
// issue: https://github.com/apache/arrow-datafusion/issues/4914
let expected = "Projection: test.b [b:UInt32]\
\n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
\n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\
\n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
\n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\
\n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\
\n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
\n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\
\n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\
\n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, a:UInt32]\
\n Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * UInt32(2):UInt32, a:UInt32]\
\n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
Expand Down

0 comments on commit e566bfc

Please sign in to comment.