Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hash join with sort push down #13560

Merged
merged 12 commits into from
Dec 9, 2024
116 changes: 116 additions & 0 deletions datafusion/core/src/physical_optimizer/sort_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::PhysicalSortRequirement;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::HashJoinExec;

/// This is a "data class" we use within the [`EnforceSorting`] rule to push
/// down [`SortExec`] in the plan. In some cases, we can reduce the total
Expand Down Expand Up @@ -294,6 +296,8 @@ fn pushdown_requirement_to_children(
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req]))
}
} else if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
handle_hash_join(hash_join, parent_required)
} else {
handle_custom_pushdown(plan, parent_required, maintains_input_order)
}
Expand Down Expand Up @@ -606,6 +610,118 @@ fn handle_custom_pushdown(
}
}

// For hash join we only maintain the input order for the right child
// for join type: Inner, Right, RightSemi, RightAnti
fn handle_hash_join(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be great eventually to get this kind of operator specific logic into the operators (e.g. some method in HashJoinExec). Definitely not in this PR, but having assumptions about the operator separate from its implementation gives us a larger chance of introducing inconsistencies I think

plan: &HashJoinExec,
parent_required: &LexRequirement,
) -> Result<Option<Vec<Option<LexRequirement>>>> {
// If there's no requirement from the parent or the plan has no children
// or the join type is not Inner, Right, RightSemi, RightAnti, return early
if parent_required.is_empty()
|| plan.children().is_empty()
haohuaijin marked this conversation as resolved.
Show resolved Hide resolved
|| !matches!(
plan.join_type(),
haohuaijin marked this conversation as resolved.
Show resolved Hide resolved
JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti
)
{
return Ok(None);
}

// Collect all unique column indices used in the parent-required sorting expression
let all_indices: HashSet<usize> = parent_required
.iter()
.flat_map(|order| {
collect_columns(&order.expr)
.iter()
.map(|col| col.index())
.collect::<HashSet<_>>()
})
.collect();
haohuaijin marked this conversation as resolved.
Show resolved Hide resolved

let column_indices = build_join_column_index(plan);
let projected_indices: Vec<_> = if let Some(projection) = &plan.projection {
projection.iter().map(|&i| &column_indices[i]).collect()
} else {
column_indices.iter().collect()
};
let len_of_left_fields = projected_indices
.iter()
.filter(|ci| ci.side == JoinSide::Left)
.count();

let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields);

// If all columns are from the right child, update the parent requirements
if all_from_right_child {
// Transform the parent-required expression for the child schema by adjusting columns
let updated_parent_req = parent_required
.iter()
.map(|req| {
let child_schema = plan.children()[1].schema();
let updated_columns = Arc::clone(&req.expr)
.transform_up(|expr| {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
let index = projected_indices[col.index()].index;
Ok(Transformed::yes(Arc::new(Column::new(
child_schema.field(index).name(),
index,
))))
} else {
Ok(Transformed::no(expr))
}
})?
.data;
Ok(PhysicalSortRequirement::new(updated_columns, req.options))
})
.collect::<Result<Vec<_>>>()?;

// Populating with the updated requirements for children that maintain order
Ok(Some(vec![
None,
Some(LexRequirement::new(updated_parent_req)),
]))
} else {
Ok(None)
}
}

// this function is used to build the column index for the hash join
// push down sort requirements to the right child
fn build_join_column_index(plan: &HashJoinExec) -> Vec<ColumnIndex> {
let left = plan.left().schema();
let right = plan.right().schema();

let left_fields = || {
left.fields()
.iter()
.enumerate()
.map(|(index, _)| ColumnIndex {
index,
side: JoinSide::Left,
})
};

let right_fields = || {
right
.fields()
.iter()
.enumerate()
.map(|(index, _)| ColumnIndex {
index,
side: JoinSide::Right,
})
};

match plan.join_type() {
JoinType::Inner | JoinType::Right => {
left_fields().chain(right_fields()).collect()
}
JoinType::RightSemi | JoinType::RightAnti => right_fields().collect(),
_ => unreachable!("unexpected join type: {}", plan.join_type()),
haohuaijin marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Define the Requirements Compatibility
#[derive(Debug)]
enum RequirementsCompatibility {
Expand Down
114 changes: 70 additions & 44 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand All @@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand All @@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These plans imply that the sort has been pushed into the second (probe) input which makes sense I think : https://docs.rs/datafusion/latest/datafusion/physical_plan/joins/struct.HashJoinExec.html

03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -4313,3 +4313,29 @@ physical_plan
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
06)--------MemoryExec: partitions=1, partition_sizes=[1]

statement ok
CREATE TABLE test(a INT, b INT, c INT)

statement ok
insert into test values (1,2,3), (4,5,6)

query TT
explain select * from test where a not in (select a from test where b > 3) order by c desc;
----
logical_plan
01)Sort: test.c DESC NULLS FIRST
02)--LeftAnti Join: test.a = __correlated_sq_1.a
03)----TableScan: test projection=[a, b, c]
04)----SubqueryAlias: __correlated_sq_1
05)------Projection: test.a
06)--------Filter: test.b > Int32(3)
07)----------TableScan: test projection=[a, b]
physical_plan
01)CoalesceBatchesExec: target_batch_size=3
02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(a@0, a@0)]
03)----CoalesceBatchesExec: target_batch_size=3
04)------FilterExec: b@1 > 3, projection=[a@0]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
06)----SortExec: expr=[c@2 DESC], preserve_partitioning=[false]
07)------MemoryExec: partitions=1, partition_sizes=[1]
Loading