diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a026fdbf0226..007d7366e8b9 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -771,6 +771,18 @@ pub fn can_hash(data_type: &DataType) -> bool { } } +/// Check whether all columns are from the schema. +pub fn check_all_column_from_schema( + columns: &HashSet, + schema: DFSchemaRef, +) -> Result { + let result = columns + .iter() + .all(|column| schema.index_of_column(column).is_ok()); + + Ok(result) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 227025773d2d..5d4c47faf7e3 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -54,8 +54,9 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::logical_plan::{Filter, Subquery}; use datafusion_expr::utils::{ - can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, - find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION, + can_hash, check_all_column_from_schema, expand_qualified_wildcard, expand_wildcard, + expr_as_column_expr, find_aggregate_exprs, find_column_exprs, find_window_exprs, + COUNT_STAR_EXPANSION, }; use datafusion_expr::Expr::Alias; use datafusion_expr::{ @@ -711,41 +712,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match constraint { JoinConstraint::On(sql_expr) => { - let mut keys: Vec<(Column, Column)> = vec![]; + let mut keys: Vec<(Expr, Expr)> = vec![]; let join_schema = left.schema().join(right.schema())?; // parse ON expression let expr = self.sql_to_rex(sql_expr, &join_schema, ctes)?; + // normalize all columns in expression + let using_columns = expr.to_columns()?; + let normalized_expr = normalize_col_with_schemas( + expr, + &[left.schema(), right.schema()], + &[using_columns], + )?; + // expression that didn't match equi-join pattern let mut filter = vec![]; // extract join keys extract_join_keys( - expr, + normalized_expr, &mut keys, &mut filter, left.schema(), right.schema(), - ); + )?; - let (left_keys, right_keys): (Vec, Vec) = + let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); - let join_filter = filter - .into_iter() - .map(|expr| { - let using_columns = expr.to_columns()?; - - normalize_col_with_schemas( - expr, - &[left.schema(), right.schema()], - &[using_columns], - ) - }) - .collect::>>()? - .into_iter() - .reduce(Expr::and); + let join_filter = filter.into_iter().reduce(Expr::and); if left_keys.is_empty() { // When we don't have join keys, use cross join @@ -755,9 +751,46 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(Ok(join))? .build() } else { - LogicalPlanBuilder::from(left) - .join(&right, join_type, (left_keys, right_keys), join_filter)? - .build() + // Wrap projection for left input if left join keys contain normal expression. + let (left_child, left_projected) = + wrap_projection_for_join_if_necessary(&left_keys, left)?; + let left_join_keys = left_keys + .iter() + .map(|key| { + key.try_into_col() + .or_else(|_| Ok(Column::from_name(key.display_name()?))) + }) + .collect::>>()?; + + // Wrap projection for right input if right join keys contains normal expression. + let (right_child, right_projected) = + wrap_projection_for_join_if_necessary(&right_keys, right)?; + let right_join_keys = right_keys + .iter() + .map(|key| { + key.try_into_col() + .or_else(|_| Ok(Column::from_name(key.display_name()?))) + }) + .collect::>>()?; + + let join_plan_builder = LogicalPlanBuilder::from(left_child).join( + &right_child, + join_type, + (left_join_keys, right_join_keys), + join_filter, + )?; + + // Remove temporary projected columns if necessary. + if left_projected || right_projected { + let final_join_result = join_schema + .fields() + .iter() + .map(|field| Expr::Column(field.qualified_column())) + .collect::>(); + join_plan_builder.project(final_join_result)?.build() + } else { + join_plan_builder.build() + } } } JoinConstraint::Using(idents) => { @@ -2983,36 +3016,75 @@ fn remove_join_expressions( /// foo = bar => accum=[(foo, bar)] accum_filter=[] /// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] /// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] +/// +/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, c2): +/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10] +/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[] +/// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10] /// ``` fn extract_join_keys( expr: Expr, - accum: &mut Vec<(Column, Column)>, + accum: &mut Vec<(Expr, Expr)>, accum_filter: &mut Vec, left_schema: &Arc, right_schema: &Arc, -) { +) -> Result<()> { match &expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - if left_schema.field_from_column(l).is_ok() - && right_schema.field_from_column(r).is_ok() - && can_hash(left_schema.field_from_column(l).unwrap().data_type()) - { - accum.push((l.clone(), r.clone())); - } else if left_schema.field_from_column(r).is_ok() - && right_schema.field_from_column(l).is_ok() - && can_hash(left_schema.field_from_column(r).unwrap().data_type()) - { - accum.push((r.clone(), l.clone())); + Operator::Eq => { + let left = *left.clone(); + let right = *right.clone(); + let left_using_columns = left.to_columns()?; + let right_using_columns = right.to_columns()?; + + // When one side key does not contain columns, we need move this expression to filter. + // For example: a = 1, a = now() + 10. + if left_using_columns.is_empty() || right_using_columns.is_empty() { + accum_filter.push(expr); + return Ok(()); + } + + // Checking left join key is from left schema, right join key is from right schema, or the opposite. + let l_is_left = check_all_column_from_schema( + &left_using_columns, + left_schema.clone(), + )?; + let r_is_right = check_all_column_from_schema( + &right_using_columns, + right_schema.clone(), + )?; + + let r_is_left_and_l_is_right = || { + let result = check_all_column_from_schema( + &right_using_columns, + left_schema.clone(), + )? && check_all_column_from_schema( + &left_using_columns, + right_schema.clone(), + )?; + + Result::Ok(result) + }; + + let join_key_pair = match (l_is_left, r_is_right) { + (true, true) => Some((left, right)), + (_, _) if r_is_left_and_l_is_right()? => Some((right, left)), + _ => None, + }; + + if let Some((left_expr, right_expr)) = join_key_pair { + let left_expr_type = left_expr.get_type(left_schema)?; + let right_expr_type = right_expr.get_type(right_schema)?; + + if can_hash(&left_expr_type) && can_hash(&right_expr_type) { + accum.push((left_expr, right_expr)); } else { accum_filter.push(expr); } - } - _other => { + } else { accum_filter.push(expr); } - }, + } Operator::And => { if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = expr { extract_join_keys( @@ -3021,14 +3093,14 @@ fn extract_join_keys( accum_filter, left_schema, right_schema, - ); + )?; extract_join_keys( *right, accum, accum_filter, left_schema, right_schema, - ); + )?; } } _other => { @@ -3039,6 +3111,8 @@ fn extract_join_keys( accum_filter.push(expr); } } + + Ok(()) } /// Extract join keys from a WHERE clause @@ -3065,6 +3139,32 @@ fn extract_possible_join_keys( } } +/// Wrap projection for a plan, if the join keys contains normal expression. +fn wrap_projection_for_join_if_necessary( + join_keys: &[Expr], + input: LogicalPlan, +) -> Result<(LogicalPlan, bool)> { + let expr_join_keys = join_keys + .iter() + .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) + .cloned() + .collect::>(); + + let need_project = !expr_join_keys.is_empty(); + let plan = if need_project { + let mut projection = vec![Expr::Wildcard]; + projection.extend(expr_join_keys.into_iter()); + + LogicalPlanBuilder::from(input) + .project(projection)? + .build()? + } else { + input + }; + + Ok((plan, need_project)) +} + #[cfg(test)] mod tests { use std::any::Any; @@ -5693,6 +5793,205 @@ mod tests { assert!(logical_plan("SELECT \"1\"").is_err()); } + #[test] + fn test_constant_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id = 10"; + + let expected = "Projection: person.id, orders.order_id\ + \n Filter: person.id = Int64(10)\ + \n CrossJoin:\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_right_left_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON orders.customer_id * 2 = person.id + 10"; + + let expected = "Projection: person.id, orders.order_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_single_column_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id + 10 = orders.customer_id * 2"; + + let expected = "Projection: person.id, orders.order_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_multiple_column_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id + person.age + 10 = orders.customer_id * 2 - orders.price"; + + let expected = "Projection: person.id, orders.order_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id * Int64(2) - orders.price\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_left_projection_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id + person.age + 10 = orders.customer_id"; + + let expected = "Projection: person.id, orders.order_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_right_projection_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id = orders.customer_id * 2 - orders.price"; + + let expected = "Projection: person.id, orders.order_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id = orders.customer_id * Int64(2) - orders.price\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_one_side_constant_full_join() { + // TODO: this sql should be parsed as join after + // https://github.com/apache/arrow-datafusion/issues/2877 is resolved. + let sql = "SELECT id, order_id \ + FROM person \ + FULL OUTER JOIN orders \ + ON person.id = 10"; + + let expected = "Projection: person.id, orders.order_id\ + \n Filter: person.id = Int64(10)\ + \n CrossJoin:\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_select_all_inner_join() { + let sql = "SELECT * + FROM person \ + INNER JOIN orders \ + ON orders.customer_id * 2 = person.id + 10"; + + let expected = "Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_select_join_key_inner_join() { + let sql = "SELECT orders.customer_id * 2, person.id + 10 + FROM person + INNER JOIN orders + ON orders.customer_id * 2 = person.id + 10"; + + let expected = "Projection: orders.customer_id * Int64(2), person.id + Int64(10)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_non_projetion_after_inner_join() { + // There's no need to add projection for left and right, so does adding projection after join. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON orders.customer_id = person.id"; + + let expected = "Projection: person.id, person.age\ + \n Inner Join: person.id = orders.customer_id\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_duplicated_left_join_key_inner_join() { + // person.id * 2 happen twice in left side. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = orders.order_id"; + + let expected = "Projection: person.id, person.age\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), person.id * Int64(2) = orders.order_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id * Int64(2)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id + Int64(10)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_duplicated_right_join_key_inner_join() { + // orders.customer_id + 10 happen twice in right side. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON person.id * 2 = orders.customer_id + 10 and person.id = orders.customer_id + 10"; + + let expected = "Projection: person.id, person.age\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), person.id = orders.customer_id + Int64(10)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id * Int64(2)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id + Int64(10)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => {