-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Encode all join conditions in a single expression field #7612
Changes from 5 commits
27e8ab3
301414c
cf651fd
590e7a9
b48cc0f
3f2d442
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
use async_recursion::async_recursion; | ||
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; | ||
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; | ||
|
||
use datafusion::logical_expr::{ | ||
aggregate_function, window_function::find_df_window_func, BinaryExpr, | ||
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, | ||
|
@@ -129,6 +130,51 @@ fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> { | |
} | ||
} | ||
|
||
fn split_eq_and_noneq_join_predicate_with_nulls_equality( | ||
filter: &Expr, | ||
) -> (Vec<(Column, Column)>, bool, Option<Expr>) { | ||
let exprs = split_conjunction(filter); | ||
|
||
let mut accum_join_keys: Vec<(Column, Column)> = vec![]; | ||
let mut accum_filters: Vec<Expr> = vec![]; | ||
let mut nulls_equal_nulls = false; | ||
|
||
for expr in exprs { | ||
match expr { | ||
Expr::BinaryExpr(binary_expr) => match binary_expr { | ||
x @ (BinaryExpr { | ||
left, | ||
op: Operator::Eq, | ||
right, | ||
} | ||
| BinaryExpr { | ||
left, | ||
op: Operator::IsNotDistinctFrom, | ||
right, | ||
}) => { | ||
nulls_equal_nulls = match x.op { | ||
Operator::Eq => false, | ||
Operator::IsNotDistinctFrom => true, | ||
_ => unreachable!(), | ||
}; | ||
|
||
match (left.as_ref(), right.as_ref()) { | ||
(Expr::Column(l), Expr::Column(r)) => { | ||
accum_join_keys.push((l.clone(), r.clone())); | ||
} | ||
_ => accum_filters.push(expr.clone()), | ||
} | ||
} | ||
_ => accum_filters.push(expr.clone()), | ||
}, | ||
_ => accum_filters.push(expr.clone()), | ||
} | ||
} | ||
|
||
let join_filter = accum_filters.into_iter().reduce(Expr::and); | ||
(accum_join_keys, nulls_equal_nulls, join_filter) | ||
} | ||
|
||
/// Convert Substrait Plan to DataFusion DataFrame | ||
pub async fn from_substrait_plan( | ||
ctx: &mut SessionContext, | ||
|
@@ -331,7 +377,13 @@ pub async fn from_substrait_rel( | |
} | ||
} | ||
Some(RelType::Join(join)) => { | ||
let left = LogicalPlanBuilder::from( | ||
if join.post_join_filter.is_some() { | ||
return not_impl_err!( | ||
"JoinRel with post_join_filter is not yet supported" | ||
); | ||
} | ||
|
||
let left: LogicalPlanBuilder = LogicalPlanBuilder::from( | ||
from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, | ||
); | ||
let right = LogicalPlanBuilder::from( | ||
|
@@ -341,65 +393,32 @@ pub async fn from_substrait_rel( | |
// The join condition expression needs full input schema and not the output schema from join since we lose columns from | ||
// certain join types such as semi and anti joins | ||
let in_join_schema = left.schema().join(right.schema())?; | ||
// Parse post join filter if exists | ||
let join_filter = match &join.post_join_filter { | ||
Some(filter) => { | ||
let parsed_filter = | ||
from_substrait_rex(filter, &in_join_schema, extensions).await?; | ||
Some(parsed_filter.as_ref().clone()) | ||
} | ||
None => None, | ||
}; | ||
|
||
// If join expression exists, parse the `on` condition expression, build join and return | ||
// Otherwise, build join with koin filter, without join keys | ||
// Otherwise, build join with only the filter, without join keys | ||
match &join.expression.as_ref() { | ||
Some(expr) => { | ||
let on = | ||
from_substrait_rex(expr, &in_join_schema, extensions).await?; | ||
let predicates = split_conjunction(&on); | ||
// TODO: collect only one null_eq_null | ||
let join_exprs: Vec<(Column, Column, bool)> = predicates | ||
.iter() | ||
.map(|p| match p { | ||
Expr::BinaryExpr(BinaryExpr { left, op, right }) => { | ||
match (left.as_ref(), right.as_ref()) { | ||
(Expr::Column(l), Expr::Column(r)) => match op { | ||
Operator::Eq => Ok((l.clone(), r.clone(), false)), | ||
Operator::IsNotDistinctFrom => { | ||
Ok((l.clone(), r.clone(), true)) | ||
} | ||
_ => plan_err!("invalid join condition op"), | ||
}, | ||
_ => plan_err!("invalid join condition expression"), | ||
} | ||
} | ||
_ => plan_err!( | ||
"Non-binary expression is not supported in join condition" | ||
), | ||
}) | ||
.collect::<Result<Vec<_>>>()?; | ||
let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) = | ||
itertools::multiunzip(join_exprs); | ||
// The join expression can contain both equal and non-equal ops. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the Something like left.join(
right.build()?,
join_type,
(vec![], vec![]),
on, // <-- use the filter directly here, let optimizer pass extract the equijoin columns
nulls_equal_nulls,
)? It makes me realize when looking at the API for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns out this is exactly what |
||
// As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. | ||
// So we extract each part as follows: | ||
// - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector | ||
// - Otherwise we add the expression to join_filter (use conjunction if filter already exists) | ||
let (join_ons, nulls_equal_nulls, join_filter) = | ||
split_eq_and_noneq_join_predicate_with_nulls_equality(&on); | ||
let (left_cols, right_cols): (Vec<_>, Vec<_>) = | ||
itertools::multiunzip(join_ons); | ||
left.join_detailed( | ||
right.build()?, | ||
join_type, | ||
(left_cols, right_cols), | ||
join_filter, | ||
null_eq_nulls[0], | ||
nulls_equal_nulls, | ||
)? | ||
.build() | ||
} | ||
None => match &join_filter { | ||
Some(_) => left | ||
.join( | ||
right.build()?, | ||
join_type, | ||
(Vec::<Column>::new(), Vec::<Column>::new()), | ||
join_filter, | ||
)? | ||
.build(), | ||
None => plan_err!("Join without join keys require a valid filter"), | ||
}, | ||
None => plan_err!("JoinRel without join condition is not allowed"), | ||
} | ||
} | ||
Some(RelType::Read(read)) => match &read.as_ref().read_type { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -277,39 +277,56 @@ pub fn to_substrait_rel( | |
// parse filter if exists | ||
let in_join_schema = join.left.schema().join(join.right.schema())?; | ||
let join_filter = match &join.filter { | ||
Some(filter) => Some(Box::new(to_substrait_rex( | ||
Some(filter) => Some(to_substrait_rex( | ||
filter, | ||
&Arc::new(in_join_schema), | ||
0, | ||
extension_info, | ||
)?)), | ||
)?), | ||
None => None, | ||
}; | ||
|
||
// map the left and right columns to binary expressions in the form `l = r` | ||
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` | ||
let eq_op = if join.null_equals_null { | ||
Operator::IsNotDistinctFrom | ||
} else { | ||
Operator::Eq | ||
}; | ||
|
||
let join_expr = to_substrait_join_expr( | ||
let join_on = to_substrait_join_expr( | ||
&join.on, | ||
eq_op, | ||
join.left.schema(), | ||
join.right.schema(), | ||
extension_info, | ||
)? | ||
.map(Box::new); | ||
)?; | ||
|
||
// create conjunction between `join_on` and `join_filter` to embed all join conditions, | ||
// whether equal or non-equal in a single expression | ||
let join_expr = match &join_on { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps you could use |
||
Some(on_expr) => match &join_filter { | ||
Some(filter) => Some(Box::new(make_binary_op_scalar_func( | ||
on_expr, | ||
filter, | ||
Operator::And, | ||
extension_info, | ||
))), | ||
None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist | ||
}, | ||
None => match &join_filter { | ||
Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist | ||
None => None, | ||
}, | ||
}; | ||
|
||
Ok(Box::new(Rel { | ||
rel_type: Some(RelType::Join(Box::new(JoinRel { | ||
common: None, | ||
left: Some(left), | ||
right: Some(right), | ||
r#type: join_type as i32, | ||
expression: join_expr, | ||
post_join_filter: join_filter, | ||
expression: join_expr.clone(), | ||
post_join_filter: None, | ||
advanced_extension: None, | ||
}))), | ||
})) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,15 +23,18 @@ use std::hash::Hash; | |
use std::sync::Arc; | ||
|
||
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; | ||
use datafusion::common::{DFSchema, DFSchemaRef}; | ||
use datafusion::error::Result; | ||
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; | ||
use datafusion::error::{DataFusionError, Result}; | ||
use datafusion::execution::context::SessionState; | ||
use datafusion::execution::registry::SerializerRegistry; | ||
use datafusion::execution::runtime_env::RuntimeEnv; | ||
use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; | ||
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; | ||
use datafusion::prelude::*; | ||
|
||
use substrait::proto::extensions::simple_extension_declaration::MappingType; | ||
use substrait::proto::rel::RelType; | ||
use substrait::proto::{plan_rel, Plan, Rel}; | ||
|
||
struct MockSerializerRegistry; | ||
|
||
|
@@ -378,12 +381,15 @@ async fn roundtrip_inner_join() -> Result<()> { | |
|
||
#[tokio::test] | ||
async fn roundtrip_non_equi_inner_join() -> Result<()> { | ||
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await | ||
roundtrip_verify_post_join_filter( | ||
"SELECT data.a FROM data JOIN data2 ON data.a <> data2.a", | ||
) | ||
.await | ||
} | ||
|
||
#[tokio::test] | ||
async fn roundtrip_non_equi_join() -> Result<()> { | ||
roundtrip( | ||
roundtrip_verify_post_join_filter( | ||
"SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > data2.a", | ||
) | ||
.await | ||
|
@@ -615,6 +621,91 @@ async fn extension_logical_plan() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
fn check_post_join_filters(rel: &Rel) -> Result<()> { | ||
// search for target_rel and field value in proto | ||
match &rel.rel_type { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like it might be helpful (eventually) do define |
||
Some(RelType::Join(join)) => { | ||
// check if join filter is None | ||
if join.post_join_filter.is_some() { | ||
plan_err!( | ||
"DataFusion generated Susbtrait plan cannot have post_join_filter in JoinRel" | ||
) | ||
} else { | ||
// recursively check JoinRels | ||
match check_post_join_filters(join.left.as_ref().unwrap().as_ref()) { | ||
Err(e) => Err(e), | ||
Ok(_) => { | ||
check_post_join_filters(join.right.as_ref().unwrap().as_ref()) | ||
} | ||
} | ||
} | ||
} | ||
Some(RelType::Project(p)) => { | ||
check_post_join_filters(p.input.as_ref().unwrap().as_ref()) | ||
} | ||
Some(RelType::Filter(filter)) => { | ||
check_post_join_filters(filter.input.as_ref().unwrap().as_ref()) | ||
} | ||
Some(RelType::Fetch(fetch)) => { | ||
check_post_join_filters(fetch.input.as_ref().unwrap().as_ref()) | ||
} | ||
Some(RelType::Sort(sort)) => { | ||
check_post_join_filters(sort.input.as_ref().unwrap().as_ref()) | ||
} | ||
Some(RelType::Aggregate(agg)) => { | ||
check_post_join_filters(agg.input.as_ref().unwrap().as_ref()) | ||
} | ||
Some(RelType::Set(set)) => { | ||
for input in &set.inputs { | ||
match check_post_join_filters(input) { | ||
Err(e) => return Err(e), | ||
Ok(_) => continue, | ||
} | ||
} | ||
Ok(()) | ||
} | ||
Some(RelType::ExtensionSingle(ext)) => { | ||
check_post_join_filters(ext.input.as_ref().unwrap().as_ref()) | ||
} | ||
Some(RelType::ExtensionMulti(ext)) => { | ||
for input in &ext.inputs { | ||
match check_post_join_filters(input) { | ||
Err(e) => return Err(e), | ||
Ok(_) => continue, | ||
} | ||
} | ||
Ok(()) | ||
} | ||
Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()), | ||
_ => not_impl_err!( | ||
"Unsupported RelType: {:?} in post join filter check", | ||
rel.rel_type | ||
), | ||
} | ||
} | ||
|
||
async fn verify_post_join_filter_value(proto: Box<Plan>) -> Result<()> { | ||
for relation in &proto.relations { | ||
match relation.rel_type.as_ref() { | ||
Some(rt) => match rt { | ||
plan_rel::RelType::Rel(rel) => match check_post_join_filters(rel) { | ||
Err(e) => return Err(e), | ||
Ok(_) => continue, | ||
}, | ||
plan_rel::RelType::Root(root) => { | ||
match check_post_join_filters(root.input.as_ref().unwrap()) { | ||
Err(e) => return Err(e), | ||
Ok(_) => continue, | ||
} | ||
} | ||
}, | ||
None => return plan_err!("Cannot parse plan relation: None"), | ||
} | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { | ||
let mut ctx = create_context().await?; | ||
let df = ctx.sql(sql).await?; | ||
|
@@ -683,6 +774,25 @@ async fn roundtrip(sql: &str) -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { | ||
let mut ctx = create_context().await?; | ||
let df = ctx.sql(sql).await?; | ||
let plan = df.into_optimized_plan()?; | ||
let proto = to_substrait_plan(&plan, &ctx)?; | ||
let plan2 = from_substrait_plan(&mut ctx, &proto).await?; | ||
let plan2 = ctx.state().optimize(&plan2)?; | ||
|
||
println!("{plan:#?}"); | ||
println!("{plan2:#?}"); | ||
|
||
let plan1str = format!("{plan:?}"); | ||
let plan2str = format!("{plan2:?}"); | ||
assert_eq!(plan1str, plan2str); | ||
|
||
// verify that the join filters are None | ||
verify_post_join_filter_value(proto).await | ||
} | ||
|
||
async fn roundtrip_all_types(sql: &str) -> Result<()> { | ||
let mut ctx = create_all_type_context().await?; | ||
let df = ctx.sql(sql).await?; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