From 8e320225b8e97d0998071577aa41d7331e545683 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 26 Dec 2024 10:39:59 -0500 Subject: [PATCH] Move join type input swapping to pub methods on Joins --- datafusion/common/src/join_type.rs | 34 +++ .../src/physical_optimizer/join_selection.rs | 263 +++--------------- .../physical-plan/src/joins/cross_join.rs | 17 +- .../physical-plan/src/joins/hash_join.rs | 81 +++++- .../physical-plan/src/joins/join_filter.rs | 100 +++++++ datafusion/physical-plan/src/joins/mod.rs | 8 +- .../src/joins/nested_loop_join.rs | 39 ++- datafusion/physical-plan/src/joins/utils.rs | 142 ++++++---- 8 files changed, 397 insertions(+), 287 deletions(-) create mode 100644 datafusion/physical-plan/src/joins/join_filter.rs diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index bdca253c5f64..ac81d977b729 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -73,6 +73,40 @@ impl JoinType { pub fn is_outer(self) -> bool { self == JoinType::Left || self == JoinType::Right || self == JoinType::Full } + + /// Returns the `JoinType` if the (2) inputs were swapped + /// + /// Panics if [`Self::supports_swap`] returns false + pub fn swap(&self) -> JoinType { + match self { + JoinType::Inner => JoinType::Inner, + JoinType::Full => JoinType::Full, + JoinType::Left => JoinType::Right, + JoinType::Right => JoinType::Left, + JoinType::LeftSemi => JoinType::RightSemi, + JoinType::RightSemi => JoinType::LeftSemi, + JoinType::LeftAnti => JoinType::RightAnti, + JoinType::RightAnti => JoinType::LeftAnti, + JoinType::LeftMark => { + unreachable!("LeftMark join type does not support swapping") + } + } + } + + /// Does the join type support swapping inputs? + pub fn supports_swap(&self) -> bool { + matches!( + self, + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) + } } impl Display for JoinType { diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 29c6e0078847..d7a2f1740141 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -32,15 +32,12 @@ use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; -use arrow_schema::Schema; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, JoinSide, JoinType}; use datafusion_expr::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::execution_plan::EmissionType; @@ -108,197 +105,49 @@ fn supports_collect_by_thresholds( } /// Predicate that checks whether the given join type supports input swapping. +#[deprecated(since = "45.0.0", note = "use JoinType::supports_swap instead")] +#[allow(dead_code)] pub(crate) fn supports_swap(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::Inner - | JoinType::Left - | JoinType::Right - | JoinType::Full - | JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - ) + join_type.supports_swap() } /// This function returns the new join type we get after swapping the given /// join's inputs. +#[deprecated(since = "45.0.0", note = "use datafusion-functions-nested instead")] +#[allow(dead_code)] pub(crate) fn swap_join_type(join_type: JoinType) -> JoinType { - match join_type { - JoinType::Inner => JoinType::Inner, - JoinType::Full => JoinType::Full, - JoinType::Left => JoinType::Right, - JoinType::Right => JoinType::Left, - JoinType::LeftSemi => JoinType::RightSemi, - JoinType::RightSemi => JoinType::LeftSemi, - JoinType::LeftAnti => JoinType::RightAnti, - JoinType::RightAnti => JoinType::LeftAnti, - JoinType::LeftMark => { - unreachable!("LeftMark join type does not support swapping") - } - } -} - -/// This function swaps the given join's projection. -fn swap_join_projection( - left_schema_len: usize, - right_schema_len: usize, - projection: Option<&Vec>, - join_type: &JoinType, -) -> Option> { - match join_type { - // For Anti/Semi join types, projection should remain unmodified, - // since these joins output schema remains the same after swap - JoinType::LeftAnti - | JoinType::LeftSemi - | JoinType::RightAnti - | JoinType::RightSemi => projection.cloned(), - - _ => projection.map(|p| { - p.iter() - .map(|i| { - // If the index is less than the left schema length, it is from - // the left schema, so we add the right schema length to it. - // Otherwise, it is from the right schema, so we subtract the left - // schema length from it. - if *i < left_schema_len { - *i + right_schema_len - } else { - *i - left_schema_len - } - }) - .collect() - }), - } + join_type.swap() } /// This function swaps the inputs of the given join operator. /// This function is public so other downstream projects can use it /// to construct `HashJoinExec` with right side as the build side. +#[deprecated(since = "45.0.0", note = "use HashJoinExec::swap_inputs instead")] pub fn swap_hash_join( hash_join: &HashJoinExec, partition_mode: PartitionMode, ) -> Result> { - let left = hash_join.left(); - let right = hash_join.right(); - let new_join = HashJoinExec::try_new( - Arc::clone(right), - Arc::clone(left), - hash_join - .on() - .iter() - .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) - .collect(), - swap_join_filter(hash_join.filter()), - &swap_join_type(*hash_join.join_type()), - swap_join_projection( - left.schema().fields().len(), - right.schema().fields().len(), - hash_join.projection.as_ref(), - hash_join.join_type(), - ), - partition_mode, - hash_join.null_equals_null(), - )?; - // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again - if matches!( - hash_join.join_type(), - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - ) || hash_join.projection.is_some() - { - Ok(Arc::new(new_join)) - } else { - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj = ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?; - Ok(Arc::new(proj)) - } + hash_join.swap_inputs(partition_mode) } /// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required +#[deprecated(since = "45.0.0", note = "use NestedLoopJoinExec::swap_inputs")] +#[allow(dead_code)] pub(crate) fn swap_nl_join(join: &NestedLoopJoinExec) -> Result> { - let new_filter = swap_join_filter(join.filter()); - let new_join_type = &swap_join_type(*join.join_type()); - - let new_join = NestedLoopJoinExec::try_new( - Arc::clone(join.right()), - Arc::clone(join.left()), - new_filter, - new_join_type, - )?; - - // For Semi/Anti joins, swap result will produce same output schema, - // no need to wrap them into additional projection - let plan: Arc = if matches!( - join.join_type(), - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - ) { - Arc::new(new_join) - } else { - let projection = - swap_reverting_projection(&join.left().schema(), &join.right().schema()); - - Arc::new(ProjectionExec::try_new(projection, Arc::new(new_join))?) - }; - - Ok(plan) + join.swap_inputs() } -/// When the order of the join is changed by the optimizer, the columns in -/// the output should not be impacted. This function creates the expressions -/// that will allow to swap back the values from the original left as the first -/// columns and those on the right next. -pub(crate) fn swap_reverting_projection( - left_schema: &Schema, - right_schema: &Schema, -) -> Vec<(Arc, String)> { - let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| { - ( - Arc::new(Column::new(f.name(), i)) as Arc, - f.name().to_owned(), - ) - }); - let right_len = right_cols.len(); - let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| { - ( - Arc::new(Column::new(f.name(), right_len + i)) as Arc, - f.name().to_owned(), - ) - }); - - left_cols.chain(right_cols).collect() +/// Swaps join sides for filter column indices and produces new `JoinFilter` (if exists). +#[deprecated(since = "45.0.0", note = "use filter.map(JoinFilter::swap) instead")] +#[allow(dead_code)] +fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { + filter.map(JoinFilter::swap) } -/// Swaps join sides for filter column indices and produces new JoinFilter +#[deprecated(since = "45.0.0", note = "use JoinFilter::swap instead")] +#[allow(dead_code)] pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter { - let column_indices = filter - .column_indices() - .iter() - .map(|idx| ColumnIndex { - index: idx.index, - side: idx.side.negate(), - }) - .collect(); - - JoinFilter::new( - Arc::clone(filter.expression()), - column_indices, - filter.schema().clone(), - ) -} - -/// Swaps join sides for filter column indices and produces new `JoinFilter` (if exists). -fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { - filter.map(swap_filter) + filter.swap() } impl PhysicalOptimizerRule for JoinSelection { @@ -383,10 +232,10 @@ pub(crate) fn try_collect_left( match (left_can_collect, right_can_collect) { (true, true) => { - if supports_swap(*hash_join.join_type()) + if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?)) + Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) } else { Ok(Some(Arc::new(HashJoinExec::try_new( Arc::clone(left), @@ -411,8 +260,8 @@ pub(crate) fn try_collect_left( hash_join.null_equals_null(), )?))), (false, true) => { - if supports_swap(*hash_join.join_type()) { - swap_hash_join(hash_join, PartitionMode::CollectLeft).map(Some) + if hash_join.join_type().supports_swap() { + hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) } else { Ok(None) } @@ -431,9 +280,9 @@ pub(crate) fn partitioned_hash_join( ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if supports_swap(*hash_join.join_type()) && should_swap_join_order(&**left, &**right)? + if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - swap_hash_join(hash_join, PartitionMode::Partitioned) + hash_join.swap_inputs(PartitionMode::Partitioned) } else { Ok(Arc::new(HashJoinExec::try_new( Arc::clone(left), @@ -476,10 +325,12 @@ fn statistical_join_selection_subrule( PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right(); - if supports_swap(*hash_join.join_type()) + if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? + hash_join + .swap_inputs(PartitionMode::Partitioned) + .map(Some)? } else { None } @@ -489,23 +340,17 @@ fn statistical_join_selection_subrule( let left = cross_join.left(); let right = cross_join.right(); if should_swap_join_order(&**left, &**right)? { - let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj: Arc = Arc::new(ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?); - Some(proj) + cross_join.swap_inputs().map(Some)? } else { None } } else if let Some(nl_join) = plan.as_any().downcast_ref::() { let left = nl_join.left(); let right = nl_join.right(); - if supports_swap(*nl_join.join_type()) + if nl_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - swap_nl_join(nl_join).map(Some)? + nl_join.swap_inputs().map(Some)? } else { None } @@ -718,10 +563,10 @@ fn swap_join_according_to_unboundedness( JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, ) => internal_err!("{join_type} join cannot be swapped for unbounded input."), (PartitionMode::Partitioned, _) => { - swap_hash_join(hash_join, PartitionMode::Partitioned) + hash_join.swap_inputs(PartitionMode::Partitioned) } (PartitionMode::CollectLeft, _) => { - swap_hash_join(hash_join, PartitionMode::CollectLeft) + hash_join.swap_inputs(PartitionMode::CollectLeft) } (PartitionMode::Auto, _) => { internal_err!("Auto is not acceptable for unbounded input here.") @@ -751,12 +596,15 @@ mod tests_statistical { }; use arrow::datatypes::{DataType, Field}; + use arrow_schema::Schema; use datafusion_common::{stats::Precision, JoinType, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::PhysicalExprRef; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_plan::projection::ProjectionExec; use rstest::rstest; /// Return statistics for empty table @@ -1372,7 +1220,8 @@ mod tests_statistical { false, )?); - let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned) + let swapped = join + .swap_inputs(PartitionMode::Partitioned) .expect("swap_hash_join must support joins with projections"); let swapped_join = swapped.as_any().downcast_ref::().expect( "ProjectionExec won't be added above if HashJoinExec contains embedded projection", @@ -1384,32 +1233,6 @@ mod tests_statistical { Ok(()) } - #[tokio::test] - async fn test_swap_reverting_projection() { - let left_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - - let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]); - - let proj = swap_reverting_projection(&left_schema, &right_schema); - - assert_eq!(proj.len(), 3); - - let (col, name) = &proj[0]; - assert_eq!(name, "a"); - assert_col_expr(col, "a", 1); - - let (col, name) = &proj[1]; - assert_eq!(name, "b"); - assert_col_expr(col, "b", 2); - - let (col, name) = &proj[2]; - assert_eq!(name, "c"); - assert_col_expr(col, "c", 0); - } - fn assert_col_expr(expr: &Arc, name: &str, index: usize) { let col = expr .as_any() @@ -1643,7 +1466,9 @@ mod hash_join_tests { use arrow::datatypes::{DataType, Field}; use arrow::record_batch::RecordBatch; + use arrow_schema::Schema; use datafusion_physical_expr::expressions::col; + use datafusion_physical_plan::projection::ProjectionExec; struct TestCase { case: String, @@ -1723,7 +1548,7 @@ mod hash_join_tests { initial_join_type: join_type, initial_mode: PartitionMode::CollectLeft, expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), + expected_join_type: join_type.swap(), expected_mode: PartitionMode::CollectLeft, expecting_swap: true, }); @@ -1766,7 +1591,7 @@ mod hash_join_tests { initial_join_type: join_type, initial_mode: PartitionMode::Partitioned, expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), + expected_join_type: join_type.swap(), expected_mode: PartitionMode::Partitioned, expecting_swap: true, }); @@ -1824,7 +1649,7 @@ mod hash_join_tests { initial_join_type: join_type, initial_mode: PartitionMode::Partitioned, expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), + expected_join_type: join_type.swap(), expected_mode: PartitionMode::Partitioned, expecting_swap: true, }); diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index b70eeb313b2a..69300fce7745 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -19,8 +19,8 @@ //! and producing batches in parallel for the right partitions use super::utils::{ - adjust_right_output_partitioning, BatchSplitter, BatchTransformer, - BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, + adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter, + BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; @@ -168,6 +168,19 @@ impl CrossJoinExec { boundedness_from_children([left, right]), ) } + + /// Returns a new `ExecutionPlan` that computes the same join as this one, + /// with the left and right inputs swapped using the specified + /// `partition_mode`. + pub fn swap_inputs(&self) -> Result> { + let new_join = + CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left)); + reorder_output_after_swap( + Arc::new(new_join), + &self.left.schema(), + &self.right.schema(), + ) + } } /// Asynchronously collect the result of the left child diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index dabe42ee43a2..a0fe0bd116ee 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use std::task::Poll; use std::{any::Any, vec}; -use super::utils::asymmetric_join_output_partitioning; +use super::utils::{asymmetric_join_output_partitioning, reorder_output_after_swap}; use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, @@ -566,8 +566,87 @@ impl HashJoinExec { boundedness_from_children([left, right]), )) } + + /// Returns a new `ExecutionPlan` that computes the same join as this one, + /// with the left and right inputs swapped using the specified + /// `partition_mode`. + /// + /// # Notes: + /// + /// This function is public so other downstream projects can use it to + /// construct `HashJoinExec` with right side as the build side. + pub fn swap_inputs( + &self, + partition_mode: PartitionMode, + ) -> Result> { + let left = self.left(); + let right = self.right(); + let new_join = HashJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect(), + self.filter().map(JoinFilter::swap), + &self.join_type().swap(), + swap_join_projection( + left.schema().fields().len(), + right.schema().fields().len(), + self.projection.as_ref(), + self.join_type(), + ), + partition_mode, + self.null_equals_null(), + )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again + if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) || self.projection.is_some() + { + Ok(Arc::new(new_join)) + } else { + reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) + } + } } +/// This function swaps the given join's projection. +fn swap_join_projection( + left_schema_len: usize, + right_schema_len: usize, + projection: Option<&Vec>, + join_type: &JoinType, +) -> Option> { + match join_type { + // For Anti/Semi join types, projection should remain unmodified, + // since these joins output schema remains the same after swap + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::RightAnti + | JoinType::RightSemi => projection.cloned(), + + _ => projection.map(|p| { + p.iter() + .map(|i| { + // If the index is less than the left schema length, it is from + // the left schema, so we add the right schema length to it. + // Otherwise, it is from the right schema, so we subtract the left + // schema length from it. + if *i < left_schema_len { + *i + right_schema_len + } else { + *i - left_schema_len + } + }) + .collect() + }), + } +} impl DisplayAs for HashJoinExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { diff --git a/datafusion/physical-plan/src/joins/join_filter.rs b/datafusion/physical-plan/src/joins/join_filter.rs new file mode 100644 index 000000000000..b99afd87c92a --- /dev/null +++ b/datafusion/physical-plan/src/joins/join_filter.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::joins::utils::ColumnIndex; +use arrow_schema::Schema; +use datafusion_common::JoinSide; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Filter applied before join output. Fields are crate-public to allow +/// downstream implementations to experiment with custom joins. +#[derive(Debug, Clone)] +pub struct JoinFilter { + /// Filter expression + pub(crate) expression: Arc, + /// Column indices required to construct intermediate batch for filtering + pub(crate) column_indices: Vec, + /// Physical schema of intermediate batch + pub(crate) schema: Schema, +} + +impl JoinFilter { + /// Creates new JoinFilter + pub fn new( + expression: Arc, + column_indices: Vec, + schema: Schema, + ) -> JoinFilter { + JoinFilter { + expression, + column_indices, + schema, + } + } + + /// Helper for building ColumnIndex vector from left and right indices + pub fn build_column_indices( + left_indices: Vec, + right_indices: Vec, + ) -> Vec { + left_indices + .into_iter() + .map(|i| ColumnIndex { + index: i, + side: JoinSide::Left, + }) + .chain(right_indices.into_iter().map(|i| ColumnIndex { + index: i, + side: JoinSide::Right, + })) + .collect() + } + + /// Filter expression + pub fn expression(&self) -> &Arc { + &self.expression + } + + /// Column indices for intermediate batch creation + pub fn column_indices(&self) -> &[ColumnIndex] { + &self.column_indices + } + + /// Intermediate batch schema + pub fn schema(&self) -> &Schema { + &self.schema + } + + /// Rewrites the join filter if the inputs to the join are rewritten + pub fn swap(&self) -> JoinFilter { + let column_indices = self + .column_indices() + .iter() + .map(|idx| ColumnIndex { + index: idx.index, + side: idx.side.negate(), + }) + .collect(); + + JoinFilter::new( + Arc::clone(self.expression()), + column_indices, + self.schema().clone(), + ) + } +} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 6ddf19c51193..fa077d200833 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -31,18 +31,20 @@ mod stream_join_utils; mod symmetric_hash_join; pub mod utils; +mod join_filter; #[cfg(test)] pub mod test_utils; #[derive(Clone, Copy, Debug, PartialEq, Eq)] -/// Partitioning mode to use for hash join +/// Hash join Partitioning mode pub enum PartitionMode { /// Left/right children are partitioned using the left and right keys Partitioned, /// Left side will collected into one partition CollectLeft, - /// When set to Auto, DataFusion optimizer will decide which PartitionMode mode(Partitioned/CollectLeft) is optimal based on statistics. - /// It will also consider swapping the left and right inputs for the Join + /// DataFusion optimizer decides which PartitionMode + /// mode(Partitioned/CollectLeft) is optimal based on statistics. It will + /// also consider swapping the left and right inputs for the Join Auto, } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 8caf5d9b5de1..c69fa2888806 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -24,8 +24,9 @@ use std::sync::Arc; use std::task::Poll; use super::utils::{ - asymmetric_join_output_partitioning, need_produce_result_in_final, BatchSplitter, - BatchTransformer, NoopBatchTransformer, StatefulStreamResult, + asymmetric_join_output_partitioning, need_produce_result_in_final, + reorder_output_after_swap, BatchSplitter, BatchTransformer, NoopBatchTransformer, + StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::execution_plan::{boundedness_from_children, EmissionType}; @@ -296,6 +297,40 @@ impl NestedLoopJoinExec { ), ] } + + /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left + /// and right inputs swapped. + pub fn swap_inputs(&self) -> Result> { + let new_filter = self.filter().map(JoinFilter::swap); + let new_join_type = &self.join_type().swap(); + + let new_join = NestedLoopJoinExec::try_new( + Arc::clone(self.right()), + Arc::clone(self.left()), + new_filter, + new_join_type, + )?; + + // For Semi/Anti joins, swap result will produce same output schema, + // no need to wrap them into additional projection + let plan: Arc = if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) { + Arc::new(new_join) + } else { + reorder_output_after_swap( + Arc::new(new_join), + &self.left().schema(), + &self.right().schema(), + )? + }; + + Ok(plan) + } } impl DisplayAs for NestedLoopJoinExec { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d792e143046c..371949a32598 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -29,6 +29,8 @@ use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, }; +// compatibility +pub use super::join_filter::JoinFilter; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, @@ -54,6 +56,7 @@ use datafusion_physical_expr::{ LexOrdering, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, }; +use crate::projection::ProjectionExec; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use hashbrown::raw::RawTable; @@ -549,66 +552,6 @@ pub struct ColumnIndex { pub side: JoinSide, } -/// Filter applied before join output. Fields are crate-public to allow -/// downstream implementations to experiment with custom joins. -#[derive(Debug, Clone)] -pub struct JoinFilter { - /// Filter expression - pub(crate) expression: Arc, - /// Column indices required to construct intermediate batch for filtering - pub(crate) column_indices: Vec, - /// Physical schema of intermediate batch - pub(crate) schema: Schema, -} - -impl JoinFilter { - /// Creates new JoinFilter - pub fn new( - expression: Arc, - column_indices: Vec, - schema: Schema, - ) -> JoinFilter { - JoinFilter { - expression, - column_indices, - schema, - } - } - - /// Helper for building ColumnIndex vector from left and right indices - pub fn build_column_indices( - left_indices: Vec, - right_indices: Vec, - ) -> Vec { - left_indices - .into_iter() - .map(|i| ColumnIndex { - index: i, - side: JoinSide::Left, - }) - .chain(right_indices.into_iter().map(|i| ColumnIndex { - index: i, - side: JoinSide::Right, - })) - .collect() - } - - /// Filter expression - pub fn expression(&self) -> &Arc { - &self.expression - } - - /// Column indices for intermediate batch creation - pub fn column_indices(&self) -> &[ColumnIndex] { - &self.column_indices - } - - /// Intermediate batch schema - pub fn schema(&self) -> &Schema { - &self.schema - } -} - /// Returns the output field given the input field. Outer joins may /// insert nulls even if the input was not null /// @@ -1788,6 +1731,50 @@ impl BatchTransformer for BatchSplitter { } } +/// When the order of the join inputs are changed, the output order of columns +/// must remain the same. +/// +/// Joins output columns from their left input followed by their right input. +/// Thus if the inputs are reordered, the output columns must be reordered to +/// match the original order. +pub(crate) fn reorder_output_after_swap( + plan: Arc, + left_schema: &Schema, + right_schema: &Schema, +) -> Result> { + let proj = ProjectionExec::try_new( + swap_reverting_projection(left_schema, right_schema), + plan, + )?; + Ok(Arc::new(proj)) +} + +/// When the order of the join is changed, the output order of columns must +/// remain the same. +/// +/// Returns the expressions that will allow to swap back the values from the +/// original left as the first columns and those on the right next. +fn swap_reverting_projection( + left_schema: &Schema, + right_schema: &Schema, +) -> Vec<(Arc, String)> { + let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), i)) as Arc, + f.name().to_owned(), + ) + }); + let right_len = right_cols.len(); + let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), right_len + i)) as Arc, + f.name().to_owned(), + ) + }); + + left_cols.chain(right_cols).collect() +} + #[cfg(test)] mod tests { use std::pin::Pin; @@ -2754,4 +2741,39 @@ mod tests { assert!(splitter.next().is_none()); assert_split_batches(batches, batch_size, num_rows); } + + #[tokio::test] + async fn test_swap_reverting_projection() { + let left_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]); + + let proj = swap_reverting_projection(&left_schema, &right_schema); + + assert_eq!(proj.len(), 3); + + let (col, name) = &proj[0]; + assert_eq!(name, "a"); + assert_col_expr(col, "a", 1); + + let (col, name) = &proj[1]; + assert_eq!(name, "b"); + assert_col_expr(col, "b", 2); + + let (col, name) = &proj[2]; + assert_eq!(name, "c"); + assert_col_expr(col, "c", 0); + } + + fn assert_col_expr(expr: &Arc, name: &str, index: usize) { + let col = expr + .as_any() + .downcast_ref::() + .expect("Projection items should be Column expression"); + assert_eq!(col.name(), name); + assert_eq!(col.index(), index); + } }