diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 824f1eec4a85..8c2e24de56b9 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -22,6 +22,11 @@ use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::Schema; + +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::PhysicalExprRef; + use rand::Rng; use datafusion::common::JoinSide; @@ -40,92 +45,207 @@ use test_utils::stagger_batch_with_seed; #[tokio::test] async fn test_inner_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + None, + ) + .run_test() + .await +} + +fn less_than_10_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { + let less_than_100 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::from(100))), + )) as _; + let column_indices = vec![ColumnIndex { + index: 0, + side: JoinSide::Left, + }]; + let intermediate_schema = + Schema::new(vec![schema1.field_with_name("a").unwrap().to_owned()]); + + JoinFilter::new(less_than_100, column_indices, intermediate_schema) +} + +#[tokio::test] +async fn test_inner_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + Some(Box::new(less_than_10_join_filter)), + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_inner_join_1k_smjoin() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Inner, + None, ) + .run_test() .await } #[tokio::test] async fn test_left_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Left, + None, + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_left_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Left, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_right_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Right, + None, + ) + .run_test() + .await +} +// Add support for Right filtered joins +#[ignore] +#[tokio::test] +async fn test_right_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Right, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_full_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Full, + None, ) + .run_test() + .await +} + +#[tokio::test] +async fn test_full_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Full, + Some(Box::new(less_than_10_join_filter)), + ) + .run_test() .await } #[tokio::test] async fn test_semi_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftSemi, + None, + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_semi_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_anti_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftAnti, + None, + ) + .run_test() + .await +} + +// Test failed for now. https://github.com/apache/datafusion/issues/10872 +#[ignore] +#[tokio::test] +async fn test_anti_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, + Some(Box::new(less_than_10_join_filter)), ) + .run_test() .await } -/// Perform sort-merge join and hash join on same input -/// and verify two outputs are equal -async fn run_join_test( +type JoinFilterBuilder = Box, Arc) -> JoinFilter>; + +struct JoinFuzzTestCase { + batch_sizes: &'static [usize], input1: Vec, input2: Vec, join_type: JoinType, -) { - let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; - for batch_size in batch_sizes { - let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::new_with_config(session_config); - let task_ctx = ctx.task_ctx(); - - let schema1 = input1[0].schema(); - let schema2 = input2[0].schema(); - let on_columns = vec![ - ( - Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, - Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, - ), - ( - Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, - Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, - ), - ]; + join_filter_builder: Option, +} - // Nested loop join uses filter for joining records - let column_indices = vec![ +impl JoinFuzzTestCase { + fn new( + input1: Vec, + input2: Vec, + join_type: JoinType, + join_filter_builder: Option, + ) -> Self { + Self { + batch_sizes: &[1, 2, 7, 49, 50, 51, 100], + input1, + input2, + join_type, + join_filter_builder, + } + } + + fn column_indices(&self) -> Vec { + vec![ ColumnIndex { index: 0, side: JoinSide::Left, @@ -142,120 +262,185 @@ async fn run_join_test( index: 1, side: JoinSide::Right, }, - ]; - let intermediate_schema = Schema::new(vec![ - schema1.field_with_name("a").unwrap().to_owned(), - schema1.field_with_name("b").unwrap().to_owned(), - schema2.field_with_name("a").unwrap().to_owned(), - schema2.field_with_name("b").unwrap().to_owned(), - ]); + ] + } - let equal_a = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Eq, - Arc::new(Column::new("a", 2)), - )) as _; - let equal_b = Arc::new(BinaryExpr::new( - Arc::new(Column::new("b", 1)), - Operator::Eq, - Arc::new(Column::new("b", 3)), - )) as _; - let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; + fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + vec![ + ( + Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, + ), + ( + Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, + ), + ] + } - let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); + fn intermediate_schema(&self) -> Schema { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + Schema::new(vec![ + schema1 + .field_with_name("a") + .unwrap() + .to_owned() + .with_nullable(true), + schema1 + .field_with_name("b") + .unwrap() + .to_owned() + .with_nullable(true), + schema2.field_with_name("a").unwrap().to_owned(), + schema2.field_with_name("b").unwrap().to_owned(), + ]) + } - // sort-merge join + fn left_right(&self) -> (Arc, Arc) { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), + MemoryExec::try_new(&[self.input1.clone()], schema1.clone(), None).unwrap(), ); let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), + MemoryExec::try_new(&[self.input2.clone()], schema2.clone(), None).unwrap(), ); - let smj = Arc::new( + (left, right) + } + + fn join_filter(&self) -> Option { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + self.join_filter_builder + .as_ref() + .map(|builder| builder(schema1, schema2)) + } + + fn sort_merge_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( SortMergeJoinExec::try_new( left, right, - on_columns.clone(), - None, - join_type, + self.on_columns().clone(), + self.join_filter(), + self.join_type, vec![SortOptions::default(), SortOptions::default()], false, ) .unwrap(), - ); - let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + ) + } - // hash join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let hj = Arc::new( + fn hash_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( HashJoinExec::try_new( left, right, - on_columns.clone(), - None, - &join_type, + self.on_columns().clone(), + self.join_filter(), + &self.join_type, None, PartitionMode::Partitioned, false, ) .unwrap(), - ); - let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + ) + } - // nested loop join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let nlj = Arc::new( - NestedLoopJoinExec::try_new(left, right, Some(on_filter), &join_type) + fn nested_loop_join(&self) -> Arc { + let (left, right) = self.left_right(); + // Nested loop join uses filter for joining records + let column_indices = self.column_indices(); + let intermediate_schema = self.intermediate_schema(); + + let equal_a = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Column::new("a", 2)), + )) as _; + let equal_b = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Eq, + Arc::new(Column::new("b", 3)), + )) as _; + let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; + + let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); + + Arc::new( + NestedLoopJoinExec::try_new(left, right, Some(on_filter), &self.join_type) .unwrap(), - ); - let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); + ) + } - // compare - let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string(); - let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); - let nlj_formatted = pretty_format_batches(&nlj_collected).unwrap().to_string(); + /// Perform sort-merge join and hash join on same input + /// and verify two outputs are equal + async fn run_test(&self) { + for batch_size in self.batch_sizes { + let session_config = SessionConfig::new().with_batch_size(*batch_size); + let ctx = SessionContext::new_with_config(session_config); + let task_ctx = ctx.task_ctx(); + let smj = self.sort_merge_join(); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); - let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect(); - smj_formatted_sorted.sort_unstable(); + let hj = self.hash_join(); + let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); - let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect(); - hj_formatted_sorted.sort_unstable(); + let nlj = self.nested_loop_join(); + let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); - let mut nlj_formatted_sorted: Vec<&str> = nlj_formatted.trim().lines().collect(); - nlj_formatted_sorted.sort_unstable(); + // compare + let smj_formatted = + pretty_format_batches(&smj_collected).unwrap().to_string(); + let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); + let nlj_formatted = + pretty_format_batches(&nlj_collected).unwrap().to_string(); - for (i, (smj_line, hj_line)) in smj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, smj_line), - (i, hj_line), - "SortMergeJoinExec and HashJoinExec produced different results" - ); - } + let mut smj_formatted_sorted: Vec<&str> = + smj_formatted.trim().lines().collect(); + smj_formatted_sorted.sort_unstable(); + + let mut hj_formatted_sorted: Vec<&str> = + hj_formatted.trim().lines().collect(); + hj_formatted_sorted.sort_unstable(); + + let mut nlj_formatted_sorted: Vec<&str> = + nlj_formatted.trim().lines().collect(); + nlj_formatted_sorted.sort_unstable(); - for (i, (nlj_line, hj_line)) in nlj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { assert_eq!( - (i, nlj_line), - (i, hj_line), - "NestedLoopJoinExec and HashJoinExec produced different results" + smj_formatted_sorted.len(), + hj_formatted_sorted.len(), + "SortMergeJoinExec and HashJoinExec produced different row counts" ); + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, smj_line), + (i, hj_line), + "SortMergeJoinExec and HashJoinExec produced different results" + ); + } + + for (i, (nlj_line, hj_line)) in nlj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, nlj_line), + (i, hj_line), + "NestedLoopJoinExec and HashJoinExec produced different results" + ); + } } } }