From fa911c1afa6c61a2e647a9137cbb01815f99cc22 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 18 Jun 2024 09:06:36 -0700 Subject: [PATCH] Minor: Add routine to debug join fuzz tests --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 48 +++++++++---------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 271f774237f4..07861cb51419 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -47,7 +47,7 @@ use test_utils::stagger_batch_with_seed; // Ideally all tests should match, but in reality some tests // passes only partial cases #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum JoinTest { +enum JoinTestType { // compare NestedLoopJoin and HashJoin NljHj, // compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin @@ -62,7 +62,7 @@ async fn test_inner_join_1k() { JoinType::Inner, None, ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -114,7 +114,7 @@ async fn test_inner_join_1k_filtered() { JoinType::Inner, Some(Box::new(less_than_100_join_filter)), ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -126,7 +126,7 @@ async fn test_inner_join_1k_smjoin() { JoinType::Inner, None, ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -138,7 +138,7 @@ async fn test_left_join_1k() { JoinType::Left, None, ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -150,7 +150,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(less_than_100_join_filter)), ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -162,7 +162,7 @@ async fn test_right_join_1k() { JoinType::Right, None, ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } // Add support for Right filtered joins @@ -175,7 +175,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(less_than_100_join_filter)), ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -187,7 +187,7 @@ async fn test_full_join_1k() { JoinType::Full, None, ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -199,7 +199,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(less_than_100_join_filter)), ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -211,7 +211,7 @@ async fn test_semi_join_1k() { JoinType::LeftSemi, None, ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -227,7 +227,7 @@ async fn test_semi_join_1k_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTest::HjSmj], false) + .run_test(&[JoinTestType::HjSmj], false) .await } @@ -239,7 +239,7 @@ async fn test_anti_join_1k() { JoinType::LeftAnti, None, ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -253,7 +253,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(less_than_100_join_filter)), ) - .run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -417,9 +417,11 @@ impl JoinFuzzTestCase { ) } - /// Perform sort-merge join and hash join on same input - /// and verify two outputs are equal - async fn run_test(&self, join_test: &[JoinTest], debug: bool) { + /// Perform joins tests on same inputs and verify outputs are equal + /// `join_tests` - identifies what join types to test + /// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders, + /// so it is easy to debug a test on top of the failed data + async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) { for batch_size in self.batch_sizes { let session_config = SessionConfig::new().with_batch_size(*batch_size); let ctx = SessionContext::new_with_config(session_config); @@ -438,19 +440,17 @@ impl JoinFuzzTestCase { let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); - // if debug flag is set the test will save randomly generated inputs and outputs to user folders - // so it is easy to run debug on top of the failed test data if debug { let out_dir_name = &format!("fuzz_test_debug_batch_size_{batch_size}"); Self::save_as_parquet(&self.input1, out_dir_name, "input1"); Self::save_as_parquet(&self.input2, out_dir_name, "input2"); - if join_test.contains(&JoinTest::NljHj) { + if join_tests.contains(&JoinTestType::NljHj) { Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj"); Self::save_as_parquet(&hj_collected, out_dir_name, "hj"); } - if join_test.contains(&JoinTest::HjSmj) { + if join_tests.contains(&JoinTestType::HjSmj) { Self::save_as_parquet(&hj_collected, out_dir_name, "hj"); Self::save_as_parquet(&smj_collected, out_dir_name, "smj"); } @@ -475,12 +475,11 @@ impl JoinFuzzTestCase { nlj_formatted.trim().lines().collect(); nlj_formatted_sorted.sort_unstable(); - if join_test.contains(&JoinTest::NljHj) { + if join_tests.contains(&JoinTestType::NljHj) { let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size); - // row level compare if any of joins returns the result // the reason is different formatting when there is no rows for (i, (nlj_line, hj_line)) in nlj_formatted_sorted @@ -497,12 +496,11 @@ impl JoinFuzzTestCase { } } - if join_test.contains(&JoinTest::HjSmj) { + if join_tests.contains(&JoinTestType::HjSmj) { let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size); - // row level compare if any of joins returns the result // the reason is different formatting when there is no rows if smj_rows > 0 || hj_rows > 0 {