Skip to content

Commit

Permalink
tests: enable fuzz for filtered anti-semi NLJoin (#12360)
Browse files Browse the repository at this point in the history
* tests: enable fuzz for filtered anti-semi NLJoin

* tests: update filters in join fuzz tests

* tests: disable flaky tests for SortMergeJoin
  • Loading branch information
korowa authored Sep 9, 2024
1 parent 9bc39a0 commit 4569cbb
Showing 1 changed file with 109 additions and 83 deletions.
192 changes: 109 additions & 83 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::PhysicalExprRef;

use itertools::Itertools;
use rand::Rng;

use datafusion::common::JoinSide;
Expand Down Expand Up @@ -54,33 +55,6 @@ enum JoinTestType {
// because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes
HjSmj,
}
#[tokio::test]
async fn test_inner_join_1k() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Inner,
None,
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

fn less_than_100_join_filter(schema1: Arc<Schema>, _schema2: Arc<Schema>) -> JoinFilter {
let less_than_100 = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Lt,
Arc::new(Literal::new(ScalarValue::from(100))),
)) as _;
let column_indices = vec![ColumnIndex {
index: 0,
side: JoinSide::Left,
}];
let intermediate_schema =
Schema::new(vec![schema1.field_with_name("a").unwrap().to_owned()]);

JoinFilter::new(less_than_100, column_indices, intermediate_schema)
}

fn col_lt_col_filter(schema1: Arc<Schema>, schema2: Arc<Schema>) -> JoinFilter {
let less_filter = Arc::new(BinaryExpr::new(
Expand Down Expand Up @@ -120,14 +94,14 @@ async fn test_inner_join_1k_filtered() {
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Inner,
Some(Box::new(less_than_100_join_filter)),
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

#[tokio::test]
async fn test_inner_join_1k_smjoin() {
async fn test_inner_join_1k() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
Expand All @@ -151,14 +125,16 @@ async fn test_left_join_1k() {
}

#[tokio::test]
// flaky for HjSmj case
// https://github.com/apache/datafusion/issues/12359
async fn test_left_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Left,
Some(Box::new(less_than_100_join_filter)),
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[JoinTestType::NljHj], false)
.await
}

Expand All @@ -173,17 +149,18 @@ async fn test_right_join_1k() {
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}
// Add support for Right filtered joins
#[ignore]

#[tokio::test]
// flaky for HjSmj case
// https://github.com/apache/datafusion/issues/12359
async fn test_right_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Right,
Some(Box::new(less_than_100_join_filter)),
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[JoinTestType::NljHj], false)
.await
}

Expand All @@ -200,14 +177,16 @@ async fn test_full_join_1k() {
}

#[tokio::test]
// flaky for HjSmj case
// https://github.com/apache/datafusion/issues/12359
async fn test_full_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Full,
Some(Box::new(less_than_100_join_filter)),
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.run_test(&[JoinTestType::NljHj], false)
.await
}

Expand All @@ -225,15 +204,13 @@ async fn test_semi_join_1k() {

#[tokio::test]
async fn test_semi_join_1k_filtered() {
// NLJ vs HJ gives wrong result
// Tracked in https://github.com/apache/datafusion/issues/11537
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftSemi,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -250,19 +227,16 @@ async fn test_anti_join_1k() {
}

#[tokio::test]
#[ignore]
// flaky test giving 1 rows difference sometimes
// flaky for HjSmj case, giving 1 rows difference sometimes
// https://github.com/apache/datafusion/issues/11555
async fn test_anti_join_1k_filtered() {
// NLJ vs HJ gives wrong result
// Tracked in https://github.com/apache/datafusion/issues/11537
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftAnti,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj], false)
.run_test(&[JoinTestType::NljHj], false)
.await
}

Expand Down Expand Up @@ -292,27 +266,6 @@ impl JoinFuzzTestCase {
}
}

fn column_indices(&self) -> Vec<ColumnIndex> {
vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
]
}

fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> {
let schema1 = self.input1[0].schema();
let schema2 = self.input2[0].schema();
Expand All @@ -328,10 +281,20 @@ impl JoinFuzzTestCase {
]
}

/// Helper function for building NLJoin filter, returning intermediate
/// schema as a union of origin filter intermediate schema and
/// on-condition schema
fn intermediate_schema(&self) -> Schema {
let filter_schema = if let Some(filter) = self.join_filter() {
filter.schema().to_owned()
} else {
Schema::empty()
};

let schema1 = self.input1[0].schema();
let schema2 = self.input2[0].schema();
Schema::new(vec![

let on_schema = Schema::new(vec![
schema1
.field_with_name("a")
.unwrap()
Expand All @@ -344,7 +307,81 @@ impl JoinFuzzTestCase {
.with_nullable(true),
schema2.field_with_name("a").unwrap().to_owned(),
schema2.field_with_name("b").unwrap().to_owned(),
])
]);

Schema::new(
filter_schema
.fields
.into_iter()
.cloned()
.chain(on_schema.fields.into_iter().cloned())
.collect_vec(),
)
}

/// Helper function for building NLJoin filter, returns the union
/// of original filter expression and on-condition expression
fn composite_filter_expression(&self) -> PhysicalExprRef {
let (filter_expression, column_idx_offset) =
if let Some(filter) = self.join_filter() {
(
filter.expression().to_owned(),
filter.schema().fields().len(),
)
} else {
(Arc::new(Literal::new(ScalarValue::from(true))) as _, 0)
};

let equal_a = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", column_idx_offset)),
Operator::Eq,
Arc::new(Column::new("a", column_idx_offset + 2)),
));
let equal_b = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", column_idx_offset + 1)),
Operator::Eq,
Arc::new(Column::new("b", column_idx_offset + 3)),
));
let on_expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b));

Arc::new(BinaryExpr::new(
filter_expression,
Operator::And,
on_expression,
))
}

/// Helper function for building NLJoin filter, returning the union
/// of original filter column indices and on-condition column indices.
/// Result must match intermediate schema.
fn column_indices(&self) -> Vec<ColumnIndex> {
let mut column_indices = if let Some(filter) = self.join_filter() {
filter.column_indices().to_vec()
} else {
vec![]
};

let on_column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];

column_indices.extend(on_column_indices);
column_indices
}

fn left_right(&self) -> (Arc<MemoryExec>, Arc<MemoryExec>) {
Expand Down Expand Up @@ -400,26 +437,15 @@ impl JoinFuzzTestCase {

fn nested_loop_join(&self) -> Arc<NestedLoopJoinExec> {
let (left, right) = self.left_right();
// Nested loop join uses filter for joining records

let column_indices = self.column_indices();
let intermediate_schema = self.intermediate_schema();
let expression = self.composite_filter_expression();

let equal_a = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Eq,
Arc::new(Column::new("a", 2)),
)) as _;
let equal_b = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Operator::Eq,
Arc::new(Column::new("b", 3)),
)) as _;
let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _;

let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema);
let filter = JoinFilter::new(expression, column_indices, intermediate_schema);

Arc::new(
NestedLoopJoinExec::try_new(left, right, Some(on_filter), &self.join_type)
NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type)
.unwrap(),
)
}
Expand Down

0 comments on commit 4569cbb

Please sign in to comment.