Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: Do not return empty record batches from streams #13794

Merged
merged 17 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,12 @@ impl Stream for GroupedHashAggregateStream {
}

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
self.exec_state = ExecutionState::ProducingOutput(batch);
timer.done();
let Some(batch) = extract_ok!(self.emit(to_emit, false))
else {
break 'reading_input;
};
self.exec_state = ExecutionState::ProducingOutput(batch);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think I would find this easier to read if it it avoid a redundant break:

Suggested change
let Some(batch) = extract_ok!(self.emit(to_emit, false))
else {
break 'reading_input;
};
self.exec_state = ExecutionState::ProducingOutput(batch);
if let let Some(batch) = extract_ok!(self.emit(to_emit, false)) {
self.exec_state = ExecutionState::ProducingOutput(batch);
}

// make sure the exec_state just set is not overwritten below
break 'reading_input;
}
Expand Down Expand Up @@ -693,9 +696,12 @@ impl Stream for GroupedHashAggregateStream {
}

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
self.exec_state = ExecutionState::ProducingOutput(batch);
timer.done();
let Some(batch) = extract_ok!(self.emit(to_emit, false))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

else {
break 'reading_input;
};
self.exec_state = ExecutionState::ProducingOutput(batch);
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}
Expand Down Expand Up @@ -768,6 +774,7 @@ impl Stream for GroupedHashAggregateStream {
let output = batch.slice(0, size);
(ExecutionState::ProducingOutput(remaining), output)
};
debug_assert!(output_batch.num_rows() > 0);
return Poll::Ready(Some(Ok(
output_batch.record_output(&self.baseline_metrics)
)));
Expand Down Expand Up @@ -902,14 +909,14 @@ impl GroupedHashAggregateStream {

/// Create an output RecordBatch with the group keys and
/// accumulator states/values specified in emit_to
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Option<RecordBatch>> {
let schema = if spilling {
Arc::clone(&self.spill_state.spill_schema)
} else {
self.schema()
};
if self.group_values.is_empty() {
return Ok(RecordBatch::new_empty(schema));
return Ok(None);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}

let mut output = self.group_values.emit(emit_to)?;
Expand Down Expand Up @@ -937,7 +944,8 @@ impl GroupedHashAggregateStream {
// over the target memory size after emission, we can emit again rather than returning Err.
let _ = self.update_memory_reservation();
let batch = RecordBatch::try_new(schema, output)?;
Ok(batch)
debug_assert!(batch.num_rows() > 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be good to document in comments somewhere the expectation / behavior that no empty record batches are produced

Ok(Some(batch))
}

/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
Expand All @@ -963,7 +971,9 @@ impl GroupedHashAggregateStream {

/// Emit all rows, sort them, and store them on disk.
fn spill(&mut self) -> Result<()> {
let emit = self.emit(EmitTo::All, true)?;
let Some(emit) = self.emit(EmitTo::All, true)? else {
return Ok(());
};
let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?;
let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
// TODO: slice large `sorted` and write to multiple files in parallel
Expand Down Expand Up @@ -1008,8 +1018,9 @@ impl GroupedHashAggregateStream {
{
assert_eq!(self.mode, AggregateMode::Partial);
let n = self.group_values.len() / self.batch_size * self.batch_size;
let batch = self.emit(EmitTo::First(n), false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
if let Some(batch) = self.emit(EmitTo::First(n), false)? {
self.exec_state = ExecutionState::ProducingOutput(batch);
};
}
Ok(())
}
Expand All @@ -1019,7 +1030,9 @@ impl GroupedHashAggregateStream {
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
fn update_merged_stream(&mut self) -> Result<()> {
let batch = self.emit(EmitTo::All, true)?;
let Some(batch) = self.emit(EmitTo::All, true)? else {
return Ok(());
};
// clear up memory for streaming_merge
self.clear_all();
self.update_memory_reservation()?;
Expand Down Expand Up @@ -1067,7 +1080,7 @@ impl GroupedHashAggregateStream {
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
let batch = self.emit(EmitTo::All, false)?;
ExecutionState::ProducingOutput(batch)
batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
} else {
// If spill files exist, stream-merge them.
self.update_merged_stream()?;
Expand Down Expand Up @@ -1096,8 +1109,9 @@ impl GroupedHashAggregateStream {
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
if probe.should_skip() {
let batch = self.emit(EmitTo::All, false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
if let Some(batch) = self.emit(EmitTo::All, false)? {
self.exec_state = ExecutionState::ProducingOutput(batch);
};
}
}

Expand Down
59 changes: 30 additions & 29 deletions datafusion/physical-plan/src/sorts/partial_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,31 +363,31 @@ impl PartialSortStream {
if self.is_closed {
return Poll::Ready(None);
}
let result = match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if let Some(slice_point) =
self.get_slice_point(self.common_prefix_length, &batch)?
{
self.in_mem_batches.push(batch.slice(0, slice_point));
let remaining_batch =
batch.slice(slice_point, batch.num_rows() - slice_point);
let sorted_batch = self.sort_in_mem_batches();
self.in_mem_batches.push(remaining_batch);
sorted_batch
} else {
self.in_mem_batches.push(batch);
Ok(RecordBatch::new_empty(self.schema()))
loop {
return Poll::Ready(Some(match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if let Some(slice_point) =
self.get_slice_point(self.common_prefix_length, &batch)?
{
self.in_mem_batches.push(batch.slice(0, slice_point));
let remaining_batch =
batch.slice(slice_point, batch.num_rows() - slice_point);
let sorted_batch = self.sort_in_mem_batches();
self.in_mem_batches.push(remaining_batch);
sorted_batch
} else {
self.in_mem_batches.push(batch);
continue;
}
}
}
Some(Err(e)) => Err(e),
None => {
self.is_closed = true;
// once input is consumed, sort the rest of the inserted batches
self.sort_in_mem_batches()
}
};

Poll::Ready(Some(result))
Some(Err(e)) => Err(e),
None => {
self.is_closed = true;
// once input is consumed, sort the rest of the inserted batches
self.sort_in_mem_batches()
}
}));
}
}

/// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
Expand All @@ -407,6 +407,7 @@ impl PartialSortStream {
self.is_closed = true;
}
}
debug_assert!(result.num_rows() > 0);
Ok(result)
}

Expand Down Expand Up @@ -731,7 +732,7 @@ mod tests {
let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
assert_eq!(
result.iter().map(|r| r.num_rows()).collect_vec(),
[0, 125, 125, 0, 150]
[125, 125, 150]
);

assert_eq!(
Expand Down Expand Up @@ -760,10 +761,10 @@ mod tests {
nulls_first: false,
};
for (fetch_size, expected_batch_num_rows) in [
(Some(50), vec![0, 50]),
(Some(120), vec![0, 120]),
(Some(150), vec![0, 125, 25]),
(Some(250), vec![0, 125, 125]),
(Some(50), vec![50]),
(Some(120), vec![120]),
(Some(150), vec![125, 25]),
(Some(250), vec![125, 125]),
] {
let partial_sort_executor = PartialSortExec::new(
LexOrdering::new(vec![
Expand Down
36 changes: 20 additions & 16 deletions datafusion/physical-plan/src/topk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,22 @@ impl TopK {
} = self;
let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop

let mut batch = heap.emit()?;
metrics.baseline.output_rows().add(batch.num_rows());

// break into record batches as needed
let mut batches = vec![];
loop {
if batch.num_rows() <= batch_size {
batches.push(Ok(batch));
break;
} else {
batches.push(Ok(batch.slice(0, batch_size)));
let remaining_length = batch.num_rows() - batch_size;
batch = batch.slice(batch_size, remaining_length);
if let Some(mut batch) = heap.emit()? {
metrics.baseline.output_rows().add(batch.num_rows());

loop {
if batch.num_rows() <= batch_size {
batches.push(Ok(batch));
break;
} else {
batches.push(Ok(batch.slice(0, batch_size)));
let remaining_length = batch.num_rows() - batch_size;
batch = batch.slice(batch_size, remaining_length);
}
}
}
};
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(batches),
Expand Down Expand Up @@ -345,21 +346,21 @@ impl TopKHeap {

/// Returns the values stored in this heap, from values low to
/// high, as a single [`RecordBatch`], resetting the inner heap
pub fn emit(&mut self) -> Result<RecordBatch> {
pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
Ok(self.emit_with_state()?.0)
}

/// Returns the values stored in this heap, from values low to
/// high, as a single [`RecordBatch`], and a sorted vec of the
/// current heap's contents
pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec<TopKRow>)> {
pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
let schema = Arc::clone(self.store.schema());

// generate sorted rows
let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();

if self.store.is_empty() {
return Ok((RecordBatch::new_empty(schema), topk_rows));
return Ok((None, topk_rows));
}

// Indices for each row within its respective RecordBatch
Expand Down Expand Up @@ -393,7 +394,7 @@ impl TopKHeap {
.collect::<Result<_>>()?;

let new_batch = RecordBatch::try_new(schema, output_columns)?;
Ok((new_batch, topk_rows))
Ok((Some(new_batch), topk_rows))
}

/// Compact this heap, rewriting all stored batches into a single
Expand All @@ -418,6 +419,9 @@ impl TopKHeap {
// Note: new batch is in the same order as inner
let num_rows = self.inner.len();
let (new_batch, mut topk_rows) = self.emit_with_state()?;
let Some(new_batch) = new_batch else {
return Ok(());
};

// clear all old entries in store (this invalidates all
// store_ids in `inner`)
Expand Down
Loading
Loading