diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 1c2d8ece2f36..20c329915254 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -27,6 +27,7 @@ use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExprRef; +use itertools::Itertools; use rand::Rng; use datafusion::common::JoinSide; @@ -54,33 +55,6 @@ enum JoinTestType { // because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes HjSmj, } -#[tokio::test] -async fn test_inner_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Inner, - None, - ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) - .await -} - -fn less_than_100_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { - let less_than_100 = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Lt, - Arc::new(Literal::new(ScalarValue::from(100))), - )) as _; - let column_indices = vec![ColumnIndex { - index: 0, - side: JoinSide::Left, - }]; - let intermediate_schema = - Schema::new(vec![schema1.field_with_name("a").unwrap().to_owned()]); - - JoinFilter::new(less_than_100, column_indices, intermediate_schema) -} fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { let less_filter = Arc::new(BinaryExpr::new( @@ -120,14 +94,14 @@ async fn test_inner_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Inner, - Some(Box::new(less_than_100_join_filter)), + Some(Box::new(col_lt_col_filter)), ) .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] -async fn test_inner_join_1k_smjoin() { +async fn test_inner_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -151,14 +125,16 @@ async fn test_left_join_1k() { } #[tokio::test] +// flaky for HjSmj case +// https://github.com/apache/datafusion/issues/12359 async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Left, - Some(Box::new(less_than_100_join_filter)), + Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[JoinTestType::NljHj], false) .await } @@ -173,17 +149,18 @@ async fn test_right_join_1k() { .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } -// Add support for Right filtered joins -#[ignore] + #[tokio::test] +// flaky for HjSmj case +// https://github.com/apache/datafusion/issues/12359 async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Right, - Some(Box::new(less_than_100_join_filter)), + Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[JoinTestType::NljHj], false) .await } @@ -200,14 +177,16 @@ async fn test_full_join_1k() { } #[tokio::test] +// flaky for HjSmj case +// https://github.com/apache/datafusion/issues/12359 async fn test_full_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Full, - Some(Box::new(less_than_100_join_filter)), + Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[JoinTestType::NljHj], false) .await } @@ -225,15 +204,13 @@ async fn test_semi_join_1k() { #[tokio::test] async fn test_semi_join_1k_filtered() { - // NLJ vs HJ gives wrong result - // Tracked in https://github.com/apache/datafusion/issues/11537 JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -250,19 +227,16 @@ async fn test_anti_join_1k() { } #[tokio::test] -#[ignore] -// flaky test giving 1 rows difference sometimes +// flaky for HjSmj case, giving 1 rows difference sometimes // https://github.com/apache/datafusion/issues/11555 async fn test_anti_join_1k_filtered() { - // NLJ vs HJ gives wrong result - // Tracked in https://github.com/apache/datafusion/issues/11537 JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj], false) + .run_test(&[JoinTestType::NljHj], false) .await } @@ -292,27 +266,6 @@ impl JoinFuzzTestCase { } } - fn column_indices(&self) -> Vec { - vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 1, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ColumnIndex { - index: 1, - side: JoinSide::Right, - }, - ] - } - fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { let schema1 = self.input1[0].schema(); let schema2 = self.input2[0].schema(); @@ -328,10 +281,20 @@ impl JoinFuzzTestCase { ] } + /// Helper function for building NLJoin filter, returning intermediate + /// schema as a union of origin filter intermediate schema and + /// on-condition schema fn intermediate_schema(&self) -> Schema { + let filter_schema = if let Some(filter) = self.join_filter() { + filter.schema().to_owned() + } else { + Schema::empty() + }; + let schema1 = self.input1[0].schema(); let schema2 = self.input2[0].schema(); - Schema::new(vec![ + + let on_schema = Schema::new(vec![ schema1 .field_with_name("a") .unwrap() @@ -344,7 +307,81 @@ impl JoinFuzzTestCase { .with_nullable(true), schema2.field_with_name("a").unwrap().to_owned(), schema2.field_with_name("b").unwrap().to_owned(), - ]) + ]); + + Schema::new( + filter_schema + .fields + .into_iter() + .cloned() + .chain(on_schema.fields.into_iter().cloned()) + .collect_vec(), + ) + } + + /// Helper function for building NLJoin filter, returns the union + /// of original filter expression and on-condition expression + fn composite_filter_expression(&self) -> PhysicalExprRef { + let (filter_expression, column_idx_offset) = + if let Some(filter) = self.join_filter() { + ( + filter.expression().to_owned(), + filter.schema().fields().len(), + ) + } else { + (Arc::new(Literal::new(ScalarValue::from(true))) as _, 0) + }; + + let equal_a = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", column_idx_offset)), + Operator::Eq, + Arc::new(Column::new("a", column_idx_offset + 2)), + )); + let equal_b = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", column_idx_offset + 1)), + Operator::Eq, + Arc::new(Column::new("b", column_idx_offset + 3)), + )); + let on_expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)); + + Arc::new(BinaryExpr::new( + filter_expression, + Operator::And, + on_expression, + )) + } + + /// Helper function for building NLJoin filter, returning the union + /// of original filter column indices and on-condition column indices. + /// Result must match intermediate schema. + fn column_indices(&self) -> Vec { + let mut column_indices = if let Some(filter) = self.join_filter() { + filter.column_indices().to_vec() + } else { + vec![] + }; + + let on_column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + + column_indices.extend(on_column_indices); + column_indices } fn left_right(&self) -> (Arc, Arc) { @@ -400,26 +437,15 @@ impl JoinFuzzTestCase { fn nested_loop_join(&self) -> Arc { let (left, right) = self.left_right(); - // Nested loop join uses filter for joining records + let column_indices = self.column_indices(); let intermediate_schema = self.intermediate_schema(); + let expression = self.composite_filter_expression(); - let equal_a = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Eq, - Arc::new(Column::new("a", 2)), - )) as _; - let equal_b = Arc::new(BinaryExpr::new( - Arc::new(Column::new("b", 1)), - Operator::Eq, - Arc::new(Column::new("b", 3)), - )) as _; - let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; - - let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); + let filter = JoinFilter::new(expression, column_indices, intermediate_schema); Arc::new( - NestedLoopJoinExec::try_new(left, right, Some(on_filter), &self.join_type) + NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type) .unwrap(), ) }