diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 60b9ba3031a1..c47a86974cd2 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -18,12 +18,13 @@ //! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::DFSchema; +use datafusion_common::tree_node::Transformed; use datafusion_common::Result; -use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair, split_conjunction}; +use datafusion_common::{internal_err, DFSchema}; +use datafusion_expr::utils::split_conjunction_owned; +use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; use std::sync::Arc; - // equijoin predicate type EquijoinPredicate = (Expr, Expr); @@ -51,15 +52,34 @@ impl ExtractEquijoinPredicate { impl OptimizerRule for ExtractEquijoinPredicate { fn try_optimize( &self, - plan: &LogicalPlan, + _plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called ExtractEquijoinPredicate::rewrite") + } + fn supports_rewrite(&self) -> bool { + true + } + + fn name(&self) -> &str { + "extract_equijoin_predicate" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { match plan { LogicalPlan::Join(Join { left, right, - on, - filter, + mut on, + filter: Some(expr), join_type, join_constraint, schema, @@ -67,66 +87,55 @@ impl OptimizerRule for ExtractEquijoinPredicate { }) => { let left_schema = left.schema(); let right_schema = right.schema(); - - filter.as_ref().map_or(Result::Ok(None), |expr| { - let (equijoin_predicates, non_equijoin_expr) = - split_eq_and_noneq_join_predicate( - expr, - left_schema, - right_schema, - )?; - - let optimized_plan = (!equijoin_predicates.is_empty()).then(|| { - let mut new_on = on.clone(); - new_on.extend(equijoin_predicates); - - LogicalPlan::Join(Join { - left: left.clone(), - right: right.clone(), - on: new_on, - filter: non_equijoin_expr, - join_type: *join_type, - join_constraint: *join_constraint, - schema: schema.clone(), - null_equals_null: *null_equals_null, - }) - }); - - Ok(optimized_plan) - }) + let (equijoin_predicates, non_equijoin_expr) = + split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?; + + if !equijoin_predicates.is_empty() { + on.extend(equijoin_predicates); + Ok(Transformed::yes(LogicalPlan::Join(Join { + left, + right, + on, + filter: non_equijoin_expr, + join_type, + join_constraint, + schema, + null_equals_null, + }))) + } else { + Ok(Transformed::no(LogicalPlan::Join(Join { + left, + right, + on, + filter: non_equijoin_expr, + join_type, + join_constraint, + schema, + null_equals_null, + }))) + } } - _ => Ok(None), + _ => Ok(Transformed::no(plan)), } } - - fn name(&self) -> &str { - "extract_equijoin_predicate" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } } fn split_eq_and_noneq_join_predicate( - filter: &Expr, + filter: Expr, left_schema: &Arc, right_schema: &Arc, ) -> Result<(Vec, Option)> { - let exprs = split_conjunction(filter); + let exprs = split_conjunction_owned(filter); let mut accum_join_keys: Vec<(Expr, Expr)> = vec![]; let mut accum_filters: Vec = vec![]; for expr in exprs { match expr { Expr::BinaryExpr(BinaryExpr { - left, + ref left, op: Operator::Eq, - right, + ref right, }) => { - let left = left.as_ref(); - let right = right.as_ref(); - let join_key_pair = find_valid_equijoin_key_pair( left, right, @@ -141,13 +150,13 @@ fn split_eq_and_noneq_join_predicate( if can_hash(&left_expr_type) && can_hash(&right_expr_type) { accum_join_keys.push((left_expr, right_expr)); } else { - accum_filters.push(expr.clone()); + accum_filters.push(expr); } } else { - accum_filters.push(expr.clone()); + accum_filters.push(expr); } } - _ => accum_filters.push(expr.clone()), + _ => accum_filters.push(expr), } }