diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2baeb29df635..5e77becd1c5e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -703,13 +703,21 @@ struct SMJStream { pub reservation: MemoryReservation, /// Runtime env pub runtime_env: Arc, + /// A unique number for each batch pub streamed_batch_counter: AtomicUsize, } +/// Joined batches with attached join filter information struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources pub batches: Vec, + /// Filter match mask for each row(matched/non-matched) pub filter_mask: BooleanBuilder, + /// Row indices to glue together rows in `batches` and `filter_mask` pub row_indices: UInt64Builder, + /// Which unique batch id the row belongs to + /// It is necessary to differentiate rows that are distributed the way when they point to the same + /// row index but in not the same batches pub batch_ids: Vec, } @@ -790,7 +798,7 @@ fn get_corrected_filter_mask( Some(corrected_mask.finish()) } - + // Only outer joins needs to keep track of processed rows and apply corrected filter mask _ => None, } } @@ -888,6 +896,11 @@ impl Stream for SMJStream { self.freeze_all()?; if !self.output_record_batches.batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; + // For non-filtered join output whenever the target output batch size + // is hit. For filtered join its needed to output on later phase + // because target output batch size can be hit in the middle of + // filtering causing the filtering to be incomplete and causing + // correctness issues let record_batch = if !(self.filter.is_some() && matches!( self.join_type, @@ -895,7 +908,7 @@ impl Stream for SMJStream { )) { record_batch } else { - RecordBatch::new_empty(Arc::clone(&self.schema)) + continue; }; return Poll::Ready(Some(Ok(record_batch))); @@ -1010,6 +1023,8 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation self.streamed_batch_counter .fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.streamed_state = StreamedState::Ready; @@ -1459,15 +1474,14 @@ impl SMJStream { }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns.clone()); - buffered_columns.clone() + buffered_columns.extend(streamed_columns); + buffered_columns } else { - streamed_columns.extend(buffered_columns.clone()); - streamed_columns.clone() + streamed_columns.extend(buffered_columns); + streamed_columns }; - let output_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; // Apply join filter if any if !filter_columns.is_empty() { @@ -1475,7 +1489,7 @@ impl SMJStream { // Construct batch with only filter columns let filter_batch = RecordBatch::try_new( Arc::new(f.schema().clone()), - filter_columns.clone(), + filter_columns, )?; let filter_result = f @@ -1574,10 +1588,8 @@ impl SMJStream { }; // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = RecordBatch::try_new( - Arc::clone(&self.schema), - columns.clone(), - )?; + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; self.output_record_batches .batches .push(null_joined_streamed_batch);