Skip to content

Commit

Permalink
Fix hash join with sort push down (apache#13560)
Browse files Browse the repository at this point in the history
* fix: join with sort push down

* chore:
insert some value

* apply suggestion

* recover handle_costom_pushdown change

* apply suggestion

* add more test

* add partition
  • Loading branch information
haohuaijin authored and zhuliquan committed Dec 11, 2024
1 parent bd91271 commit 45926ab
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 44 deletions.
101 changes: 101 additions & 0 deletions datafusion/core/src/physical_optimizer/sort_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::tree_node::PlanContext;
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use arrow_schema::SchemaRef;

use datafusion_common::tree_node::{
ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion,
Expand All @@ -38,6 +39,8 @@ use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::PhysicalSortRequirement;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::HashJoinExec;

/// This is a "data class" we use within the [`EnforceSorting`] rule to push
/// down [`SortExec`] in the plan. In some cases, we can reduce the total
Expand Down Expand Up @@ -294,6 +297,8 @@ fn pushdown_requirement_to_children(
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req]))
}
} else if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
handle_hash_join(hash_join, parent_required)
} else {
handle_custom_pushdown(plan, parent_required, maintains_input_order)
}
Expand Down Expand Up @@ -606,6 +611,102 @@ fn handle_custom_pushdown(
}
}

// For hash join we only maintain the input order for the right child
// for join type: Inner, Right, RightSemi, RightAnti
fn handle_hash_join(
plan: &HashJoinExec,
parent_required: &LexRequirement,
) -> Result<Option<Vec<Option<LexRequirement>>>> {
// If there's no requirement from the parent or the plan has no children
// or the join type is not Inner, Right, RightSemi, RightAnti, return early
if parent_required.is_empty() || !plan.maintains_input_order()[1] {
return Ok(None);
}

// Collect all unique column indices used in the parent-required sorting expression
let all_indices: HashSet<usize> = parent_required
.iter()
.flat_map(|order| {
collect_columns(&order.expr)
.into_iter()
.map(|col| col.index())
.collect::<HashSet<_>>()
})
.collect();

let column_indices = build_join_column_index(plan);
let projected_indices: Vec<_> = if let Some(projection) = &plan.projection {
projection.iter().map(|&i| &column_indices[i]).collect()
} else {
column_indices.iter().collect()
};
let len_of_left_fields = projected_indices
.iter()
.filter(|ci| ci.side == JoinSide::Left)
.count();

let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields);

// If all columns are from the right child, update the parent requirements
if all_from_right_child {
// Transform the parent-required expression for the child schema by adjusting columns
let updated_parent_req = parent_required
.iter()
.map(|req| {
let child_schema = plan.children()[1].schema();
let updated_columns = Arc::clone(&req.expr)
.transform_up(|expr| {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
let index = projected_indices[col.index()].index;
Ok(Transformed::yes(Arc::new(Column::new(
child_schema.field(index).name(),
index,
))))
} else {
Ok(Transformed::no(expr))
}
})?
.data;
Ok(PhysicalSortRequirement::new(updated_columns, req.options))
})
.collect::<Result<Vec<_>>>()?;

// Populating with the updated requirements for children that maintain order
Ok(Some(vec![
None,
Some(LexRequirement::new(updated_parent_req)),
]))
} else {
Ok(None)
}
}

// this function is used to build the column index for the hash join
// push down sort requirements to the right child
fn build_join_column_index(plan: &HashJoinExec) -> Vec<ColumnIndex> {
let map_fields = |schema: SchemaRef, side: JoinSide| {
schema
.fields()
.iter()
.enumerate()
.map(|(index, _)| ColumnIndex { index, side })
.collect::<Vec<_>>()
};

match plan.join_type() {
JoinType::Inner | JoinType::Right => {
map_fields(plan.left().schema(), JoinSide::Left)
.into_iter()
.chain(map_fields(plan.right().schema(), JoinSide::Right))
.collect::<Vec<_>>()
}
JoinType::RightSemi | JoinType::RightAnti => {
map_fields(plan.right().schema(), JoinSide::Right)
}
_ => unreachable!("unexpected join type: {}", plan.join_type()),
}
}

/// Define the Requirements Compatibility
#[derive(Debug)]
enum RequirementsCompatibility {
Expand Down
171 changes: 127 additions & 44 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand All @@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand All @@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -4313,3 +4313,86 @@ physical_plan
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
06)--------MemoryExec: partitions=1, partition_sizes=[1]

# Test hash join sort push down
# Issue: https://github.com/apache/datafusion/issues/13559
statement ok
CREATE TABLE test(a INT, b INT, c INT)

statement ok
insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null)

statement ok
set datafusion.execution.target_partitions = 2;

query TT
explain select * from test where a in (select a from test where b > 3) order by c desc nulls first;
----
logical_plan
01)Sort: test.c DESC NULLS FIRST
02)--LeftSemi Join: test.a = __correlated_sq_1.a
03)----TableScan: test projection=[a, b, c]
04)----SubqueryAlias: __correlated_sq_1
05)------Projection: test.a
06)--------Filter: test.b > Int32(3)
07)----------TableScan: test projection=[a, b]
physical_plan
01)SortPreservingMergeExec: [c@2 DESC]
02)--CoalesceBatchesExec: target_batch_size=3
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)]
04)------CoalesceBatchesExec: target_batch_size=3
05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
06)----------CoalesceBatchesExec: target_batch_size=3
07)------------FilterExec: b@1 > 3, projection=[a@0]
08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
10)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true]
11)--------CoalesceBatchesExec: target_batch_size=3
12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
14)--------------MemoryExec: partitions=1, partition_sizes=[1]

query TT
explain select * from test where a in (select a from test where b > 3) order by c desc nulls last;
----
logical_plan
01)Sort: test.c DESC NULLS LAST
02)--LeftSemi Join: test.a = __correlated_sq_1.a
03)----TableScan: test projection=[a, b, c]
04)----SubqueryAlias: __correlated_sq_1
05)------Projection: test.a
06)--------Filter: test.b > Int32(3)
07)----------TableScan: test projection=[a, b]
physical_plan
01)SortPreservingMergeExec: [c@2 DESC NULLS LAST]
02)--CoalesceBatchesExec: target_batch_size=3
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)]
04)------CoalesceBatchesExec: target_batch_size=3
05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
06)----------CoalesceBatchesExec: target_batch_size=3
07)------------FilterExec: b@1 > 3, projection=[a@0]
08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
10)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true]
11)--------CoalesceBatchesExec: target_batch_size=3
12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
14)--------------MemoryExec: partitions=1, partition_sizes=[1]

query III
select * from test where a in (select a from test where b > 3) order by c desc nulls first;
----
9 10 NULL
4 5 6

query III
select * from test where a in (select a from test where b > 3) order by c desc nulls last;
----
4 5 6
9 10 NULL

statement ok
DROP TABLE test

statement ok
set datafusion.execution.target_partitions = 1;

0 comments on commit 45926ab

Please sign in to comment.