Skip to content

Commit

Permalink
Minor: Add routine to debug join fuzz tests
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed Jun 21, 2024
1 parent 8285812 commit fa911c1
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use test_utils::stagger_batch_with_seed;
// Ideally all tests should match, but in reality some tests
// passes only partial cases
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum JoinTest {
enum JoinTestType {
// compare NestedLoopJoin and HashJoin
NljHj,
// compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin
Expand All @@ -62,7 +62,7 @@ async fn test_inner_join_1k() {
JoinType::Inner,
None,
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand Down Expand Up @@ -114,7 +114,7 @@ async fn test_inner_join_1k_filtered() {
JoinType::Inner,
Some(Box::new(less_than_100_join_filter)),
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -126,7 +126,7 @@ async fn test_inner_join_1k_smjoin() {
JoinType::Inner,
None,
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -138,7 +138,7 @@ async fn test_left_join_1k() {
JoinType::Left,
None,
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -150,7 +150,7 @@ async fn test_left_join_1k_filtered() {
JoinType::Left,
Some(Box::new(less_than_100_join_filter)),
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -162,7 +162,7 @@ async fn test_right_join_1k() {
JoinType::Right,
None,
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}
// Add support for Right filtered joins
Expand All @@ -175,7 +175,7 @@ async fn test_right_join_1k_filtered() {
JoinType::Right,
Some(Box::new(less_than_100_join_filter)),
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -187,7 +187,7 @@ async fn test_full_join_1k() {
JoinType::Full,
None,
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -199,7 +199,7 @@ async fn test_full_join_1k_filtered() {
JoinType::Full,
Some(Box::new(less_than_100_join_filter)),
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -211,7 +211,7 @@ async fn test_semi_join_1k() {
JoinType::LeftSemi,
None,
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -227,7 +227,7 @@ async fn test_semi_join_1k_filtered() {
JoinType::LeftSemi,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTest::HjSmj], false)
.run_test(&[JoinTestType::HjSmj], false)
.await
}

Expand All @@ -239,7 +239,7 @@ async fn test_anti_join_1k() {
JoinType::LeftAnti,
None,
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -253,7 +253,7 @@ async fn test_anti_join_1k_filtered() {
JoinType::LeftAnti,
Some(Box::new(less_than_100_join_filter)),
)
.run_test(&[JoinTest::HjSmj, JoinTest::NljHj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand Down Expand Up @@ -417,9 +417,11 @@ impl JoinFuzzTestCase {
)
}

/// Perform sort-merge join and hash join on same input
/// and verify two outputs are equal
async fn run_test(&self, join_test: &[JoinTest], debug: bool) {
/// Perform joins tests on same inputs and verify outputs are equal
/// `join_tests` - identifies what join types to test
/// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders,
/// so it is easy to debug a test on top of the failed data
async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) {
for batch_size in self.batch_sizes {
let session_config = SessionConfig::new().with_batch_size(*batch_size);
let ctx = SessionContext::new_with_config(session_config);
Expand All @@ -438,19 +440,17 @@ impl JoinFuzzTestCase {
let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows());

// if debug flag is set the test will save randomly generated inputs and outputs to user folders
// so it is easy to run debug on top of the failed test data
if debug {
let out_dir_name = &format!("fuzz_test_debug_batch_size_{batch_size}");
Self::save_as_parquet(&self.input1, out_dir_name, "input1");
Self::save_as_parquet(&self.input2, out_dir_name, "input2");

if join_test.contains(&JoinTest::NljHj) {
if join_tests.contains(&JoinTestType::NljHj) {
Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj");
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
}

if join_test.contains(&JoinTest::HjSmj) {
if join_tests.contains(&JoinTestType::HjSmj) {
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
Self::save_as_parquet(&smj_collected, out_dir_name, "smj");
}
Expand All @@ -475,12 +475,11 @@ impl JoinFuzzTestCase {
nlj_formatted.trim().lines().collect();
nlj_formatted_sorted.sort_unstable();

if join_test.contains(&JoinTest::NljHj) {
if join_tests.contains(&JoinTestType::NljHj) {
let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size);
assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str());

let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size);

// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
Expand All @@ -497,12 +496,11 @@ impl JoinFuzzTestCase {
}
}

if join_test.contains(&JoinTest::HjSmj) {
if join_tests.contains(&JoinTestType::HjSmj) {
let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size);
assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str());

let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size);

// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
if smj_rows > 0 || hj_rows > 0 {
Expand Down

0 comments on commit fa911c1

Please sign in to comment.