Skip to content

Commit

Permalink
pushdown support for JOIN ON predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed May 29, 2022
1 parent df2094f commit 7af4483
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 141 deletions.
239 changes: 200 additions & 39 deletions datafusion/core/src/optimizer/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use datafusion_expr::{
Expr, TableProviderFilterPushDown,
};
use std::collections::{HashMap, HashSet};
use std::iter::once;

/// Filter Push Down optimizer rule pushes filter clauses down the plan
/// # Introduction
Expand Down Expand Up @@ -65,6 +66,16 @@ struct State {
filters: Vec<(Expr, HashSet<Column>)>,
}

impl State {
fn append_predicates(&mut self, predicates: Predicates) {
predicates
.0
.into_iter()
.zip(predicates.1)
.for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone())))
}
}

type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<Column>>);

/// returns all predicates in `state` that depend on any of `used_columns`
Expand Down Expand Up @@ -109,18 +120,6 @@ fn remove_filters(
.collect::<Vec<_>>()
}

// keeps all filters from `filters` that are in `predicate_columns`
fn keep_filters(
filters: &[(Expr, HashSet<Column>)],
relevant_predicates: &Predicates,
) -> Vec<(Expr, HashSet<Column>)> {
filters
.iter()
.filter(|(expr, _)| relevant_predicates.0.contains(&expr))
.cloned()
.collect::<Vec<_>>()
}

/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
/// in `state` depend on the columns `used_columns`.
fn issue_filters(
Expand Down Expand Up @@ -178,13 +177,35 @@ fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
}
}

// For a given JOIN logical plan, determine whether each side of the join is preserved
// in terms on join filtering.
// Predicates from join filter can only be pushed to preserved join side.
fn on_lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => (true, true),
JoinType::Left => (false, true),
JoinType::Right => (true, false),
JoinType::Full => (false, false),
// Semi/Anti joins can not have join filter.
JoinType::Semi | JoinType::Anti => unreachable!(
"on_lr_is_preserved cannot be appplied to SEMI/ANTI-JOIN nodes"
),
},
LogicalPlan::CrossJoin(_) => {
unreachable!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes")
}
_ => unreachable!("on_lr_is_preserved only valid for JOIN nodes"),
}
}

// Determine which predicates in state can be pushed down to a given side of a join.
// To determine this, we need to know the schema of the relevant join side and whether
// or not the side's rows are preserved when joining. If the side is not preserved, we
// do not push down anything. Otherwise we can push down predicates where all of the
// relevant columns are contained on the relevant join side's schema.
fn get_pushable_join_predicates<'a>(
state: &'a State,
filters: &'a [(Expr, HashSet<Column>)],
schema: &DFSchema,
preserved: bool,
) -> Predicates<'a> {
Expand All @@ -204,8 +225,7 @@ fn get_pushable_join_predicates<'a>(
})
.collect::<HashSet<_>>();

state
.filters
filters
.iter()
.filter(|(_, columns)| {
let all_columns_in_schema = schema_columns
Expand All @@ -224,32 +244,67 @@ fn optimize_join(
plan: &LogicalPlan,
left: &LogicalPlan,
right: &LogicalPlan,
on_filter: Vec<(Expr, HashSet<Column>)>,
) -> Result<LogicalPlan> {
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(plan);
let to_left = get_pushable_join_predicates(&state, left.schema(), left_preserved);
let to_right = get_pushable_join_predicates(&state, right.schema(), right_preserved);

let to_left =
get_pushable_join_predicates(&state.filters, left.schema(), left_preserved);
let to_right =
get_pushable_join_predicates(&state.filters, right.schema(), right_preserved);
let to_keep: Predicates = state
.filters
.iter()
.filter(|(expr, _)| {
let pushed_to_left = to_left.0.contains(&expr);
let pushed_to_right = to_right.0.contains(&expr);
!pushed_to_left && !pushed_to_right
})
.filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e))
.map(|(a, b)| (a, b))
.unzip();

let mut left_state = state.clone();
left_state.filters = keep_filters(&left_state.filters, &to_left);
// Get pushable predicates from join filter
let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() {
((vec![], vec![]), (vec![], vec![]), vec![])
} else {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan);
let on_to_left =
get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved);
let on_to_right =
get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved);
let on_to_keep = on_filter
.iter()
.filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e))
.map(|(a, _)| a.clone())
.collect::<Vec<_>>();

(on_to_left, on_to_right, on_to_keep)
};

// Find pushable predicates in current state and
// append pushable predicates from JOIN ON.
// Then recursively call optimization for both join inputs
let mut left_state = State { filters: vec![] };
left_state.append_predicates(to_left);
left_state.append_predicates(on_to_left);
let left = optimize(left, left_state)?;

let mut right_state = state.clone();
right_state.filters = keep_filters(&right_state.filters, &to_right);
let mut right_state = State { filters: vec![] };
right_state.append_predicates(to_right);
right_state.append_predicates(on_to_right);
let right = optimize(right, right_state)?;

