diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index cf1742a30e663..01dedbb0fa7c9 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -186,7 +186,7 @@ async fn test_full_join_1k_filtered() { } #[tokio::test] -async fn test_semi_join_1k() { +async fn test_left_semi_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -198,7 +198,7 @@ async fn test_semi_join_1k() { } #[tokio::test] -async fn test_semi_join_1k_filtered() { +async fn test_left_semi_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -209,6 +209,30 @@ async fn test_semi_join_1k_filtered() { .await } +#[tokio::test] +async fn test_right_semi_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], true) + .await +} + +#[tokio::test] +async fn test_right_semi_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], true) + .await +} + #[tokio::test] async fn test_anti_join_1k() { JoinFuzzTestCase::new( diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 5b1a296658687..141e98f03ff92 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -115,12 +115,6 @@ impl SortMergeJoinExec { let left_schema = left.schema(); let right_schema = right.schema(); - if join_type == JoinType::RightSemi { - return not_impl_err!( - "SortMergeJoinExec does not support JoinType::RightSemi" - ); - } - check_join_is_valid(&left_schema, &right_schema, &on)?; if sort_options.len() != on.len() { return plan_err!( @@ -148,6 +142,7 @@ impl SortMergeJoinExec { let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); + let cache = Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on); Ok(Self { @@ -732,7 +727,7 @@ struct JoinedRecordBatches { pub batches: Vec, /// Filter match mask for each row(matched/non-matched) pub filter_mask: BooleanBuilder, - /// Row indices to glue together rows in `batches` and `filter_mask` + /// Streamed row indices to glue together rows in `batches` and `filter_mask` pub row_indices: UInt64Builder, /// Which unique batch id the row belongs to /// It is necessary to differentiate rows that are distributed the way when they point to the same @@ -834,7 +829,7 @@ fn get_corrected_filter_mask( corrected_mask.extend(vec![Some(false); null_matched]); Some(corrected_mask.finish()) } - JoinType::LeftSemi => { + JoinType::LeftSemi | JoinType::RightSemi => { for i in 0..row_indices_length { let last_index = last_index_for_row(i, row_indices, batch_ids, row_indices_length); @@ -963,6 +958,7 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftSemi | JoinType::LeftMark | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::Full ) @@ -1045,6 +1041,7 @@ impl Stream for SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::Full @@ -1068,6 +1065,7 @@ impl Stream for SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::Full | JoinType::LeftMark @@ -1382,7 +1380,6 @@ impl SortMergeJoinStream { self.join_type, JoinType::Left | JoinType::Right - | JoinType::RightSemi | JoinType::Full | JoinType::LeftAnti | JoinType::LeftMark @@ -1391,7 +1388,10 @@ impl SortMergeJoinStream { } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) { + if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi + ) { mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); // if the join filter is specified then its needed to output the streamed index // only if it has not been emitted before @@ -1589,7 +1589,7 @@ impl SortMergeJoinStream { continue; } - let mut left_columns = self + let mut right_columns = self .streamed_batch .batch .columns() @@ -1599,9 +1599,12 @@ impl SortMergeJoinStream { // The row indices of joined buffered batch let right_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) { + let mut left_columns = if matches!(self.join_type, JoinType::LeftMark) { vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] - } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi + ) { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { fetch_right_columns_by_idxs( @@ -1622,21 +1625,27 @@ impl SortMergeJoinStream { // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. let filter_columns = if chunk.buffered_batch_idx.is_some() { - if !matches!(self.join_type, JoinType::Right) { - if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark - ) { - let right_cols = fetch_right_columns_by_idxs( - &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), - &right_indices, - )?; + if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + ) { + let buffered_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; - get_filter_column(&self.filter, &left_columns, &right_cols) - } else { - get_filter_column(&self.filter, &left_columns, &right_columns) - } + get_filter_column(&self.filter, &right_columns, &buffered_cols) + } else if matches!(self.join_type, JoinType::RightSemi) { + let buffered_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; + + get_filter_column(&self.filter, &buffered_cols, &right_columns) + } else if matches!(self.join_type, JoinType::Right) { + get_filter_column(&self.filter, &left_columns, &right_columns) } else { get_filter_column(&self.filter, &right_columns, &left_columns) } @@ -1647,14 +1656,15 @@ impl SortMergeJoinStream { }; let columns = if !matches!(self.join_type, JoinType::Right) { - left_columns.extend(right_columns); - left_columns - } else { right_columns.extend(left_columns); right_columns + } else { + left_columns.extend(right_columns); + left_columns }; let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + // Apply join filter if any if !filter_columns.is_empty() { if let Some(f) = &self.filter { @@ -1689,6 +1699,7 @@ impl SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::Full @@ -1772,6 +1783,7 @@ impl SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::Full @@ -1871,6 +1883,10 @@ impl SortMergeJoinStream { let output_column_indices = (0..left_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::RightSemi) { + let output_column_indices = (0..right_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; } else if matches!(self.join_type, JoinType::Full) && corrected_mask.false_count() > 0 { @@ -2305,19 +2321,21 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow_array::builder::{BooleanBuilder, UInt64Builder}; use arrow_array::{BooleanArray, UInt64Array}; + use datafusion_expr::Operator; - use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; + use datafusion_common::{JoinSide, JoinType::*}; use datafusion_execution::config::SessionConfig; use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; + use datafusion_physical_expr::expressions::BinaryExpr; use crate::expressions::Column; use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; - use crate::joins::utils::JoinOn; + use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; use crate::test::build_table_i32; @@ -2440,6 +2458,26 @@ mod tests { ) } + fn join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + sort_options: Vec, + null_equals_null: bool, + ) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + Some(filter), + join_type, + sort_options, + null_equals_null, + ) + } + async fn join_collect( left: Arc, right: Arc, @@ -2474,6 +2512,25 @@ mod tests { Ok((columns, batches)) } + async fn join_collect_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let sort_options = vec![SortOptions::default(); on.len()]; + + let task_ctx = Arc::new(TaskContext::default()); + let join = + join_with_filter(left, right, on, filter, join_type, sort_options, false)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + async fn join_collect_batch_size_equals_two( left: Arc, right: Arc, @@ -2862,7 +2919,7 @@ mod tests { } #[tokio::test] - async fn join_semi() -> Result<()> { + async fn join_left_semi() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2, 3]), ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right @@ -2893,6 +2950,255 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_semi_one() -> Result<()> { + let left = build_table( + ("a1", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a2", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Lt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 10 | 20 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 3 | 6 | |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![None, Some(5), Some(4), None, Some(5)]), + ("c2", &vec![Some(90), Some(80), Some(70), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightSemi, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + true, + ) + .await?; + + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 3 | | 9 |", + "| 2 | 5 | |", + "| 2 | 5 | 8 |", + "| 1 | 4 | 7 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_left_mark() -> Result<()> { let left = build_table( @@ -3257,7 +3563,9 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3335,7 +3643,9 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3391,7 +3701,9 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3492,7 +3804,9 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3882,175 +4196,177 @@ mod tests { } #[tokio::test] - async fn test_left_semi_join_filtered_mask() -> Result<()> { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + async fn test_semi_join_filtered_mask() -> Result<()> { + for join_type in [LeftSemi, RightSemi] { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![true]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true),]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, Some(true), None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - output.num_rows() + let corrected_mask = get_corrected_filter_mask( + join_type, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); - - let corrected_mask = get_corrected_filter_mask( - LeftSemi, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); - - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - None, - Some(true), - None, - None, - None - ]) - ); - - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); - assert_batches_eq!( - &[ - "+---+----+---+----+", - "| a | b | x | y |", - "+---+----+---+----+", - "| 1 | 10 | 1 | 11 |", - "| 1 | 11 | 1 | 12 |", - "| 1 | 12 | 1 | 13 |", - "+---+----+---+----+", - ], - &[filtered_rb] - ); + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - Some(false), - None, - Some(false), - None, - Some(false), - None, - None, - None - ]) - ); + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); - let null_joined_batch = filter_record_batch(&output, &null_mask)?; + let null_joined_batch = filter_record_batch(&output, &null_mask)?; - assert_batches_eq!( - &[ - "+---+---+---+---+", - "| a | b | x | y |", - "+---+---+---+---+", - "+---+---+---+---+", - ], - &[null_joined_batch] - ); + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + } Ok(()) } diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 9a20e7987ff63..cb2ba1e27c4ef 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -508,6 +508,124 @@ select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1. ---- +# RIGHTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + + # Test LEFT ANTI with cross batch data distribution statement ok set datafusion.execution.batch_size = 1; @@ -647,6 +765,29 @@ NULL NULL 7 9 NULL NULL 8 10 NULL NULL 9 11 + +# Test RIGHTSEMI with cross batch data distribution + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b union all + select 12 a, 14 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b union all + select 12 a, 15 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 +12 15 + + # return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true;