Skip to content

Commit

Permalink
Chore: Do not return empty record batches from streams (apache#13794)
Browse files Browse the repository at this point in the history
* do not emit empty record batches in plans

* change function signatures to Option<RecordBatch> if empty batches are possible

* format code

* shorten code

* change list_unnest_at_level for returning Option value

* add documentation
take concat_batches into compute_aggregates function again

* create unit test for row_hash.rs

* add test for unnest

* add test for unnest

* add test for partial sort

* add test for bounded window agg

* add test for window agg

* apply simplifications and fix typo

* apply simplifications and fix typo
  • Loading branch information
mertak-synnada authored Dec 18, 2024
1 parent 7e0fc14 commit 63ce486
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 111 deletions.
24 changes: 24 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,30 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn aggregate_assert_no_empty_batches() -> Result<()> {
// build plan using DataFrame API
let df = test_table().await?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![
min(col("c12")),
max(col("c12")),
avg(col("c12")),
sum(col("c12")),
count(col("c12")),
count_distinct(col("c12")),
median(col("c12")),
];

let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
// Empty batches should not be produced
for batch in df {
assert!(batch.num_rows() > 0);
}

Ok(())
}

#[tokio::test]
async fn test_aggregate_with_pk() -> Result<()> {
// create the dataframe
Expand Down
43 changes: 43 additions & 0 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,43 @@ async fn unnest_aggregate_columns() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn unnest_no_empty_batches() -> Result<()> {
let mut shape_id_builder = UInt32Builder::new();
let mut tag_id_builder = UInt32Builder::new();

for shape_id in 1..=10 {
for tag_id in 1..=10 {
shape_id_builder.append_value(shape_id as u32);
tag_id_builder.append_value((shape_id * 10 + tag_id) as u32);
}
}

let batch = RecordBatch::try_from_iter(vec![
("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef),
("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef),
])?;

let ctx = SessionContext::new();
ctx.register_batch("shapes", batch)?;
let df = ctx.table("shapes").await?;

let results = df
.clone()
.aggregate(
vec![col("shape_id")],
vec![array_agg(col("tag_id")).alias("tag_id")],
)?
.collect()
.await?;

// Assert that there are no empty batches in result
for rb in results {
assert!(rb.num_rows() > 0);
}
Ok(())
}

#[tokio::test]
async fn unnest_array_agg() -> Result<()> {
let mut shape_id_builder = UInt32Builder::new();
Expand All @@ -1268,6 +1305,12 @@ async fn unnest_array_agg() -> Result<()> {
let df = ctx.table("shapes").await?;

let results = df.clone().collect().await?;

// Assert that there are no empty batches in result
for rb in results.clone() {
assert!(rb.num_rows() > 0);
}

let expected = vec![
"+----------+--------+",
"| shape_id | tag_id |",
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
Linear,
)?);
let task_ctx = ctx.task_ctx();
let mut collected_results =
collect(running_window_exec, task_ctx).await?;
collected_results.retain(|batch| batch.num_rows() > 0);
let collected_results = collect(running_window_exec, task_ctx).await?;
let input_batch_sizes = batches
.iter()
.map(|batch| batch.num_rows())
Expand All @@ -310,6 +308,8 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
.iter()
.map(|batch| batch.num_rows())
.collect::<Vec<_>>();
// There should be no empty batches at results
assert!(result_batch_sizes.iter().all(|e| *e > 0));
if causal {
// For causal window frames, we can generate results immediately
// for each input batch. Hence, batch sizes should match.
Expand Down Expand Up @@ -688,8 +688,8 @@ async fn run_window_test(
let collected_running = collect(running_window_exec, task_ctx)
.await?
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>();
assert!(collected_running.iter().all(|rb| rb.num_rows() > 0));

// BoundedWindowAggExec should produce more chunk than the usual WindowAggExec.
// Otherwise it means that we cannot generate result in running mode.
Expand Down
46 changes: 32 additions & 14 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,13 @@ impl Stream for GroupedHashAggregateStream {
}

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
self.exec_state = ExecutionState::ProducingOutput(batch);
timer.done();
if let Some(batch) =
extract_ok!(self.emit(to_emit, false))
{
self.exec_state =
ExecutionState::ProducingOutput(batch);
};
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}
Expand Down Expand Up @@ -693,9 +697,13 @@ impl Stream for GroupedHashAggregateStream {
}

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
self.exec_state = ExecutionState::ProducingOutput(batch);
timer.done();
if let Some(batch) =
extract_ok!(self.emit(to_emit, false))
{
self.exec_state =
ExecutionState::ProducingOutput(batch);
};
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}
Expand Down Expand Up @@ -768,6 +776,9 @@ impl Stream for GroupedHashAggregateStream {
let output = batch.slice(0, size);
(ExecutionState::ProducingOutput(remaining), output)
};
// Empty record batches should not be emitted.
// They need to be treated as [`Option<RecordBatch>`]es and handled separately
debug_assert!(output_batch.num_rows() > 0);
return Poll::Ready(Some(Ok(
output_batch.record_output(&self.baseline_metrics)
)));
Expand Down Expand Up @@ -902,14 +913,14 @@ impl GroupedHashAggregateStream {

/// Create an output RecordBatch with the group keys and
/// accumulator states/values specified in emit_to
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Option<RecordBatch>> {
let schema = if spilling {
Arc::clone(&self.spill_state.spill_schema)
} else {
self.schema()
};
if self.group_values.is_empty() {
return Ok(RecordBatch::new_empty(schema));
return Ok(None);
}

let mut output = self.group_values.emit(emit_to)?;
Expand Down Expand Up @@ -937,7 +948,8 @@ impl GroupedHashAggregateStream {
// over the target memory size after emission, we can emit again rather than returning Err.
let _ = self.update_memory_reservation();
let batch = RecordBatch::try_new(schema, output)?;
Ok(batch)
debug_assert!(batch.num_rows() > 0);
Ok(Some(batch))
}

/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
Expand All @@ -963,7 +975,9 @@ impl GroupedHashAggregateStream {

/// Emit all rows, sort them, and store them on disk.
fn spill(&mut self) -> Result<()> {
let emit = self.emit(EmitTo::All, true)?;
let Some(emit) = self.emit(EmitTo::All, true)? else {
return Ok(());
};
let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?;
let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
// TODO: slice large `sorted` and write to multiple files in parallel
Expand Down Expand Up @@ -1008,8 +1022,9 @@ impl GroupedHashAggregateStream {
{
assert_eq!(self.mode, AggregateMode::Partial);
let n = self.group_values.len() / self.batch_size * self.batch_size;
let batch = self.emit(EmitTo::First(n), false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
if let Some(batch) = self.emit(EmitTo::First(n), false)? {
self.exec_state = ExecutionState::ProducingOutput(batch);
};
}
Ok(())
}
Expand All @@ -1019,7 +1034,9 @@ impl GroupedHashAggregateStream {
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
fn update_merged_stream(&mut self) -> Result<()> {
let batch = self.emit(EmitTo::All, true)?;
let Some(batch) = self.emit(EmitTo::All, true)? else {
return Ok(());
};
// clear up memory for streaming_merge
self.clear_all();
self.update_memory_reservation()?;
Expand Down Expand Up @@ -1067,7 +1084,7 @@ impl GroupedHashAggregateStream {
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
let batch = self.emit(EmitTo::All, false)?;
ExecutionState::ProducingOutput(batch)
batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
} else {
// If spill files exist, stream-merge them.
self.update_merged_stream()?;
Expand Down Expand Up @@ -1096,8 +1113,9 @@ impl GroupedHashAggregateStream {
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
if probe.should_skip() {
let batch = self.emit(EmitTo::All, false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
if let Some(batch) = self.emit(EmitTo::All, false)? {
self.exec_state = ExecutionState::ProducingOutput(batch);
};
}
}

Expand Down
97 changes: 68 additions & 29 deletions datafusion/physical-plan/src/sorts/partial_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,31 +363,31 @@ impl PartialSortStream {
if self.is_closed {
return Poll::Ready(None);
}
let result = match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if let Some(slice_point) =
self.get_slice_point(self.common_prefix_length, &batch)?
{
self.in_mem_batches.push(batch.slice(0, slice_point));
let remaining_batch =
batch.slice(slice_point, batch.num_rows() - slice_point);
let sorted_batch = self.sort_in_mem_batches();
self.in_mem_batches.push(remaining_batch);
sorted_batch
} else {
self.in_mem_batches.push(batch);
Ok(RecordBatch::new_empty(self.schema()))
loop {
return Poll::Ready(Some(match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if let Some(slice_point) =
self.get_slice_point(self.common_prefix_length, &batch)?
{
self.in_mem_batches.push(batch.slice(0, slice_point));
let remaining_batch =
batch.slice(slice_point, batch.num_rows() - slice_point);
let sorted_batch = self.sort_in_mem_batches();
self.in_mem_batches.push(remaining_batch);
sorted_batch
} else {
self.in_mem_batches.push(batch);
continue;
}
}
}
Some(Err(e)) => Err(e),
None => {
self.is_closed = true;
// once input is consumed, sort the rest of the inserted batches
self.sort_in_mem_batches()
}
};

Poll::Ready(Some(result))
Some(Err(e)) => Err(e),
None => {
self.is_closed = true;
// once input is consumed, sort the rest of the inserted batches
self.sort_in_mem_batches()
}
}));
}
}

/// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
Expand All @@ -407,6 +407,9 @@ impl PartialSortStream {
self.is_closed = true;
}
}
// Empty record batches should not be emitted.
// They need to be treated as [`Option<RecordBatch>`]es and handle separately
debug_assert!(result.num_rows() > 0);
Ok(result)
}

Expand Down Expand Up @@ -731,7 +734,7 @@ mod tests {
let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
assert_eq!(
result.iter().map(|r| r.num_rows()).collect_vec(),
[0, 125, 125, 0, 150]
[125, 125, 150]
);

assert_eq!(
Expand Down Expand Up @@ -760,10 +763,10 @@ mod tests {
nulls_first: false,
};
for (fetch_size, expected_batch_num_rows) in [
(Some(50), vec![0, 50]),
(Some(120), vec![0, 120]),
(Some(150), vec![0, 125, 25]),
(Some(250), vec![0, 125, 125]),
(Some(50), vec![50]),
(Some(120), vec![120]),
(Some(150), vec![125, 25]),
(Some(250), vec![125, 125]),
] {
let partial_sort_executor = PartialSortExec::new(
LexOrdering::new(vec![
Expand Down Expand Up @@ -810,6 +813,42 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_partial_sort_no_empty_batches() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let mem_exec = prepare_partitioned_input();
let schema = mem_exec.schema();
let option_asc = SortOptions {
descending: false,
nulls_first: false,
};
let fetch_size = Some(250);
let partial_sort_executor = PartialSortExec::new(
LexOrdering::new(vec![
PhysicalSortExpr {
expr: col("a", &schema)?,
options: option_asc,
},
PhysicalSortExpr {
expr: col("c", &schema)?,
options: option_asc,
},
]),
Arc::clone(&mem_exec),
1,
)
.with_fetch(fetch_size);

let partial_sort_exec =
Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
for rb in result {
assert!(rb.num_rows() > 0);
}

Ok(())
}

#[tokio::test]
async fn test_sort_metadata() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
Expand Down
Loading

0 comments on commit 63ce486

Please sign in to comment.