From 4352bb2b78d7ef8922b7fdea429ce11710ed82a9 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 11 Nov 2022 07:48:56 -0500 Subject: [PATCH 01/10] Support normal expressions in equality join --- datafusion/expr/src/utils.rs | 12 ++ datafusion/sql/src/planner.rs | 199 +++++++++++++++++++++++++++------- 2 files changed, 172 insertions(+), 39 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2be9a6465df6..8750560fe179 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -774,6 +774,18 @@ pub fn can_hash(data_type: &DataType) -> bool { } } +/// Check whether all columns are from the schema. +pub fn check_all_column_from_schema( + columns: &HashSet, + schema: DFSchemaRef, +) -> Result { + let result = columns + .iter() + .all(|column| schema.index_of_column(column).is_ok()); + + Ok(result) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index dacd4af87248..1a6f398222fd 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -41,6 +41,7 @@ use datafusion_expr::{ window_function::WindowFunction, BuiltinScalarFunction, TableSource, }; use std::collections::{HashMap, HashSet}; +use std::ops::Not; use std::str::FromStr; use std::sync::Arc; use std::{convert::TryInto, vec}; @@ -53,6 +54,7 @@ use datafusion_common::{ use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; use datafusion_expr::logical_plan::builder::project_with_alias; use datafusion_expr::logical_plan::{Filter, Subquery}; +use datafusion_expr::utils::check_all_column_from_schema; use datafusion_expr::Expr::Alias; use sqlparser::ast::TimezoneInfo; @@ -665,41 +667,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match constraint { JoinConstraint::On(sql_expr) => { - let mut keys: Vec<(Column, Column)> = vec![]; + let mut keys: Vec<(Expr, Expr)> = vec![]; let join_schema = left.schema().join(right.schema())?; // parse ON expression let expr = self.sql_to_rex(sql_expr, &join_schema, ctes)?; + // normalize all columns in expression + let using_columns = expr.to_columns()?; + let normalized_expr = normalize_col_with_schemas( + expr, + &[left.schema(), right.schema()], + &[using_columns], + )?; + // expression that didn't match equi-join pattern let mut filter = vec![]; // extract join keys extract_join_keys( - expr, + normalized_expr, &mut keys, &mut filter, left.schema(), right.schema(), - ); + )?; - let (left_keys, right_keys): (Vec, Vec) = + let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); - let join_filter = filter - .into_iter() - .map(|expr| { - let using_columns = expr.to_columns()?; - - normalize_col_with_schemas( - expr, - &[left.schema(), right.schema()], - &[using_columns], - ) - }) - .collect::>>()? - .into_iter() - .reduce(Expr::and); + let join_filter = filter.into_iter().reduce(Expr::and); if left_keys.is_empty() { // When we don't have join keys, use cross join @@ -709,8 +706,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(Ok(join))? .build() } else { - LogicalPlanBuilder::from(left) - .join(&right, join_type, (left_keys, right_keys), join_filter)? + let left_child = + wrap_projection_for_join_if_necessary(&left_keys, left)?; + let left_join_keys = left_keys + .iter() + .map(|key| { + if let Ok(col) = key.try_into_col() { + Ok(col) + } else { + Ok(Column::from_name(key.display_name()?)) + } + }) + .collect::>>()?; + + let right_child = + wrap_projection_for_join_if_necessary(&right_keys, right)?; + let right_join_keys = right_keys + .iter() + .map(|key| { + if let Ok(col) = key.try_into_col() { + Ok(col) + } else { + Ok(Column::from_name(key.display_name()?)) + } + }) + .collect::>>()?; + + LogicalPlanBuilder::from(left_child) + .join( + &right_child, + join_type, + (left_join_keys, right_join_keys), + join_filter, + )? .build() } } @@ -2840,33 +2868,68 @@ fn remove_join_expressions( /// ``` fn extract_join_keys( expr: Expr, - accum: &mut Vec<(Column, Column)>, + accum: &mut Vec<(Expr, Expr)>, accum_filter: &mut Vec, left_schema: &Arc, right_schema: &Arc, -) { +) -> Result<()> { match &expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - if left_schema.field_from_column(l).is_ok() - && right_schema.field_from_column(r).is_ok() - && can_hash(left_schema.field_from_column(l).unwrap().data_type()) - { - accum.push((l.clone(), r.clone())); - } else if left_schema.field_from_column(r).is_ok() - && right_schema.field_from_column(l).is_ok() - && can_hash(left_schema.field_from_column(r).unwrap().data_type()) - { - accum.push((r.clone(), l.clone())); + Operator::Eq => { + let left = *left.clone(); + let right = *right.clone(); + + let left_using_columns = left.to_columns()?; + let right_using_columns = right.to_columns()?; + + // When there is one side expression does not contain columns, we need move this expression to filter. + // For example: a = 1, a = now() + 10. + if left_using_columns.is_empty() || right_using_columns.is_empty() { + accum_filter.push(expr); + return Ok(()); + } + + let l_is_left = check_all_column_from_schema( + &left_using_columns, + left_schema.clone(), + )?; + let r_is_right = check_all_column_from_schema( + &right_using_columns, + right_schema.clone(), + )?; + + let r_is_left_and_l_is_right = || { + let result = check_all_column_from_schema( + &right_using_columns, + left_schema.clone(), + )? && check_all_column_from_schema( + &left_using_columns, + right_schema.clone(), + )?; + + Result::Ok(result) + }; + + let join_key_pair = match (l_is_left, r_is_right) { + (true, true) => Some((left, right)), + (_, _) if r_is_left_and_l_is_right()? => Some((right, left)), + _ => None, + }; + + if let Some((left_expr, right_expr)) = join_key_pair { + let left_expr_type = left_expr.get_type(left_schema)?; + let right_expr_type = right_expr.get_type(right_schema)?; + + // TODO: Maybe this check can be done later in optimizer + if can_hash(&left_expr_type) && can_hash(&right_expr_type) { + accum.push((left_expr, right_expr)); } else { accum_filter.push(expr); } - } - _other => { + } else { accum_filter.push(expr); } - }, + } Operator::And => { if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = expr { extract_join_keys( @@ -2875,14 +2938,14 @@ fn extract_join_keys( accum_filter, left_schema, right_schema, - ); + )?; extract_join_keys( *right, accum, accum_filter, left_schema, right_schema, - ); + )?; } } _other => { @@ -2893,6 +2956,8 @@ fn extract_join_keys( accum_filter.push(expr); } } + + Ok(()) } /// Extract join keys from a WHERE clause @@ -2934,6 +2999,30 @@ fn parse_sql_number(n: &str) -> Result { }) } +fn wrap_projection_for_join_if_necessary( + join_keys: &[Expr], + input: LogicalPlan, +) -> Result { + let expr_join_keys = join_keys + .iter() + .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) + .cloned() + .collect::>(); + + let handled_input = if expr_join_keys.is_empty().not() { + let mut projection = expand_wildcard(input.schema(), &input)?; + projection.extend_from_slice(&expr_join_keys); + + LogicalPlanBuilder::from(input) + .project(projection)? + .build()? + } else { + input + }; + + Ok(handled_input) +} + #[cfg(test)] mod tests { use super::*; @@ -5509,6 +5598,38 @@ mod tests { assert!(logical_plan("SELECT \"1\"").is_err()); } + #[test] + fn test_single_column_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id + 10 = orders.customer_id * 2"; + + let expected = "Projection: person.id, orders.order_id\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_multiple_column_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id + person.age + 10 = orders.customer_id * 2 - orders.price"; + + let expected = "Projection: person.id, orders.order_id\ + \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id * Int64(2) - orders.price\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ + \n TableScan: orders"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { From 60629ac37963e44f7e052d2494a67d35c0bfcba8 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Mon, 14 Nov 2022 03:19:50 -0500 Subject: [PATCH 02/10] Add some tests --- datafusion/sql/src/planner.rs | 94 +++++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 1a6f398222fd..2d21b894c850 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -706,29 +706,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(Ok(join))? .build() } else { + // Wrap projection fro left input if left join keys has normal expression. let left_child = wrap_projection_for_join_if_necessary(&left_keys, left)?; let left_join_keys = left_keys .iter() .map(|key| { - if let Ok(col) = key.try_into_col() { - Ok(col) - } else { - Ok(Column::from_name(key.display_name()?)) - } + key.try_into_col() + .or_else(|_| Ok(Column::from_name(key.display_name()?))) }) .collect::>>()?; + // Wrap projection fro left input if left join keys has normal expression. let right_child = wrap_projection_for_join_if_necessary(&right_keys, right)?; let right_join_keys = right_keys .iter() .map(|key| { - if let Ok(col) = key.try_into_col() { - Ok(col) - } else { - Ok(Column::from_name(key.display_name()?)) - } + key.try_into_col() + .or_else(|_| Ok(Column::from_name(key.display_name()?))) }) .collect::>>()?; @@ -2865,6 +2861,11 @@ fn remove_join_expressions( /// foo = bar => accum=[(foo, bar)] accum_filter=[] /// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] /// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] +/// +/// For normal expression join key, assume we have a(c0, c1 c2) and b(c0, c1, c2): +/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10] +/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[] +/// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10] /// ``` fn extract_join_keys( expr: Expr, @@ -2878,17 +2879,17 @@ fn extract_join_keys( Operator::Eq => { let left = *left.clone(); let right = *right.clone(); - let left_using_columns = left.to_columns()?; let right_using_columns = right.to_columns()?; - // When there is one side expression does not contain columns, we need move this expression to filter. + // When one side key does not contain columns, we need move this expression to filter. // For example: a = 1, a = now() + 10. if left_using_columns.is_empty() || right_using_columns.is_empty() { accum_filter.push(expr); return Ok(()); } + // Checking left join key is from left schema, right join key is from right schema, or the opposite. let l_is_left = check_all_column_from_schema( &left_using_columns, left_schema.clone(), @@ -2920,7 +2921,6 @@ fn extract_join_keys( let left_expr_type = left_expr.get_type(left_schema)?; let right_expr_type = right_expr.get_type(right_schema)?; - // TODO: Maybe this check can be done later in optimizer if can_hash(&left_expr_type) && can_hash(&right_expr_type) { accum.push((left_expr, right_expr)); } else { @@ -2999,6 +2999,7 @@ fn parse_sql_number(n: &str) -> Result { }) } +/// Wrap projection for a plan, if the join keys contains normal expression. fn wrap_projection_for_join_if_necessary( join_keys: &[Expr], input: LogicalPlan, @@ -3009,8 +3010,8 @@ fn wrap_projection_for_join_if_necessary( .cloned() .collect::>(); - let handled_input = if expr_join_keys.is_empty().not() { - let mut projection = expand_wildcard(input.schema(), &input)?; + let plan = if expr_join_keys.is_empty().not() { + let mut projection = vec![Expr::Wildcard]; projection.extend_from_slice(&expr_join_keys); LogicalPlanBuilder::from(input) @@ -3020,7 +3021,7 @@ fn wrap_projection_for_join_if_necessary( input }; - Ok(handled_input) + Ok(plan) } #[cfg(test)] @@ -5598,6 +5599,37 @@ mod tests { assert!(logical_plan("SELECT \"1\"").is_err()); } + #[test] + fn test_constant_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id = 10"; + + let expected = "Projection: person.id, orders.order_id\ + \n Filter: person.id = Int64(10)\ + \n CrossJoin:\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_right_left_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON orders.customer_id * 2 = person.id + 10"; + + let expected = "Projection: person.id, orders.order_id\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + #[test] fn test_single_column_expr_eq_join() { let sql = "SELECT id, order_id \ @@ -5630,6 +5662,36 @@ mod tests { quick_test(sql, expected); } + #[test] + fn test_left_projection_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id + person.age + 10 = orders.customer_id"; + + let expected = "Projection: person.id, orders.order_id\ + \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_right_projection_expr_eq_join() { + let sql = "SELECT id, order_id \ + FROM person \ + INNER JOIN orders \ + ON person.id = orders.customer_id * 2 - orders.price"; + + let expected = "Projection: person.id, orders.order_id\ + \n Inner Join: person.id = orders.customer_id * Int64(2) - orders.price\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ + \n TableScan: orders"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { From 01562492d21fb05c8f6ebd0e1383bafe8e48cc20 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Mon, 14 Nov 2022 03:22:49 -0500 Subject: [PATCH 03/10] Improve comment --- datafusion/sql/src/planner.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 2d21b894c850..eec6b5a2d16f 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -706,7 +706,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(Ok(join))? .build() } else { - // Wrap projection fro left input if left join keys has normal expression. + // Wrap projection for left input if left join keys contain normal expression. let left_child = wrap_projection_for_join_if_necessary(&left_keys, left)?; let left_join_keys = left_keys @@ -717,7 +717,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - // Wrap projection fro left input if left join keys has normal expression. + // Wrap projection for right input if right join keys contains normal expression. let right_child = wrap_projection_for_join_if_necessary(&right_keys, right)?; let right_join_keys = right_keys @@ -2862,7 +2862,7 @@ fn remove_join_expressions( /// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] /// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] /// -/// For normal expression join key, assume we have a(c0, c1 c2) and b(c0, c1, c2): +/// For normal expression join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, c2): /// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10] /// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[] /// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10] From 003ff98c6cf6550114d40e35e03ecdc788c03efb Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 16 Nov 2022 16:09:18 +0800 Subject: [PATCH 04/10] Update datafusion/sql/src/planner.rs Co-authored-by: Andrew Lamb --- datafusion/sql/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index eec6b5a2d16f..3735d2b35c4c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -3010,7 +3010,7 @@ fn wrap_projection_for_join_if_necessary( .cloned() .collect::>(); - let plan = if expr_join_keys.is_empty().not() { + let plan = if !expr_join_keys.is_empty() { let mut projection = vec![Expr::Wildcard]; projection.extend_from_slice(&expr_join_keys); From 0f7547ea7a6c57784984d5281d58a0550d58add6 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 16 Nov 2022 04:25:02 -0500 Subject: [PATCH 05/10] Add another test case --- datafusion/sql/src/planner.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 3735d2b35c4c..54cc89486e7b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -41,7 +41,6 @@ use datafusion_expr::{ window_function::WindowFunction, BuiltinScalarFunction, TableSource, }; use std::collections::{HashMap, HashSet}; -use std::ops::Not; use std::str::FromStr; use std::sync::Arc; use std::{convert::TryInto, vec}; @@ -2862,7 +2861,7 @@ fn remove_join_expressions( /// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] /// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] /// -/// For normal expression join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, c2): +/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, c2): /// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10] /// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[] /// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10] @@ -5692,6 +5691,21 @@ mod tests { quick_test(sql, expected); } + #[test] + fn test_one_side_constant_full_join() { + let sql = "SELECT id, order_id \ + FROM person \ + FULL OUTER JOIN orders \ + ON person.id = 10"; + + let expected = "Projection: person.id, orders.order_id\ + \n Filter: person.id = Int64(10)\ + \n CrossJoin:\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { From 3a04cbdf2f4e5c1ade9ac375cd55186b74026564 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 16 Nov 2022 04:41:36 -0500 Subject: [PATCH 06/10] Add comment --- datafusion/sql/src/planner.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 54cc89486e7b..da0695e18c46 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -5693,6 +5693,8 @@ mod tests { #[test] fn test_one_side_constant_full_join() { + // TODO: this sql should transfer to join after + // https://github.com/apache/arrow-datafusion/issues/2877 is resolved. let sql = "SELECT id, order_id \ FROM person \ FULL OUTER JOIN orders \ From f605adfce41e314ae97a11701dc2fc7fe1ad2b56 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 16 Nov 2022 05:55:52 -0500 Subject: [PATCH 07/10] fix typo --- datafusion/sql/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 67a641043a56..0b2822a15899 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -5788,7 +5788,7 @@ mod tests { #[test] fn test_one_side_constant_full_join() { - // TODO: this sql should transfer to join after + // TODO: this sql should be parsed as join after // https://github.com/apache/arrow-datafusion/issues/2877 is resolved. let sql = "SELECT id, order_id \ FROM person \ From 1a9a672ec6f0465c1835732b16be30f4b214f9ff Mon Sep 17 00:00:00 2001 From: ygf11 Date: Thu, 17 Nov 2022 01:59:45 -0500 Subject: [PATCH 08/10] remove dumplicated keys in projection and add test cases --- datafusion/sql/src/planner.rs | 176 ++++++++++++++++++++++++++-------- 1 file changed, 138 insertions(+), 38 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 0b2822a15899..fbf067c7b9e2 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -721,7 +721,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } else { // Wrap projection for left input if left join keys contain normal expression. - let left_child = + let (left_child, left_projected) = wrap_projection_for_join_if_necessary(&left_keys, left)?; let left_join_keys = left_keys .iter() @@ -732,7 +732,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; // Wrap projection for right input if right join keys contains normal expression. - let right_child = + let (right_child, right_projected) = wrap_projection_for_join_if_necessary(&right_keys, right)?; let right_join_keys = right_keys .iter() @@ -742,14 +742,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - LogicalPlanBuilder::from(left_child) - .join( - &right_child, - join_type, - (left_join_keys, right_join_keys), - join_filter, - )? - .build() + let join_plan_builder = LogicalPlanBuilder::from(left_child).join( + &right_child, + join_type, + (left_join_keys, right_join_keys), + join_filter, + )?; + + // Remove temporary projection columns. + if left_projected || right_projected { + let final_join_result = join_schema + .fields() + .iter() + .map(|field| Expr::Column(field.qualified_column())) + .collect::>(); + join_plan_builder.project(final_join_result)?.build() + } else { + join_plan_builder.build() + } } } JoinConstraint::Using(idents) => { @@ -3047,16 +3057,17 @@ fn extract_possible_join_keys( fn wrap_projection_for_join_if_necessary( join_keys: &[Expr], input: LogicalPlan, -) -> Result { +) -> Result<(LogicalPlan, bool)> { let expr_join_keys = join_keys .iter() .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) .cloned() - .collect::>(); + .collect::>(); - let plan = if !expr_join_keys.is_empty() { + let need_project = !expr_join_keys.is_empty(); + let plan = if need_project { let mut projection = vec![Expr::Wildcard]; - projection.extend_from_slice(&expr_join_keys); + projection.extend(expr_join_keys.into_iter()); LogicalPlanBuilder::from(input) .project(projection)? @@ -3065,7 +3076,7 @@ fn wrap_projection_for_join_if_necessary( input }; - Ok(plan) + Ok((plan, need_project)) } #[cfg(test)] @@ -5716,11 +5727,12 @@ mod tests { ON orders.customer_id * 2 = person.id + 10"; let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ - \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ - \n TableScan: person\ - \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ - \n TableScan: orders"; + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; quick_test(sql, expected); } @@ -5732,11 +5744,12 @@ mod tests { ON person.id + 10 = orders.customer_id * 2"; let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ - \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ - \n TableScan: person\ - \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ - \n TableScan: orders"; + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; quick_test(sql, expected); } @@ -5748,11 +5761,12 @@ mod tests { ON person.id + person.age + 10 = orders.customer_id * 2 - orders.price"; let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id * Int64(2) - orders.price\ - \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ - \n TableScan: person\ - \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ - \n TableScan: orders"; + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id * Int64(2) - orders.price\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ + \n TableScan: orders"; quick_test(sql, expected); } @@ -5764,10 +5778,11 @@ mod tests { ON person.id + person.age + 10 = orders.customer_id"; let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id\ - \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + person.age + Int64(10) = orders.customer_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + person.age + Int64(10)\ + \n TableScan: person\ + \n TableScan: orders"; quick_test(sql, expected); } @@ -5779,10 +5794,11 @@ mod tests { ON person.id = orders.customer_id * 2 - orders.price"; let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: person.id = orders.customer_id * Int64(2) - orders.price\ - \n TableScan: person\ - \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ - \n TableScan: orders"; + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id = orders.customer_id * Int64(2) - orders.price\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2) - orders.price\ + \n TableScan: orders"; quick_test(sql, expected); } @@ -5803,6 +5819,90 @@ mod tests { quick_test(sql, expected); } + #[test] + fn test_select_all_inner_join() { + let sql = "SELECT * + FROM person \ + INNER JOIN orders \ + ON orders.customer_id * 2 = person.id + 10"; + + let expected = "Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_select_join_key_inner_join() { + let sql = "SELECT orders.customer_id * 2, person.id + 10 + FROM person + INNER JOIN orders + ON orders.customer_id * 2 = person.id + 10"; + + let expected = "Projection: orders.customer_id * Int64(2), person.id + Int64(10)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id + Int64(10)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id * Int64(2)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_non_projetion_after_inner_join() { + // There's no need to add projection for left and right, so does adding projection after join. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON orders.customer_id = person.id"; + + let expected = "Projection: person.id, person.age\ + \n Inner Join: person.id = orders.customer_id\ + \n TableScan: person\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_dumplicated_left_join_key_inner_join() { + // person.id * 2 happen twice in left side. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = orders.order_id"; + + let expected = "Projection: person.id, person.age\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), person.id * Int64(2) = orders.order_id\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id * Int64(2)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id + Int64(10)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + + #[test] + fn test_dumplicated_right_join_key_inner_join() { + // orders.customer_id + 10 happen twice in right side. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON person.id * 2 = orders.customer_id + 10 and person.id = orders.customer_id + 10"; + + let expected = "Projection: person.id, person.age\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), person.id = orders.customer_id + Int64(10)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.id * Int64(2)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, orders.customer_id + Int64(10)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { From b5ec6536c4e8da392da55b63a6124e6e0c95b509 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Thu, 17 Nov 2022 03:16:49 -0500 Subject: [PATCH 09/10] fix comment --- datafusion/sql/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index c80aac3c7179..5775333f7cf6 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -775,7 +775,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { join_filter, )?; - // Remove temporary projection columns. + // Remove temporary projected columns if necessary. if left_projected || right_projected { let final_join_result = join_schema .fields() From 15188e882897a58b0c93d18ca2f57eb1d6e328dc Mon Sep 17 00:00:00 2001 From: ygf11 Date: Thu, 17 Nov 2022 04:08:27 -0500 Subject: [PATCH 10/10] fix typo --- datafusion/sql/src/planner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 5775333f7cf6..59bc5c1ce335 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -5898,7 +5898,7 @@ mod tests { } #[test] - fn test_dumplicated_left_join_key_inner_join() { + fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. let sql = "SELECT person.id, person.age FROM person @@ -5916,7 +5916,7 @@ mod tests { } #[test] - fn test_dumplicated_right_join_key_inner_join() { + fn test_duplicated_right_join_key_inner_join() { // orders.customer_id + 10 happen twice in right side. let sql = "SELECT person.id, person.age FROM person