diff --git a/datafusion/core/src/optimizer/filter_push_down.rs b/datafusion/core/src/optimizer/filter_push_down.rs index abfb15dd1ba2c..ba4f9f07f96b3 100644 --- a/datafusion/core/src/optimizer/filter_push_down.rs +++ b/datafusion/core/src/optimizer/filter_push_down.rs @@ -30,6 +30,7 @@ use datafusion_expr::{ Expr, TableProviderFilterPushDown, }; use std::collections::{HashMap, HashSet}; +use std::iter::once; /// Filter Push Down optimizer rule pushes filter clauses down the plan /// # Introduction @@ -65,6 +66,16 @@ struct State { filters: Vec<(Expr, HashSet)>, } +impl State { + fn append_predicates(&mut self, predicates: Predicates) { + predicates + .0 + .into_iter() + .zip(predicates.1) + .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) + } +} + type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); /// returns all predicates in `state` that depend on any of `used_columns` @@ -109,18 +120,6 @@ fn remove_filters( .collect::>() } -// keeps all filters from `filters` that are in `predicate_columns` -fn keep_filters( - filters: &[(Expr, HashSet)], - relevant_predicates: &Predicates, -) -> Vec<(Expr, HashSet)> { - filters - .iter() - .filter(|(expr, _)| relevant_predicates.0.contains(&expr)) - .cloned() - .collect::>() -} - /// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters /// in `state` depend on the columns `used_columns`. fn issue_filters( @@ -178,13 +177,35 @@ fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) { } } +// For a given JOIN logical plan, determine whether each side of the join is preserved +// in terms on join filtering. +// Predicates from join filter can only be pushed to preserved join side. +fn on_lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) { + match plan { + LogicalPlan::Join(Join { join_type, .. }) => match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (false, false), + // Semi/Anti joins can not have join filter. + JoinType::Semi | JoinType::Anti => unreachable!( + "on_lr_is_preserved cannot be appplied to SEMI/ANTI-JOIN nodes" + ), + }, + LogicalPlan::CrossJoin(_) => { + unreachable!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes") + } + _ => unreachable!("on_lr_is_preserved only valid for JOIN nodes"), + } +} + // Determine which predicates in state can be pushed down to a given side of a join. // To determine this, we need to know the schema of the relevant join side and whether // or not the side's rows are preserved when joining. If the side is not preserved, we // do not push down anything. Otherwise we can push down predicates where all of the // relevant columns are contained on the relevant join side's schema. fn get_pushable_join_predicates<'a>( - state: &'a State, + filters: &'a [(Expr, HashSet)], schema: &DFSchema, preserved: bool, ) -> Predicates<'a> { @@ -204,8 +225,7 @@ fn get_pushable_join_predicates<'a>( }) .collect::>(); - state - .filters + filters .iter() .filter(|(_, columns)| { let all_columns_in_schema = schema_columns @@ -224,32 +244,67 @@ fn optimize_join( plan: &LogicalPlan, left: &LogicalPlan, right: &LogicalPlan, + on_filter: Vec<(Expr, HashSet)>, ) -> Result { + // Get pushable predicates from current optimizer state let (left_preserved, right_preserved) = lr_is_preserved(plan); - let to_left = get_pushable_join_predicates(&state, left.schema(), left_preserved); - let to_right = get_pushable_join_predicates(&state, right.schema(), right_preserved); - + let to_left = + get_pushable_join_predicates(&state.filters, left.schema(), left_preserved); + let to_right = + get_pushable_join_predicates(&state.filters, right.schema(), right_preserved); let to_keep: Predicates = state .filters .iter() - .filter(|(expr, _)| { - let pushed_to_left = to_left.0.contains(&expr); - let pushed_to_right = to_right.0.contains(&expr); - !pushed_to_left && !pushed_to_right - }) + .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e)) .map(|(a, b)| (a, b)) .unzip(); - let mut left_state = state.clone(); - left_state.filters = keep_filters(&left_state.filters, &to_left); + // Get pushable predicates from join filter + let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() { + ((vec![], vec![]), (vec![], vec![]), vec![]) + } else { + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan); + let on_to_left = + get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved); + let on_to_right = + get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved); + let on_to_keep = on_filter + .iter() + .filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e)) + .map(|(a, _)| a.clone()) + .collect::>(); + + (on_to_left, on_to_right, on_to_keep) + }; + + // Find pushable predicates in current state and + // append pushable predicates from JOIN ON. + // Then recursively call optimization for both join inputs + let mut left_state = State { filters: vec![] }; + left_state.append_predicates(to_left); + left_state.append_predicates(on_to_left); let left = optimize(left, left_state)?; - let mut right_state = state.clone(); - right_state.filters = keep_filters(&right_state.filters, &to_right); + let mut right_state = State { filters: vec![] }; + right_state.append_predicates(to_right); + right_state.append_predicates(on_to_right); let right = optimize(right, right_state)?; // create a new Join with the new `left` and `right` let expr = plan.expressions(); + let expr = if !on_filter.is_empty() && on_to_keep.is_empty() { + // New filter expression is None - should remove last element + expr[..expr.len() - 1].to_vec() + } else if !on_to_keep.is_empty() { + // Replace last element with new filter expression + expr[..expr.len() - 1] + .iter() + .cloned() + .chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap())) + .collect() + } else { + plan.expressions() + }; let plan = from_plan(plan, &expr, &[left, right])?; if to_keep.0.is_empty() { @@ -399,15 +454,34 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { issue_filters(state, used_columns, plan) } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - optimize_join(state, plan, left, right) + optimize_join(state, plan, left, right, vec![]) } LogicalPlan::Join(Join { left, right, on, + filter, join_type, .. }) => { + // Convert JOIN ON predicate to Predicates + let on_filters = filter + .as_ref() + .map(|e| { + let mut predicates = vec![]; + utils::split_conjunction(e, &mut predicates); + + predicates + .into_iter() + .map(|e| { + let mut accum = HashSet::new(); + expr_to_columns(e, &mut accum)?; + Ok((e.clone(), accum)) + }) + .collect::>>() + }) + .unwrap_or_else(|| Ok(vec![]))?; + if *join_type == JoinType::Inner { // For inner joins, duplicate filters for joined columns so filters can be pushed down // to both sides. Take the following query as an example: @@ -421,9 +495,11 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // // Join clauses with `Using` constraints also take advantage of this logic to make sure // predicates reference the shared join columns are pushed to both sides. + // This logic should also been applied to conditions in JOIN ON clause let join_side_filters = state .filters .iter() + .chain(on_filters.iter()) .filter_map(|(predicate, columns)| { let mut join_cols_to_replace = HashMap::new(); for col in columns.iter() { @@ -464,7 +540,8 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .collect::>>()?; state.filters.extend(join_side_filters); } - optimize_join(state, plan, left, right) + + optimize_join(state, plan, left, right, on_filters) } LogicalPlan::TableScan(TableScan { source, @@ -1340,7 +1417,6 @@ mod tests { } /// single table predicate parts of ON condition should be pushed to both inputs - #[ignore] #[test] fn join_on_with_filter() -> Result<()> { let table_scan = test_table_scan()?; @@ -1351,7 +1427,7 @@ mod tests { let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a"), col("b"), col("c")])? .build()?; - let filter = col("test.a") + let filter = col("test.c") .gt(lit(1u32)) .and(col("test.b").lt(col("test2.b"))) .and(col("test2.c").gt(lit(4u32))); @@ -1368,7 +1444,7 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Inner Join: #test.a = #test2.a Filter: #test.a > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\ + Inner Join: #test.a = #test2.a Filter: #test.c > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\ \n Projection: #test.a, #test.b, #test.c\ \n TableScan: test projection=None\ \n Projection: #test2.a, #test2.b, #test2.c\ @@ -1378,7 +1454,7 @@ mod tests { let expected = "\ Inner Join: #test.a = #test2.a Filter: #test.b < #test2.b\ \n Projection: #test.a, #test.b, #test.c\ - \n Filter: #test.a > UInt32(1)\ + \n Filter: #test.c > UInt32(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a, #test2.b, #test2.c\ \n Filter: #test2.c > UInt32(4)\ @@ -1387,9 +1463,97 @@ mod tests { Ok(()) } + /// join filter should be completely removed after pushdown + #[test] + fn join_filter_removed() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b"), col("c")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b"), col("c")])? + .build()?; + let filter = col("test.b") + .gt(lit(1u32)) + .and(col("test2.c").gt(lit(4u32))); + let plan = LogicalPlanBuilder::from(left) + .join( + &right, + JoinType::Inner, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + Some(filter), + )? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Inner Join: #test.a = #test2.a Filter: #test.b > UInt32(1) AND #test2.c > UInt32(4)\ + \n Projection: #test.a, #test.b, #test.c\ + \n TableScan: test projection=None\ + \n Projection: #test2.a, #test2.b, #test2.c\ + \n TableScan: test2 projection=None" + ); + + let expected = "\ + Inner Join: #test.a = #test2.a\ + \n Projection: #test.a, #test.b, #test.c\ + \n Filter: #test.b > UInt32(1)\ + \n TableScan: test projection=None\ + \n Projection: #test2.a, #test2.b, #test2.c\ + \n Filter: #test2.c > UInt32(4)\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// predicate on join key in filter expression should be pushed down to both inputs + #[test] + fn join_filter_on_common() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let filter = col("test.a").gt(lit(1u32)); + let plan = LogicalPlanBuilder::from(left) + .join( + &right, + JoinType::Inner, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + Some(filter), + )? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Inner Join: #test.a = #test2.a Filter: #test.a > UInt32(1)\ + \n Projection: #test.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + let expected = "\ + Inner Join: #test.a = #test2.a\ + \n Projection: #test.a\ + \n Filter: #test.a > UInt32(1)\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n Filter: #test2.a > UInt32(1)\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// single table predicate parts of ON condition should be pushed to right input - /// https://github.com/apache/arrow-datafusion/issues/2619 - #[ignore] #[test] fn left_join_on_with_filter() -> Result<()> { let table_scan = test_table_scan()?; @@ -1436,8 +1600,6 @@ mod tests { } /// single table predicate parts of ON condition should be pushed to left input - /// https://github.com/apache/arrow-datafusion/issues/2619 - #[ignore] #[test] fn right_join_on_with_filter() -> Result<()> { let table_scan = test_table_scan()?; @@ -1478,13 +1640,12 @@ mod tests { \n Filter: #test.a > UInt32(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a, #test2.b, #test2.c\ - \n TableScan: test2 projection=None"; + \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } /// single table predicate parts of ON condition should not be pushed - /// https://github.com/apache/arrow-datafusion/issues/2619 #[test] fn full_join_on_with_filter() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ea8075bf2be0a..bb3434a57da26 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -227,9 +227,15 @@ impl LogicalPlan { aggr_expr, .. }) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(), - LogicalPlan::Join(Join { on, .. }) => on + LogicalPlan::Join(Join { on, filter, .. }) => on .iter() .flat_map(|(l, r)| vec![Expr::Column(l.clone()), Expr::Column(r.clone())]) + .chain( + filter + .as_ref() + .map(|expr| vec![expr.clone()]) + .unwrap_or_default(), + ) .collect(), LogicalPlan::Sort(Sort { expr, .. }) => expr.clone(), LogicalPlan::Extension(extension) => extension.node.expressions(), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 483b12b49a5e8..9085796cfd143 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -378,19 +378,24 @@ pub fn from_plan( join_type, join_constraint, on, - filter, null_equals_null, .. }) => { let schema = build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; + let filter_expr = if on.len() * 2 == expr.len() { + None + } else { + Some(expr[expr.len() - 1].clone()) + }; + Ok(LogicalPlan::Join(Join { left: Arc::new(inputs[0].clone()), right: Arc::new(inputs[1].clone()), join_type: *join_type, join_constraint: *join_constraint, on: on.clone(), - filter: filter.clone(), + filter: filter_expr, schema: DFSchemaRef::new(schema), null_equals_null: *null_equals_null, })) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 3feb870baa1c2..84ac0fc1b261d 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -577,7 +577,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); - let normalized_filters = filter + let join_filter = filter .into_iter() .map(|expr| { let mut using_columns = HashSet::new(); @@ -589,98 +589,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[using_columns], ) }) - .collect::>>()?; + .collect::>>()? + .into_iter() + .reduce(Expr::and); if left_keys.is_empty() { // When we don't have join keys, use cross join let join = LogicalPlanBuilder::from(left).cross_join(&right)?; - normalized_filters - .into_iter() - .reduce(Expr::and) + join_filter .map(|filter| join.filter(filter)) .unwrap_or(Ok(join))? .build() - } else if join_type == JoinType::Inner && !normalized_filters.is_empty() { - let join = LogicalPlanBuilder::from(left).join( - &right, - join_type, - (left_keys, right_keys), - None, - )?; - join.filter( - normalized_filters.into_iter().reduce(Expr::and).unwrap(), - )? - .build() - } else if join_type == JoinType::Left { - // Inner filters - predicates based only on right input columns - // Outer filters - predicates using left input columns - // - // Inner filters are safe to push to right input and exclude from ON - let (inner_filters, outer_filters): (Vec<_>, Vec<_>) = - normalized_filters.into_iter().partition(|e| { - find_column_exprs(&[e.clone()]) - .iter() - .filter_map(|e| match e { - Expr::Column(column) => Some(column), - _ => None, - }) - .all(|c| right.schema().index_of_column(c).is_ok()) - }); - - let right_input = if inner_filters.is_empty() { - right - } else { - LogicalPlanBuilder::from(right) - .filter(inner_filters.into_iter().reduce(Expr::and).unwrap())? - .build()? - }; - - let join = LogicalPlanBuilder::from(left).join( - &right_input, - join_type, - (left_keys, right_keys), - outer_filters.into_iter().reduce(Expr::and), - )?; - join.build() - } else if join_type == JoinType::Right && !normalized_filters.is_empty() { - // Inner filters - predicates based only on left input columns - // Outer filters - predicates using right input columns - // - // Inner filters are safe to push to left input and exclude from ON - let (inner_filters, outer_filters): (Vec<_>, Vec<_>) = - normalized_filters.into_iter().partition(|e| { - find_column_exprs(&[e.clone()]) - .iter() - .filter_map(|e| match e { - Expr::Column(column) => Some(column), - _ => None, - }) - .all(|c| left.schema().index_of_column(c).is_ok()) - }); - - let left_input = if inner_filters.is_empty() { - left - } else { - LogicalPlanBuilder::from(left) - .filter(inner_filters.into_iter().reduce(Expr::and).unwrap())? - .build()? - }; - - let join = LogicalPlanBuilder::from(left_input).join( - &right, - join_type, - (left_keys, right_keys), - outer_filters.into_iter().reduce(Expr::and), - )?; - join.build() } else { - let join = LogicalPlanBuilder::from(left).join( - &right, - join_type, - (left_keys, right_keys), - normalized_filters.into_iter().reduce(Expr::and), - )?; - join.build() + LogicalPlanBuilder::from(left) + .join(&right, join_type, (left_keys, right_keys), join_filter)? + .build() } } JoinConstraint::Using(idents) => { @@ -3823,10 +3746,9 @@ mod tests { JOIN orders \ ON id = customer_id AND order_id > 1 "; let expected = "Projection: #person.id, #orders.order_id\ - \n Filter: #orders.order_id > Int64(1)\ - \n Inner Join: #person.id = #orders.customer_id\ - \n TableScan: person projection=None\ - \n TableScan: orders projection=None"; + \n Inner Join: #person.id = #orders.customer_id Filter: #orders.order_id > Int64(1)\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3837,10 +3759,9 @@ mod tests { LEFT JOIN orders \ ON id = customer_id AND order_id > 1 AND age < 30"; let expected = "Projection: #person.id, #orders.order_id\ - \n Left Join: #person.id = #orders.customer_id Filter: #person.age < Int64(30)\ + \n Left Join: #person.id = #orders.customer_id Filter: #orders.order_id > Int64(1) AND #person.age < Int64(30)\ \n TableScan: person projection=None\ - \n Filter: #orders.order_id > Int64(1)\ - \n TableScan: orders projection=None"; + \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3851,9 +3772,8 @@ mod tests { RIGHT JOIN orders \ ON id = customer_id AND id > 1 AND order_id < 100"; let expected = "Projection: #person.id, #orders.order_id\ - \n Right Join: #person.id = #orders.customer_id Filter: #orders.order_id < Int64(100)\ - \n Filter: #person.id > Int64(1)\ - \n TableScan: person projection=None\ + \n Right Join: #person.id = #orders.customer_id Filter: #person.id > Int64(1) AND #orders.order_id < Int64(100)\ + \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -4842,10 +4762,9 @@ mod tests { FROM person \ JOIN orders ON id = customer_id AND (person.age > 30 OR person.last_name = 'X')"; let expected = "Projection: #person.id, #orders.order_id\ - \n Filter: #person.age > Int64(30) OR #person.last_name = Utf8(\"X\")\ - \n Inner Join: #person.id = #orders.customer_id\ - \n TableScan: person projection=None\ - \n TableScan: orders projection=None"; + \n Inner Join: #person.id = #orders.customer_id Filter: #person.age > Int64(30) OR #person.last_name = Utf8(\"X\")\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; quick_test(sql, expected); }