From f5a6d01a5ef9bc583ecad583c850aa30ca356ee2 Mon Sep 17 00:00:00 2001 From: Nuttiiya Seekhao <37189615+nseekhao@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:10:08 -0400 Subject: [PATCH] Encode all join conditions in a single expression field (#7612) * Encode all join conditions in a single expression field * Removed all references to post_join_filter * Simplify from_substrait_rel() * Clippy fix * Added test to ensure that Substrait plans produced from DF do not contain a post_join_filter --------- Co-authored-by: Andrew Lamb --- .../substrait/src/logical_plan/consumer.rs | 108 +++++++++------- .../substrait/src/logical_plan/producer.rs | 33 +++-- .../tests/cases/roundtrip_logical_plan.rs | 118 +++++++++++++++++- 3 files changed, 205 insertions(+), 54 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 8d99d1981b91..82e457767bb4 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -18,6 +18,7 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; + use datafusion::logical_expr::{ aggregate_function, window_function::find_df_window_func, BinaryExpr, BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, @@ -129,6 +130,51 @@ fn scalar_function_type_from_str(name: &str) -> Result { } } +fn split_eq_and_noneq_join_predicate_with_nulls_equality( + filter: &Expr, +) -> (Vec<(Column, Column)>, bool, Option) { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut accum_filters: Vec = vec![]; + let mut nulls_equal_nulls = false; + + for expr in exprs { + match expr { + Expr::BinaryExpr(binary_expr) => match binary_expr { + x @ (BinaryExpr { + left, + op: Operator::Eq, + right, + } + | BinaryExpr { + left, + op: Operator::IsNotDistinctFrom, + right, + }) => { + nulls_equal_nulls = match x.op { + Operator::Eq => false, + Operator::IsNotDistinctFrom => true, + _ => unreachable!(), + }; + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + accum_join_keys.push((l.clone(), r.clone())); + } + _ => accum_filters.push(expr.clone()), + } + } + _ => accum_filters.push(expr.clone()), + }, + _ => accum_filters.push(expr.clone()), + } + } + + let join_filter = accum_filters.into_iter().reduce(Expr::and); + (accum_join_keys, nulls_equal_nulls, join_filter) +} + /// Convert Substrait Plan to DataFusion DataFrame pub async fn from_substrait_plan( ctx: &mut SessionContext, @@ -336,7 +382,13 @@ pub async fn from_substrait_rel( } } Some(RelType::Join(join)) => { - let left = LogicalPlanBuilder::from( + if join.post_join_filter.is_some() { + return not_impl_err!( + "JoinRel with post_join_filter is not yet supported" + ); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, ); let right = LogicalPlanBuilder::from( @@ -346,60 +398,32 @@ pub async fn from_substrait_rel( // The join condition expression needs full input schema and not the output schema from join since we lose columns from // certain join types such as semi and anti joins let in_join_schema = left.schema().join(right.schema())?; - // Parse post join filter if exists - let join_filter = match &join.post_join_filter { - Some(filter) => { - let parsed_filter = - from_substrait_rex(filter, &in_join_schema, extensions).await?; - Some(parsed_filter.as_ref().clone()) - } - None => None, - }; + // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with koin filter, without join keys + // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { let on = from_substrait_rex(expr, &in_join_schema, extensions).await?; - let predicates = split_conjunction(&on); - // TODO: collect only one null_eq_null - let join_exprs: Vec<(Column, Column, bool)> = predicates - .iter() - .map(|p| match p { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => match op { - Operator::Eq => Ok((l.clone(), r.clone(), false)), - Operator::IsNotDistinctFrom => { - Ok((l.clone(), r.clone(), true)) - } - _ => plan_err!("invalid join condition op"), - }, - _ => plan_err!("invalid join condition expression"), - } - } - _ => plan_err!( - "Non-binary expression is not supported in join condition" - ), - }) - .collect::>>()?; - let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) = - itertools::multiunzip(join_exprs); + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); left.join_detailed( right.build()?, join_type, (left_cols, right_cols), join_filter, - null_eq_nulls[0], + nulls_equal_nulls, )? .build() } - None => match &join_filter { - Some(_) => left - .join_on(right.build()?, join_type, join_filter)? - .build(), - None => plan_err!("Join without join keys require a valid filter"), - }, + None => plan_err!("JoinRel without join condition is not allowed"), } } Some(RelType::Read(read)) => match &read.as_ref().read_type { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1124ea53a557..757bddf9fe58 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -278,14 +278,15 @@ pub fn to_substrait_rel( // parse filter if exists let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { - Some(filter) => Some(Box::new(to_substrait_rex( + Some(filter) => Some(to_substrait_rex( filter, &Arc::new(in_join_schema), 0, extension_info, - )?)), + )?), None => None, }; + // map the left and right columns to binary expressions in the form `l = r` // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` let eq_op = if join.null_equals_null { @@ -293,15 +294,31 @@ pub fn to_substrait_rel( } else { Operator::Eq }; - - let join_expr = to_substrait_join_expr( + let join_on = to_substrait_join_expr( &join.on, eq_op, join.left.schema(), join.right.schema(), extension_info, - )? - .map(Box::new); + )?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + on_expr, + filter, + Operator::And, + extension_info, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; Ok(Box::new(Rel { rel_type: Some(RelType::Join(Box::new(JoinRel { @@ -309,8 +326,8 @@ pub fn to_substrait_rel( left: Some(left), right: Some(right), r#type: join_type as i32, - expression: join_expr, - post_join_filter: join_filter, + expression: join_expr.clone(), + post_join_filter: None, advanced_extension: None, }))), })) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 9b9afa159c20..32416125de24 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -23,15 +23,18 @@ use std::hash::Hash; use std::sync::Arc; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use datafusion::common::{DFSchema, DFSchemaRef}; -use datafusion::error::Result; +use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; +use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; + use substrait::proto::extensions::simple_extension_declaration::MappingType; +use substrait::proto::rel::RelType; +use substrait::proto::{plan_rel, Plan, Rel}; struct MockSerializerRegistry; @@ -383,12 +386,15 @@ async fn roundtrip_inner_join() -> Result<()> { #[tokio::test] async fn roundtrip_non_equi_inner_join() -> Result<()> { - roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await + roundtrip_verify_post_join_filter( + "SELECT data.a FROM data JOIN data2 ON data.a <> data2.a", + ) + .await } #[tokio::test] async fn roundtrip_non_equi_join() -> Result<()> { - roundtrip( + roundtrip_verify_post_join_filter( "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > data2.a", ) .await @@ -620,6 +626,91 @@ async fn extension_logical_plan() -> Result<()> { Ok(()) } +fn check_post_join_filters(rel: &Rel) -> Result<()> { + // search for target_rel and field value in proto + match &rel.rel_type { + Some(RelType::Join(join)) => { + // check if join filter is None + if join.post_join_filter.is_some() { + plan_err!( + "DataFusion generated Susbtrait plan cannot have post_join_filter in JoinRel" + ) + } else { + // recursively check JoinRels + match check_post_join_filters(join.left.as_ref().unwrap().as_ref()) { + Err(e) => Err(e), + Ok(_) => { + check_post_join_filters(join.right.as_ref().unwrap().as_ref()) + } + } + } + } + Some(RelType::Project(p)) => { + check_post_join_filters(p.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Filter(filter)) => { + check_post_join_filters(filter.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Fetch(fetch)) => { + check_post_join_filters(fetch.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Sort(sort)) => { + check_post_join_filters(sort.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Aggregate(agg)) => { + check_post_join_filters(agg.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Set(set)) => { + for input in &set.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionSingle(ext)) => { + check_post_join_filters(ext.input.as_ref().unwrap().as_ref()) + } + Some(RelType::ExtensionMulti(ext)) => { + for input in &ext.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()), + _ => not_impl_err!( + "Unsupported RelType: {:?} in post join filter check", + rel.rel_type + ), + } +} + +async fn verify_post_join_filter_value(proto: Box) -> Result<()> { + for relation in &proto.relations { + match relation.rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => match check_post_join_filters(rel) { + Err(e) => return Err(e), + Ok(_) => continue, + }, + plan_rel::RelType::Root(root) => { + match check_post_join_filters(root.input.as_ref().unwrap()) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + }, + None => return plan_err!("Cannot parse plan relation: None"), + } + } + + Ok(()) +} + async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { let mut ctx = create_context().await?; let df = ctx.sql(sql).await?; @@ -688,6 +779,25 @@ async fn roundtrip(sql: &str) -> Result<()> { Ok(()) } +async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { + let mut ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + println!("{plan:#?}"); + println!("{plan2:#?}"); + + let plan1str = format!("{plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + + // verify that the join filters are None + verify_post_join_filter_value(proto).await +} + async fn roundtrip_all_types(sql: &str) -> Result<()> { let mut ctx = create_all_type_context().await?; let df = ctx.sql(sql).await?;