// create a new Join with the new `left` and `right`
let expr = plan.expressions();
let expr = if !on_filter.is_empty() && on_to_keep.is_empty() {
// New filter expression is None - should remove last element
expr[..expr.len() - 1].to_vec()
} else if !on_to_keep.is_empty() {
// Replace last element with new filter expression
expr[..expr.len() - 1]
.iter()
.cloned()
.chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap()))
.collect()
} else {
plan.expressions()
};
let plan = from_plan(plan, &expr, &[left, right])?;

if to_keep.0.is_empty() {
Expand Down Expand Up @@ -399,15 +454,34 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
issue_filters(state, used_columns, plan)
}
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
optimize_join(state, plan, left, right)
optimize_join(state, plan, left, right, vec![])
}
LogicalPlan::Join(Join {
left,
right,
on,
filter,
join_type,
..
}) => {
// Convert JOIN ON predicate to Predicates
let on_filters = filter
.as_ref()
.map(|e| {
let mut predicates = vec![];
utils::split_conjunction(e, &mut predicates);

predicates
.into_iter()
.map(|e| {
let mut accum = HashSet::new();
expr_to_columns(e, &mut accum)?;
Ok((e.clone(), accum))
})
.collect::<Result<Vec<_>>>()
})
.unwrap_or_else(|| Ok(vec![]))?;

if *join_type == JoinType::Inner {
// For inner joins, duplicate filters for joined columns so filters can be pushed down
// to both sides. Take the following query as an example:
Expand All @@ -421,9 +495,11 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
//
// Join clauses with `Using` constraints also take advantage of this logic to make sure
// predicates reference the shared join columns are pushed to both sides.
// This logic should also been applied to conditions in JOIN ON clause
let join_side_filters = state
.filters
.iter()
.chain(on_filters.iter())
.filter_map(|(predicate, columns)| {
let mut join_cols_to_replace = HashMap::new();
for col in columns.iter() {
Expand Down Expand Up @@ -464,7 +540,8 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
.collect::<Result<Vec<_>>>()?;
state.filters.extend(join_side_filters);
}
optimize_join(state, plan, left, right)

optimize_join(state, plan, left, right, on_filters)
}
LogicalPlan::TableScan(TableScan {
source,
Expand Down Expand Up @@ -1340,7 +1417,6 @@ mod tests {
}

/// single table predicate parts of ON condition should be pushed to both inputs
#[ignore]
#[test]
fn join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
Expand All @@ -1351,7 +1427,7 @@ mod tests {
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
let filter = col("test.c")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
Expand All @@ -1368,7 +1444,7 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"\
Inner Join: #test.a = #test2.a Filter: #test.a > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\
Inner Join: #test.a = #test2.a Filter: #test.c > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test projection=None\
\n Projection: #test2.a, #test2.b, #test2.c\
Expand All @@ -1378,7 +1454,7 @@ mod tests {
let expected = "\
Inner Join: #test.a = #test2.a Filter: #test.b < #test2.b\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.a > UInt32(1)\
\n Filter: #test.c > UInt32(1)\
\n TableScan: test projection=None\
\n Projection: #test2.a, #test2.b, #test2.c\
\n Filter: #test2.c > UInt32(4)\
Expand All @@ -1387,9 +1463,97 @@ mod tests {
Ok(())
}

/// join filter should be completely removed after pushdown
#[test]
fn join_filter_removed() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.b")
.gt(lit(1u32))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;

// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
"\
Inner Join: #test.a = #test2.a Filter: #test.b > UInt32(1) AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test projection=None\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2 projection=None"
);

let expected = "\
Inner Join: #test.a = #test2.a\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.b > UInt32(1)\
\n TableScan: test projection=None\
\n Projection: #test2.a, #test2.b, #test2.c\
\n Filter: #test2.c > UInt32(4)\
\n TableScan: test2 projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// predicate on join key in filter expression should be pushed down to both inputs
#[test]
fn join_filter_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let filter = col("test.a").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;

// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
"\
Inner Join: #test.a = #test2.a Filter: #test.a > UInt32(1)\
\n Projection: #test.a\
\n TableScan: test projection=None\
\n Projection: #test2.a\
\n TableScan: test2 projection=None"
);

let expected = "\
Inner Join: #test.a = #test2.a\
\n Projection: #test.a\
\n Filter: #test.a > UInt32(1)\
\n TableScan: test projection=None\
\n Projection: #test2.a\
\n Filter: #test2.a > UInt32(1)\
\n TableScan: test2 projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// single table predicate parts of ON condition should be pushed to right input
/// https://github.com/apache/arrow-datafusion/issues/2619
#[ignore]
#[test]
fn left_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down Expand Up @@ -1436,8 +1600,6 @@ mod tests {
}

/// single table predicate parts of ON condition should be pushed to left input
/// https://github.com/apache/arrow-datafusion/issues/2619
#[ignore]
#[test]
fn right_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down Expand Up @@ -1478,13 +1640,12 @@ mod tests {
\n Filter: #test.a > UInt32(1)\
\n TableScan: test projection=None\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2 projection=None";
\n TableScan: test2 projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// single table predicate parts of ON condition should not be pushed
/// https://github.com/apache/arrow-datafusion/issues/2619
#[test]
fn full_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down
Loading

0 comments on commit 7af4483

Please sign in to comment.