diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index dfaa7dbb8910..1c63df1f0281 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -140,20 +140,32 @@ fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, projection: Option<&Vec>, + join_type: &JoinType, ) -> Option> { - projection.map(|p| { - p.iter() - .map(|i| { - // If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it. - // Otherwise, it is from the right schema, so we subtract the left schema length from it. - if *i < left_schema_len { - *i + right_schema_len - } else { - *i - left_schema_len - } - }) - .collect() - }) + match join_type { + // For Anti/Semi join types, projection should remain unmodified, + // since these joins output schema remains the same after swap + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::RightAnti + | JoinType::RightSemi => projection.cloned(), + + _ => projection.map(|p| { + p.iter() + .map(|i| { + // If the index is less than the left schema length, it is from + // the left schema, so we add the right schema length to it. + // Otherwise, it is from the right schema, so we subtract the left + // schema length from it. + if *i < left_schema_len { + *i + right_schema_len + } else { + *i - left_schema_len + } + }) + .collect() + }), + } } /// This function swaps the inputs of the given join operator. @@ -179,6 +191,7 @@ pub fn swap_hash_join( left.schema().fields().len(), right.schema().fields().len(), hash_join.projection.as_ref(), + hash_join.join_type(), ), partition_mode, hash_join.null_equals_null(), @@ -1289,27 +1302,59 @@ mod tests_statistical { ); } + #[rstest( + join_type, projection, small_on_right, + case::inner(JoinType::Inner, vec![1], true), + case::left(JoinType::Left, vec![1], true), + case::right(JoinType::Right, vec![1], true), + case::full(JoinType::Full, vec![1], true), + case::left_anti(JoinType::LeftAnti, vec![0], false), + case::left_semi(JoinType::LeftSemi, vec![0], false), + case::right_anti(JoinType::RightAnti, vec![0], true), + case::right_semi(JoinType::RightSemi, vec![0], true), + )] #[tokio::test] - async fn test_hash_join_swap_on_joins_with_projections() -> Result<()> { + async fn test_hash_join_swap_on_joins_with_projections( + join_type: JoinType, + projection: Vec, + small_on_right: bool, + ) -> Result<()> { let (big, small) = create_big_and_small(); + + let left = if small_on_right { &big } else { &small }; + let right = if small_on_right { &small } else { &big }; + + let left_on = if small_on_right { + "big_col" + } else { + "small_col" + }; + let right_on = if small_on_right { + "small_col" + } else { + "big_col" + }; + let join = Arc::new(HashJoinExec::try_new( - Arc::clone(&big), - Arc::clone(&small), + Arc::clone(left), + Arc::clone(right), vec![( - Arc::new(Column::new_with_schema("big_col", &big.schema())?), - Arc::new(Column::new_with_schema("small_col", &small.schema())?), + Arc::new(Column::new_with_schema(left_on, &left.schema())?), + Arc::new(Column::new_with_schema(right_on, &right.schema())?), )], None, - &JoinType::Inner, - Some(vec![1]), + &join_type, + Some(projection), PartitionMode::Partitioned, false, )?); + let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned) .expect("swap_hash_join must support joins with projections"); let swapped_join = swapped.as_any().downcast_ref::().expect( "ProjectionExec won't be added above if HashJoinExec contains embedded projection", ); + assert_eq!(swapped_join.projection, Some(vec![0_usize])); assert_eq!(swapped.schema().fields.len(), 1); assert_eq!(swapped.schema().fields[0].name(), "small_col");