From 8a4ac43d79084947155bcecb9b258e970bdc5ce6 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 20 Aug 2024 08:59:04 -0700 Subject: [PATCH 1/2] WIP: experiment with SMJ last buffered batch --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 81 +++++----- .../src/joins/sort_merge_join.rs | 149 +++++++++++++++--- 2 files changed, 168 insertions(+), 62 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index f1cca66712d7..451ea18e26c9 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -15,19 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::Schema; - use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExprRef; - -use rand::Rng; +use std::sync::Arc; +use std::time::SystemTime; use datafusion::common::JoinSide; use datafusion::logical_expr::{JoinType, Operator}; @@ -39,6 +36,7 @@ use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, }; use datafusion::physical_plan::memory::MemoryExec; +use rand::Rng; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; @@ -450,11 +448,15 @@ impl JoinFuzzTestCase { let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); - if debug { + if debug + && ((join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows) + || (join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows)) + { let fuzz_debug = "fuzz_test_debug"; std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); + println!("Test result data mismatch found. HJ rows {}, SMJ rows {}, NLJ rows {}", hj_rows, smj_rows, nlj_rows); println!("The debug is ON. Input data will be saved to {out_dir_name}"); Self::save_partitioned_batches_as_parquet( @@ -468,7 +470,7 @@ impl JoinFuzzTestCase { "input2", ); - if join_tests.contains(&JoinTestType::NljHj) { + if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows { Self::save_partitioned_batches_as_parquet( &nlj_collected, out_dir_name, @@ -481,7 +483,7 @@ impl JoinFuzzTestCase { ); } - if join_tests.contains(&JoinTestType::HjSmj) { + if join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows { Self::save_partitioned_batches_as_parquet( &hj_collected, out_dir_name, @@ -537,6 +539,11 @@ impl JoinFuzzTestCase { if join_tests.contains(&JoinTestType::HjSmj) { let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); + // println!("=============== HashJoinExec =================="); + // hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + // println!("=============== SortMergeJoinExec =================="); + // smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size); @@ -578,34 +585,6 @@ impl JoinFuzzTestCase { /// ) /// .run_test(&[JoinTestType::HjSmj], false) /// .await; - /// - /// let ctx: SessionContext = SessionContext::new(); - /// let df = ctx - /// .read_parquet( - /// "/tmp/input1/*.parquet", - /// datafusion::prelude::ParquetReadOptions::default(), - /// ) - /// .await - /// .unwrap(); - /// let left = df.collect().await.unwrap(); - /// - /// let df = ctx - /// .read_parquet( - /// "/tmp/input2/*.parquet", - /// datafusion::prelude::ParquetReadOptions::default(), - /// ) - /// .await - /// .unwrap(); - /// - /// let right = df.collect().await.unwrap(); - /// JoinFuzzTestCase::new( - /// left, - /// right, - /// JoinType::LeftSemi, - /// Some(Box::new(less_than_100_join_filter)), - /// ) - /// .run_test() - /// .await /// } fn save_partitioned_batches_as_parquet( input: &[RecordBatch], @@ -617,9 +596,15 @@ impl JoinFuzzTestCase { std::fs::create_dir_all(out_path).unwrap(); input.iter().enumerate().for_each(|(idx, batch)| { - let mut file = - std::fs::File::create(format!("{out_path}/file_{}.parquet", idx)) - .unwrap(); + let file_path = format!("{out_path}/file_{}.parquet", idx); + let mut file = std::fs::File::create(&file_path).unwrap(); + println!( + "{}: Saving batch idx {} rows {} to parquet {}", + &out_name, + idx, + batch.num_rows(), + &file_path + ); let mut writer = parquet::arrow::ArrowWriter::try_new( &mut file, input.first().unwrap().schema(), @@ -629,8 +614,6 @@ impl JoinFuzzTestCase { writer.write(batch).unwrap(); writer.close().unwrap(); }); - - println!("The data {out_name} saved as parquet into {out_path}"); } /// Read parquet files preserving partitions, i.e. 1 file -> 1 partition @@ -643,10 +626,20 @@ impl JoinFuzzTestCase { ) -> std::io::Result> { let ctx: SessionContext = SessionContext::new(); let mut batches: Vec = vec![]; + let mut entries = std::fs::read_dir(dir)? + .map(|res| res.map(|e| e.path())) + .collect::, std::io::Error>>()?; + + // important to read files using the same order as they have been written + // sort by modification time + entries.sort_by_key(|path| { + std::fs::metadata(path) + .and_then(|metadata| metadata.modified()) + .unwrap_or(SystemTime::UNIX_EPOCH) + }); - for entry in std::fs::read_dir(dir)? { - let entry = entry?; - let path = entry.path(); + for entry in entries { + let path = entry.as_path(); if path.is_file() { let mut batch = ctx diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 96d5ba728a30..7fe4c94aaf2a 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -38,9 +38,6 @@ use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; - use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -51,6 +48,9 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; +use itertools::Itertools; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ @@ -730,6 +730,8 @@ impl Stream for SMJStream { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { + //dbg!(streamed_exhausted); + //dbg!(buffered_exhausted); self.streamed_joined = false; self.streamed_state = StreamedState::Init; } @@ -777,7 +779,9 @@ impl Stream for SMJStream { self.join_partial()?; if self.output_size < self.batch_size { + //println!("scanning_finished:{} output_size:{} batch_size:{}", self.buffered_data.scanning_finished(), self.output_size, self.batch_size); if self.buffered_data.scanning_finished() { + //dbg!("reset"); self.buffered_data.scanning_reset(); self.state = SMJState::Init; } @@ -854,14 +858,17 @@ impl SMJStream { /// Poll next streamed row fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { loop { + //dbg!(&self.streamed_state); match &self.streamed_state { StreamedState::Init => { if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows() { self.streamed_batch.idx += 1; + //dbg!("inc streamed_batch.idx"); self.streamed_state = StreamedState::Ready; return Poll::Ready(Some(Ok(()))); } else { + //dbg!("streamed_batch polling"); self.streamed_state = StreamedState::Polling; } } @@ -987,6 +994,9 @@ impl SMJStream { Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); + //self.buffered_data.current_scanning_batch = 0; + //dbg!("Polling First"); + //dbg!(&batch); if batch.num_rows() > 0 { let buffered_batch = @@ -1028,6 +1038,9 @@ impl SMJStream { // Polling batches coming concurrently as multiple partitions self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); + //dbg!("Polling Rest"); + //dbg!(&batch); + if batch.num_rows() > 0 { let buffered_batch = BufferedBatch::new( batch, @@ -1059,19 +1072,20 @@ impl SMJStream { return Ok(Ordering::Less); } - return compare_join_arrays( + compare_join_arrays( &self.streamed_batch.join_arrays, self.streamed_batch.idx, &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, self.null_equals_null, - ); + ) } /// Produce join and fill output buffer until reaching target batch size /// or the join is finished fn join_partial(&mut self) -> Result<()> { + //dbg!("join_partial()"); // Whether to join streamed rows let mut join_streamed = false; // Whether to join buffered rows @@ -1135,6 +1149,8 @@ impl SMJStream { } if !join_streamed && !join_buffered { // no joined data + //dbg!("no joined data"); + //self.buffered_data.current_scanning_batch = 0; self.buffered_data.scanning_finish(); return Ok(()); } @@ -1145,8 +1161,14 @@ impl SMJStream { && self.output_size < self.batch_size { let scanning_idx = self.buffered_data.scanning_idx(); + //dbg!(&scanning_idx); + //dbg!(self.buffered_data.scanning_batch_idx); + //dbg!(self.buffered_data.batches.len()); + //dbg!(self.buffered_data.scanning_batch()); + //dbg!(join_streamed); if join_streamed { // Join streamed row and buffered row + //dbg!("append_output_pair()"); self.streamed_batch.append_output_pair( Some(self.buffered_data.scanning_batch_idx), Some(scanning_idx), @@ -1251,7 +1273,21 @@ impl SMJStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { + //println!("\nfreeze_streamed()"); + + // Sometimes buffered batches assigned to the stream row with different join key + let matched_batches = self + .buffered_data + .batches + .iter() + .enumerate() + .filter(|(i, b)| { + (b.range.start > 0 && b.range.end > 0) || (b.range.start == 0 && i == &0) + }) + .count(); + for chunk in self.streamed_batch.output_indices.iter_mut() { + //dbg!("chunk"); // The row indices of joined streamed batch let streamed_indices = chunk.streamed_indices.finish(); @@ -1318,7 +1354,7 @@ impl SMJStream { }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns.clone()); + buffered_columns.extend(streamed_columns); buffered_columns } else { streamed_columns.extend(buffered_columns); @@ -1334,7 +1370,7 @@ impl SMJStream { // Construct batch with only filter columns let filter_batch = RecordBatch::try_new( Arc::new(f.schema().clone()), - filter_columns, + filter_columns.clone(), )?; let filter_result = f @@ -1356,16 +1392,82 @@ impl SMJStream { pre_mask.clone() }; + // Try to calculate if the buffered batch we scan is the last one for specific stream row and join key + // for Batchsize == 1 self.buffered_data.scanning_finished() works well + // For other scenarios its an attempt to figure out there is no more rows matching the same join key + let last_batch = if self.batch_size == 1 { + self.buffered_data.scanning_finished() + } else if matched_batches > 1 { + self.buffered_data.current_scanning_batch >= matched_batches - 1 + } else { + self.buffered_data.scanning_offset == 0 + }; // For certain join types, we need to adjust the initial mask to handle the join filter. let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = get_filtered_join_mask( self.join_type, &streamed_indices, + &buffered_indices, &mask, &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, + last_batch, ); + let a = columns + .get(0) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let b = columns + .get(1) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + if a.iter().contains(&Some(95)) + && b.iter().contains(&Some(46)) + && false + { + println!("++++++++++++++++++++++++ START BATCH SIZE {} ++++++++++++++++++++++++++++++++++", self.batch_size); + dbg!(last_batch); + dbg!(matched_batches); + //dbg!(&columns); + dbg!(&self.streamed_batch.idx); + //dbg!(&self.streamed_batch.batch.num_rows()); + //dbg!(&self.streamed_batch.batch); + //dbg!(&columns); + dbg!(&filter_columns); + dbg!(&mask); + dbg!(&maybe_filtered_join_mask); + //dbg!(&pre_mask); + dbg!(&self.buffered_data.scanning_offset); + dbg!(&self.buffered_data.scanning_batch_idx); + dbg!(self.buffered_data.current_scanning_batch); + + //dbg!(&self.buffered_data.scanning_batch_finished()); + //dbg!(&self.buffered_data.scanning_finished()); + dbg!(&buffered_indices); + dbg!(&streamed_indices); + //dbg!(&buffered_indices.is_empty()); + //dbg!(&buffered_indices.is_nullable()); + //dbg!(&self.streamed_batch.join_filter_matched_idxs); + dbg!(&self.streamed_batch.buffered_batch_idx); + dbg!(&self.buffered_data.batches.len()); + dbg!(&self.buffered_data.batches); + dbg!(&chunk.buffered_batch_idx); + dbg!(&chunk.streamed_indices.len()); + + dbg!(self.output_size); + //dbg!(&self.buffered_data.scanning_idx()); + //self.scanning_batch_idx == self.batches.len() + //dbg!(i); + //dbg!(streamed_output_indices); + + println!("++++++++++++++++++++++++ END BATCH SIZE {} ++++++++++++++++++++++++++++++++++", self.batch_size); + } + let mask = if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { self.streamed_batch @@ -1400,9 +1502,10 @@ impl SMJStream { // The masking behavior is like LeftAnti join. JoinType::LeftAnti, &streamed_indices, + &buffered_indices, mask, &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, + true, ) .unwrap() .0; @@ -1486,14 +1589,21 @@ impl SMJStream { } else { self.output_record_batches.push(output_batch); } + if self.buffered_data.batches.len() > 1 + && self.buffered_data.current_scanning_batch < matched_batches - 1 + { + self.buffered_data.current_scanning_batch += 1; + } else { + self.buffered_data.current_scanning_batch = 0 + }; } - self.streamed_batch.output_indices.clear(); Ok(()) } fn output_record_batch_and_reset(&mut self) -> Result { + //dbg!("output_record_batch_and_reset"); let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); @@ -1633,9 +1743,10 @@ fn get_buffered_columns_from_batch( fn get_filtered_join_mask( join_type: JoinType, streamed_indices: &UInt64Array, + buffered_indices: &UInt64Array, mask: &BooleanArray, matched_indices: &HashSet, - scanning_buffered_offset: &usize, + last_buffered_batch: bool, ) -> Option<(BooleanArray, Vec)> { let mut seen_as_true: bool = false; let streamed_indices_length = streamed_indices.len(); @@ -1675,17 +1786,14 @@ fn get_filtered_join_mask( } Some((corrected_mask.finish(), filter_matched_indices)) } - // LeftAnti semantics: return true if for every x in the collection the join matching filter is false. + // LeftAnti semantics: return true if for every element in the collection the join matching filter is false. // `filter_matched_indices` needs to be set once per streaming index // to prevent duplicates in the output JoinType::LeftAnti => { // have we seen a filter match for a streaming index before for i in 0..streamed_indices_length { let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { + if mask.value(i) && !seen_as_true { seen_as_true = true; filter_matched_indices.push(streamed_idx); } @@ -1695,8 +1803,8 @@ fn get_filtered_join_mask( // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last if (i < streamed_indices_length - 1 && streamed_idx != streamed_indices.value(i + 1)) - || (i == streamed_indices_length - 1 - && *scanning_buffered_offset == 0) + || (i == streamed_indices_length - 1 && last_buffered_batch) + || buffered_indices.is_null(i) { corrected_mask.append_value( !matched_indices.contains(&streamed_idx) && !seen_as_true, @@ -1722,6 +1830,8 @@ struct BufferedData { pub scanning_batch_idx: usize, /// current scanning offset used in join_partial() pub scanning_offset: usize, + /// + pub current_scanning_batch: usize, } impl BufferedData { @@ -1742,6 +1852,7 @@ impl BufferedData { } pub fn scanning_reset(&mut self) { + //dbg!("scanning_reset"); self.scanning_batch_idx = 0; self.scanning_offset = 0; } @@ -1749,6 +1860,7 @@ impl BufferedData { pub fn scanning_advance(&mut self) { self.scanning_offset += 1; while !self.scanning_finished() && self.scanning_batch_finished() { + //dbg!("scanning_advance"); self.scanning_batch_idx += 1; self.scanning_offset = 0; } @@ -1775,6 +1887,7 @@ impl BufferedData { } pub fn scanning_finish(&mut self) { + //dbg!("scanning_finish"); self.scanning_batch_idx = self.batches.len(); self.scanning_offset = 0; } From 0bad8bda74300f091d9feb485f5bf7e13fc6ef5d Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 20 Aug 2024 09:02:20 -0700 Subject: [PATCH 2/2] WIP: experiment with SMJ last buffered batch --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 451ea18e26c9..1e5a6adf5279 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -248,20 +248,22 @@ async fn test_anti_join_1k() { } #[tokio::test] -#[ignore] +//#[ignore] // flaky test giving 1 rows difference sometimes // https://github.com/apache/datafusion/issues/11555 async fn test_anti_join_1k_filtered() { // NLJ vs HJ gives wrong result // Tracked in https://github.com/apache/datafusion/issues/11537 - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::LeftAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[JoinTestType::HjSmj], false) - .await + for i in 0..1000 { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::HjSmj], false) + .await + } } type JoinFilterBuilder = Box, Arc) -> JoinFilter>;