diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8e310d1f4e8aa..1c3186b762b71 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -194,9 +194,50 @@ fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { } } -/// Return true if a predicate only references columns in the specified schema -fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result { - let schema_columns = schema +/// Evaluates the columns referenced in the given expression to see if they refer +/// only to the left or right columns +#[derive(Debug)] +struct ColumnChecker<'a> { + /// schema of left join input + left_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + left_columns: Option>, + /// schema of right join input + right_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + right_columns: Option>, +} + +impl<'a> ColumnChecker<'a> { + fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self { + Self { + left_schema, + left_columns: None, + right_schema, + right_columns: None, + } + } + + /// Return true if the expression references only columns from the left side of the join + fn is_left_only(&mut self, predicate: &Expr) -> bool { + if self.left_columns.is_none() { + self.left_columns = Some(schema_columns(self.left_schema)); + } + has_all_column_refs(predicate, self.left_columns.as_ref().unwrap()) + } + + /// Return true if the expression references only columns from the right side of the join + fn is_right_only(&mut self, predicate: &Expr) -> bool { + if self.right_columns.is_none() { + self.right_columns = Some(schema_columns(self.right_schema)); + } + has_all_column_refs(predicate, self.right_columns.as_ref().unwrap()) + } +} + +/// Returns all columns in the schema +fn schema_columns(schema: &DFSchema) -> HashSet { + schema .iter() .flat_map(|(qualifier, field)| { [ @@ -205,8 +246,7 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result>(); - Ok(has_all_column_refs(predicate, &schema_columns)) + .collect::>() } /// Determine whether the predicate can evaluate as the join conditions @@ -291,16 +331,7 @@ fn extract_or_clauses_for_join<'a>( filters: &'a [Expr], schema: &'a DFSchema, ) -> impl Iterator + 'a { - let schema_columns = schema - .iter() - .flat_map(|(qualifier, field)| { - [ - Column::new(qualifier.cloned(), field.name()), - // we need to push down filter using unqualified column as well - Column::new_unqualified(field.name()), - ] - }) - .collect::>(); + let schema_columns = schema_columns(schema); // new formed OR clauses and their column references filters.iter().filter_map(move |expr| { @@ -403,12 +434,11 @@ fn push_down_all_join( let mut right_push = vec![]; let mut keep_predicates = vec![]; let mut join_conditions = vec![]; + let mut checker = ColumnChecker::new(left_schema, right_schema); for predicate in predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate @@ -421,11 +451,9 @@ fn push_down_all_join( // For infer predicates, if they can not push through join, just drop them for predicate in inferred_join_predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } } @@ -435,11 +463,9 @@ fn push_down_all_join( if !on_filter.is_empty() { for on in on_filter { - if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { + if on_left_preserved && checker.is_left_only(&on) { left_push.push(on) - } else if on_right_preserved - && can_pushdown_join_predicate(&on, right_schema)? - { + } else if on_right_preserved && checker.is_right_only(&on) { right_push.push(on) } else { on_filter_join_conditions.push(on)