From 22e46922a44725fb393f34b806bd67efcbba6948 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Jun 2024 10:01:03 -0700 Subject: [PATCH] fix: Fix the incorrect null joined rows for outer join with join filter --- .../src/joins/sort_merge_join.rs | 273 +++++++++++------- .../test_files/sort_merge_join.slt | 29 +- 2 files changed, 189 insertions(+), 113 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 8da345cdfca6..d080ea5681b3 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -34,6 +34,7 @@ use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use arrow_array::types::UInt64Type; use futures::{Stream, StreamExt}; use hashbrown::HashSet; @@ -476,6 +477,7 @@ struct StreamedJoinedChunk { /// Array builder for streamed indices streamed_indices: UInt64Builder, /// Array builder for buffered indices + /// This could contain nulls if the join is null-joined buffered_indices: UInt64Builder, } @@ -564,6 +566,9 @@ struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, + /// The indices of buffered batch that failed the join filter. + /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. + pub join_filter_failed_idxs: HashSet, } impl BufferedBatch { @@ -595,6 +600,7 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, + join_filter_failed_idxs: HashSet::new(), } } } @@ -852,6 +858,7 @@ impl SMJStream { // pop previous buffered batches while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); + // If the head batch is fully processed, dequeue it and produce output of it. if head_batch.range.end == head_batch.batch.num_rows() { self.freeze_dequeuing_buffered()?; if let Some(buffered_batch) = @@ -860,6 +867,8 @@ impl SMJStream { self.reservation.shrink(buffered_batch.size_estimation); } } else { + // If the head batch is not fully processed, break the loop. + // Streamed batch will be joined with the head batch in the next step. break; } } @@ -1055,7 +1064,7 @@ impl SMJStream { Some(scanning_idx), ); } else { - // Join nulls and buffered row + // Join nulls and buffered row for full join self.buffered_data .scanning_batch_mut() .null_joined @@ -1088,7 +1097,7 @@ impl SMJStream { fn freeze_all(&mut self) -> Result<()> { self.freeze_streamed()?; - self.freeze_buffered(self.buffered_data.batches.len())?; + self.freeze_buffered(self.buffered_data.batches.len(), false)?; Ok(()) } @@ -1098,7 +1107,8 @@ impl SMJStream { // 2. freezes NULLs joined to dequeued buffered batch to "release" it fn freeze_dequeuing_buffered(&mut self) -> Result<()> { self.freeze_streamed()?; - self.freeze_buffered(1)?; + // Only freeze and produce the first batch in buffered_data as the batch is fully processed + self.freeze_buffered(1, true)?; Ok(()) } @@ -1106,7 +1116,11 @@ impl SMJStream { // NULLs on streamed side. // // Applicable only in case of Full join. - fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { + fn freeze_buffered( + &mut self, + batch_count: usize, + output_join_filter_fail_batch: bool, + ) -> Result<()> { if !matches!(self.join_type, JoinType::Full) { return Ok(()); } @@ -1114,33 +1128,31 @@ impl SMJStream { let buffered_indices = UInt64Array::from_iter_values( buffered_batch.null_joined.iter().map(|&index| index as u64), ); - if buffered_indices.is_empty() { - continue; + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.output_record_batches.push(record_batch); } buffered_batch.null_joined.clear(); - // Take buffered (right) columns - let buffered_columns = buffered_batch - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?; - - // Create null streamed (left) columns - let mut streamed_columns = self - .streamed_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), buffered_indices.len())) - .collect::>(); - - streamed_columns.extend(buffered_columns); - let columns = streamed_columns; - - self.output_record_batches - .push(RecordBatch::try_new(self.schema.clone(), columns)?); + // For buffered rows which are joined with streamed side but failed on join filter + if output_join_filter_fail_batch { + let buffered_indices = UInt64Array::from_iter_values( + buffered_batch.join_filter_failed_idxs.iter().copied(), + ); + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.output_record_batches.push(record_batch); + } + buffered_batch.join_filter_failed_idxs.clear(); + } } Ok(()) } @@ -1149,6 +1161,7 @@ impl SMJStream { // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { for chunk in self.streamed_batch.output_indices.iter_mut() { + // The row indices of joined streamed batch let streamed_indices = chunk.streamed_indices.finish(); if streamed_indices.is_empty() { @@ -1163,6 +1176,7 @@ impl SMJStream { .map(|column| take(column, &streamed_indices, None)) .collect::, ArrowError>>()?; + // The row indices of joined buffered batch let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { @@ -1174,6 +1188,8 @@ impl SMJStream { &buffered_indices, )? } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. self.buffered_schema .fields() .iter() @@ -1205,7 +1221,8 @@ impl SMJStream { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } } else { - // This chunk is for null joined rows (outer join), we don't need to apply join filter. + // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. + // Any join filter applied only on either streamed or buffered side will be pushed already. vec![] }; @@ -1234,49 +1251,73 @@ impl SMJStream { .evaluate(&filter_batch)? .into_array(filter_batch.num_rows())?; - // The selection mask of the filter - let mut mask = + // The boolean selection mask of the join filter result + let pre_mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + // If there are nulls in join filter result, exclude them from selecting + // the rows to output. + let mask = if pre_mask.null_count() > 0 { + compute::prep_null_mask_filter( + datafusion_common::cast::as_boolean_array(&filter_result)?, + ) + } else { + pre_mask.clone() + }; + + // For certain join types, we need to adjust the initial mask to handle the join filter. let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = get_filtered_join_mask( self.join_type, - streamed_indices, - mask, + &streamed_indices, + &mask, &self.streamed_batch.join_filter_matched_idxs, &self.buffered_data.scanning_offset, ); - if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { - mask = &filtered_join_mask.0; - self.streamed_batch - .join_filter_matched_idxs - .extend(&filtered_join_mask.1); - } + let mask = + if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { + self.streamed_batch + .join_filter_matched_idxs + .extend(&filtered_join_mask.1); + &filtered_join_mask.0 + } else { + &mask + }; - // Push the filtered batch to the output + // Push the filtered batch which contains rows passing join filter to the output let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; self.output_record_batches.push(filtered_batch); - // For outer joins, we need to push the null joined rows to the output. + // For outer joins, we need to push the null joined rows to the output if + // all joined rows are failed on the join filter. + // I.e., if all rows joined from a streamed row are failed with the join filter, + // we need to join it with nulls as buffered side. if matches!( self.join_type, JoinType::Left | JoinType::Right | JoinType::Full ) { - // The reverse of the selection mask. For the rows not pass join filter above, - // we need to join them (left or right) with null rows for outer joins. - let not_mask = if mask.null_count() > 0 { - // If the mask contains nulls, we need to use `prep_null_mask_filter` to - // handle the nulls in the mask as false to produce rows where the mask - // was null itself. - compute::not(&compute::prep_null_mask_filter(mask))? - } else { - compute::not(mask)? - }; + // We need to get the mask for row indices that the joined rows are failed + // on the join filter. I.e., for a row in streamed side, if all joined rows + // between it and all buffered rows are failed on the join filter, we need to + // output it with null columns from buffered side. For the mask here, it + // behaves like LeftAnti join. + let null_mask: BooleanArray = get_filtered_join_mask( + // Set a mask slot as true only if all joined rows of same streamed index + // are failed on the join filter. + // The masking behavior is like LeftAnti join. + JoinType::LeftAnti, + &streamed_indices, + mask, + &self.streamed_batch.join_filter_matched_idxs, + &self.buffered_data.scanning_offset, + ) + .unwrap() + .0; let null_joined_batch = - compute::filter_record_batch(&output_batch, ¬_mask)?; + compute::filter_record_batch(&output_batch, &null_mask)?; let mut buffered_columns = self .buffered_schema @@ -1313,51 +1354,37 @@ impl SMJStream { streamed_columns }; + // Push the streamed/buffered batch joined nulls to the output let null_joined_streamed_batch = RecordBatch::try_new(self.schema.clone(), columns.clone())?; self.output_record_batches.push(null_joined_streamed_batch); - // For full join, we also need to output the null joined rows from the buffered side + // For full join, we also need to output the null joined rows from the buffered side. + // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with + // streamed side, it won't be outputted by `freeze_buffered`. + // We need to check if a buffered row is joined with streamed side and output. + // If it is joined with streamed side, but finally fails on the join filter, + // we need to output it with nulls as streamed side. if matches!(self.join_type, JoinType::Full) { - // Handle not mask for buffered side further. - // For buffered side, we want to output the rows that are not null joined with - // the streamed side. i.e. the rows that are not null in the `buffered_indices`. - let not_mask = if let Some(nulls) = buffered_indices.nulls() { - let mask = not_mask.values() & nulls.inner(); - BooleanArray::new(mask, None) - } else { - not_mask - }; - - let null_joined_batch = - compute::filter_record_batch(&output_batch, ¬_mask)?; - - let mut streamed_columns = self - .streamed_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - null_joined_batch.num_rows(), - ) - }) - .collect::>(); - - let buffered_columns = null_joined_batch - .columns() - .iter() - .skip(streamed_columns_length) - .cloned() - .collect::>(); - - streamed_columns.extend(buffered_columns); - - let null_joined_buffered_batch = RecordBatch::try_new( - self.schema.clone(), - streamed_columns, - )?; - self.output_record_batches.push(null_joined_buffered_batch); + for i in 0..pre_mask.len() { + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; + let buffered_index = buffered_indices.value(i); + + if !pre_mask.value(i) { + // For a buffered row that is joined with streamed side but failed on the join filter, + buffered_batch + .join_filter_failed_idxs + .insert(buffered_index); + } else if buffered_batch + .join_filter_failed_idxs + .contains(&buffered_index) + { + buffered_batch + .join_filter_failed_idxs + .remove(&buffered_index); + } + } } } } else { @@ -1422,6 +1449,38 @@ fn get_filter_column( filter_columns } +fn produce_buffered_null_batch( + schema: &SchemaRef, + streamed_schema: &SchemaRef, + buffered_indices: &PrimitiveArray, + buffered_batch: &BufferedBatch, +) -> Result> { + if buffered_indices.is_empty() { + return Ok(None); + } + + // Take buffered (right) columns + let buffered_columns = buffered_batch + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() + .map_err(Into::::into)?; + + // Create null streamed (left) columns + let mut streamed_columns = streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>(); + + streamed_columns.extend(buffered_columns); + let columns = streamed_columns; + + Ok(Some(RecordBatch::try_new(schema.clone(), columns)?)) +} + /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` #[inline(always)] fn get_buffered_columns( @@ -1445,9 +1504,13 @@ fn get_buffered_columns( /// `streamed_indices` have the same length as `mask` /// `matched_indices` array of streaming indices that already has a join filter match /// `scanning_buffered_offset` current buffered offset across batches +/// +/// This return a tuple of: +/// - corrected mask with respect to the join type +/// - indices of rows in streamed batch that have a join filter match fn get_filtered_join_mask( join_type: JoinType, - streamed_indices: UInt64Array, + streamed_indices: &UInt64Array, mask: &BooleanArray, matched_indices: &HashSet, scanning_buffered_offset: &usize, @@ -2808,7 +2871,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 1, 1]), + &UInt64Array::from(vec![0, 0, 1, 1]), &BooleanArray::from(vec![true, true, false, false]), &HashSet::new(), &0, @@ -2819,7 +2882,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, true]), &HashSet::new(), &0, @@ -2830,7 +2893,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![false, true]), &HashSet::new(), &0, @@ -2841,7 +2904,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, false]), &HashSet::new(), &0, @@ -2852,7 +2915,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]), &HashSet::new(), &0, @@ -2866,7 +2929,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]), &HashSet::new(), &0, @@ -2885,7 +2948,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 1, 1]), + &UInt64Array::from(vec![0, 0, 1, 1]), &BooleanArray::from(vec![true, true, false, false]), &HashSet::new(), &0, @@ -2896,7 +2959,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, true]), &HashSet::new(), &0, @@ -2907,7 +2970,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![false, true]), &HashSet::new(), &0, @@ -2918,7 +2981,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, false]), &HashSet::new(), &0, @@ -2929,7 +2992,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]), &HashSet::new(), &0, @@ -2943,7 +3006,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]), &HashSet::new(), &0, diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index b4deb43a728e..5a6334602c22 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -84,7 +84,6 @@ SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 -Alice 50 NULL NULL Bob 1 NULL NULL query TITI rowsort @@ -112,7 +111,6 @@ SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 -NULL NULL Alice 2 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -137,12 +135,9 @@ query TITI rowsort SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ---- Alice 100 NULL NULL -Alice 100 NULL NULL Alice 50 Alice 2 -Alice 50 NULL NULL Bob 1 NULL NULL NULL NULL Alice 1 -NULL NULL Alice 1 NULL NULL Alice 2 query TITI rowsort @@ -151,10 +146,7 @@ SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 NULL NULL -Alice 50 NULL NULL Bob 1 NULL NULL -NULL NULL Alice 1 -NULL NULL Alice 2 statement ok DROP TABLE t1; @@ -613,6 +605,27 @@ select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1. ) order by 1, 2 ---- +query IIII +select * from ( +with t as ( + select id, id % 5 id1 from (select unnest(range(0,10)) id) +), t1 as ( + select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) +) +select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 +) order by 1, 2, 3, 4 +---- +5 0 0 2 +6 1 1 3 +7 2 2 4 +8 3 3 5 +9 4 4 6 +NULL NULL 5 7 +NULL NULL 6 8 +NULL NULL 7 9 +NULL NULL 8 10 +NULL NULL 9 11 + # return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true;