Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed Oct 18, 2024
1 parent cad91c2 commit 3986741
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,13 +703,21 @@ struct SMJStream {
pub reservation: MemoryReservation,
/// Runtime env
pub runtime_env: Arc<RuntimeEnv>,
/// A unique number for each batch
pub streamed_batch_counter: AtomicUsize,
}

/// Joined batches with attached join filter information
struct JoinedRecordBatches {
/// Joined batches. Each batch is already joined columns from left and right sources
pub batches: Vec<RecordBatch>,
/// Filter match mask for each row(matched/non-matched)
pub filter_mask: BooleanBuilder,
/// Row indices to glue together rows in `batches` and `filter_mask`
pub row_indices: UInt64Builder,
/// Which unique batch id the row belongs to
/// It is necessary to differentiate rows that are distributed the way when they point to the same
/// row index but in not the same batches
pub batch_ids: Vec<usize>,
}

Expand Down Expand Up @@ -790,7 +798,7 @@ fn get_corrected_filter_mask(

Some(corrected_mask.finish())
}

// Only outer joins needs to keep track of processed rows and apply corrected filter mask
_ => None,
}
}
Expand Down Expand Up @@ -888,14 +896,19 @@ impl Stream for SMJStream {
self.freeze_all()?;
if !self.output_record_batches.batches.is_empty() {
let record_batch = self.output_record_batch_and_reset()?;
// For non-filtered join output whenever the target output batch size
// is hit. For filtered join its needed to output on later phase
// because target output batch size can be hit in the middle of
// filtering causing the filtering to be incomplete and causing
// correctness issues
let record_batch = if !(self.filter.is_some()
&& matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi
)) {
record_batch
} else {
RecordBatch::new_empty(Arc::clone(&self.schema))
continue;
};

return Poll::Ready(Some(Ok(record_batch)));
Expand Down Expand Up @@ -1010,6 +1023,8 @@ impl SMJStream {
self.join_metrics.input_rows.add(batch.num_rows());
self.streamed_batch =
StreamedBatch::new(batch, &self.on_streamed);
// Every incoming streaming batch should have its unique id
// Check `JoinedRecordBatches.self.streamed_batch_counter` documentation
self.streamed_batch_counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.streamed_state = StreamedState::Ready;
Expand Down Expand Up @@ -1459,23 +1474,22 @@ impl SMJStream {
};

let columns = if matches!(self.join_type, JoinType::Right) {
buffered_columns.extend(streamed_columns.clone());
buffered_columns.clone()
buffered_columns.extend(streamed_columns);
buffered_columns
} else {
streamed_columns.extend(buffered_columns.clone());
streamed_columns.clone()
streamed_columns.extend(buffered_columns);
streamed_columns
};

let output_batch =
RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?;
let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?;

// Apply join filter if any
if !filter_columns.is_empty() {
if let Some(f) = &self.filter {
// Construct batch with only filter columns
let filter_batch = RecordBatch::try_new(
Arc::new(f.schema().clone()),
filter_columns.clone(),
filter_columns,
)?;

let filter_result = f
Expand Down Expand Up @@ -1574,10 +1588,8 @@ impl SMJStream {
};

// Push the streamed/buffered batch joined nulls to the output
let null_joined_streamed_batch = RecordBatch::try_new(
Arc::clone(&self.schema),
columns.clone(),
)?;
let null_joined_streamed_batch =
RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
self.output_record_batches
.batches
.push(null_joined_streamed_batch);
Expand Down

0 comments on commit 3986741

Please sign in to comment.