Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: join swap for projected semi/anti joins #13022

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,32 @@ fn swap_join_projection(
left_schema_len: usize,
right_schema_len: usize,
projection: Option<&Vec<usize>>,
join_type: &JoinType,
) -> Option<Vec<usize>> {
projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
})
match join_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// For Anti/Semi join types, projection should remain unmodified,
// since these joins output schema remains the same after swap
JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::RightAnti
| JoinType::RightSemi => projection.cloned(),

_ => projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from
// the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left
// schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
}),
}
}

/// This function swaps the inputs of the given join operator.
Expand All @@ -179,6 +191,7 @@ pub fn swap_hash_join(
left.schema().fields().len(),
right.schema().fields().len(),
hash_join.projection.as_ref(),
hash_join.join_type(),
),
partition_mode,
hash_join.null_equals_null(),
Expand Down Expand Up @@ -1289,27 +1302,59 @@ mod tests_statistical {
);
}

#[rstest(
join_type, projection, small_on_right,
case::inner(JoinType::Inner, vec![1], true),
case::left(JoinType::Left, vec![1], true),
case::right(JoinType::Right, vec![1], true),
case::full(JoinType::Full, vec![1], true),
case::left_anti(JoinType::LeftAnti, vec![0], false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be better to add a swap test for LeftAnti & LeftSemi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a separate one? Initially I was going to do so, but realized that that it'll result in 3 almost equivalent tests ("normal" joins / left semi-anti / right semi-anti) except for minor differences, so decided that it could be better to parameterize test code to highlight these differences and reduce code duplication.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since small_on_right is false for LeftAnti & LeftSemi, the swap rule don't work for those tests. But it doesn't matters.

case::left_semi(JoinType::LeftSemi, vec![0], false),
case::right_anti(JoinType::RightAnti, vec![0], true),
case::right_semi(JoinType::RightSemi, vec![0], true),
)]
#[tokio::test]
async fn test_hash_join_swap_on_joins_with_projections() -> Result<()> {
async fn test_hash_join_swap_on_joins_with_projections(
join_type: JoinType,
projection: Vec<usize>,
small_on_right: bool,
) -> Result<()> {
let (big, small) = create_big_and_small();

let left = if small_on_right { &big } else { &small };
let right = if small_on_right { &small } else { &big };

let left_on = if small_on_right {
"big_col"
} else {
"small_col"
};
let right_on = if small_on_right {
"small_col"
} else {
"big_col"
};

let join = Arc::new(HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
Arc::clone(left),
Arc::clone(right),
vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema())?),
Arc::new(Column::new_with_schema("small_col", &small.schema())?),
Arc::new(Column::new_with_schema(left_on, &left.schema())?),
Arc::new(Column::new_with_schema(right_on, &right.schema())?),
)],
None,
&JoinType::Inner,
Some(vec![1]),
&join_type,
Some(projection),
PartitionMode::Partitioned,
false,
)?);

let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned)
.expect("swap_hash_join must support joins with projections");
let swapped_join = swapped.as_any().downcast_ref::<HashJoinExec>().expect(
"ProjectionExec won't be added above if HashJoinExec contains embedded projection",
);

assert_eq!(swapped_join.projection, Some(vec![0_usize]));
assert_eq!(swapped.schema().fields.len(), 1);
assert_eq!(swapped.schema().fields[0].name(), "small_col");
Expand Down