diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 729c45426ff2..fc4eaef80903 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -386,9 +386,7 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { None } // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if matches!(op, Operator::And | Operator::Or) => - { + Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => { let l = remove_join_expressions(*left, join_keys); let r = remove_join_expressions(*right, join_keys); match (l, r) { @@ -402,7 +400,20 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { _ => None, } } - + Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => { + let l = remove_join_expressions(*left, join_keys); + let r = remove_join_expressions(*right, join_keys); + match (l, r) { + (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + Box::new(ll), + op, + Box::new(rr), + ))), + // When either `left` or `right` is empty, it means they are `true` + // so OR'ing anything with them will also be true + _ => None, + } + } _ => Some(expr), } } @@ -995,6 +1006,7 @@ mod tests { let t4 = test_table_scan_with_name("t4")?; // could eliminate to inner join + // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688) let plan1 = LogicalPlanBuilder::from(t1) .cross_join(t2)? .filter(binary_expr( @@ -1012,6 +1024,10 @@ mod tests { let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?; // could eliminate to inner join + // filter: + // ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688)) + // AND + // ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b)) let plan = LogicalPlanBuilder::from(plan1) .cross_join(plan2)? .filter(binary_expr( @@ -1057,7 +1073,7 @@ mod tests { "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", @@ -1084,6 +1100,12 @@ mod tests { let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?; // could eliminate to inner join + // Filter: + // ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688)) + // AND + // ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b)) + // AND + // ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688)) let plan = LogicalPlanBuilder::from(plan1) .cross_join(plan2)? .filter(binary_expr( @@ -1142,7 +1164,7 @@ mod tests { .build()?; let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 4f01d2b2c72b..84aeb3ebd766 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -1023,7 +1023,6 @@ statement ok DROP TABLE t3; -# Test issue: https://github.com/apache/datafusion/issues/11275 statement ok CREATE TABLE t0 (v1 BOOLEAN) AS VALUES (false), (null); @@ -1033,6 +1032,7 @@ CREATE TABLE t1 (v1 BOOLEAN) AS VALUES (false), (null), (false); statement ok CREATE TABLE t2 (v1 BOOLEAN) AS VALUES (false), (true); +# Test issue: https://github.com/apache/datafusion/issues/11275 query BB SELECT t2.v1, t1.v1 FROM t0, t1, t2 WHERE t2.v1 IS DISTINCT FROM t0.v1 ORDER BY 1,2; ---- @@ -1046,6 +1046,19 @@ true false true NULL true NULL +# Test issue: https://github.com/apache/datafusion/issues/11621 +query BB +SELECT * FROM t1 JOIN t2 ON t1.v1 = t2.v1 WHERE (t1.v1 == t2.v1) OR t1.v1; +---- +false false +false false + +query BB +SELECT * FROM t1 JOIN t2 ON t1.v1 = t2.v1 WHERE t1.v1 OR (t1.v1 == t2.v1); +---- +false false +false false + statement ok DROP TABLE t0;