Skip to content

Commit

Permalink
Use partial aggregation schema for spilling to avoid column mismatch …
Browse files Browse the repository at this point in the history
…in GroupedHashAggregateStream (#13995)

* Refactor spill handling in GroupedHashAggregateStream to use partial aggregate schema

* Implement aggregate functions with spill handling in tests

* Add tests for aggregate functions with and without spill handling

* Move test related imports into mod test

* Rename spill pool test functions for clarity and consistency

* Refactor aggregate function imports to use fully qualified paths

* Remove outdated comments regarding input batch schema for spilling in GroupedHashAggregateStream

* Update aggregate test to use AVG instead of MAX

* assert spill count

* Refactor partial aggregate schema creation to use create_schema function

* Refactor partial aggregation schema creation and remove redundant function

* Remove unused import of Schema from arrow::datatypes in row_hash.rs

* move spill pool testing for aggregate functions to physical-plan/src/aggregates

* Use Arc::clone for schema references in aggregate functions
  • Loading branch information
kosiew authored Jan 8, 2025
1 parent b5fe4a9 commit 81b50c4
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 11 deletions.
9 changes: 4 additions & 5 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
#[cfg(feature = "parquet")]
mod parquet;

use std::any::Any;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::file_format::csv::CsvFormatFactory;
Expand All @@ -43,6 +38,10 @@ use crate::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;
use std::any::Any;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
Expand Down
134 changes: 134 additions & 0 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,7 @@ mod tests {
use crate::execution_plan::Boundedness;
use crate::expressions::col;
use crate::memory::MemoryExec;
use crate::metrics::MetricValue;
use crate::test::assert_is_pending;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::RecordBatchStream;
Expand Down Expand Up @@ -2783,4 +2784,137 @@ mod tests {
assert_eq!(aggr_schema, expected_schema);
Ok(())
}

// test for https://github.com/apache/datafusion/issues/13949
async fn run_test_with_spill_pool_if_necessary(
pool_size: usize,
expect_spill: bool,
) -> Result<()> {
fn create_record_batch(
schema: &Arc<Schema>,
data: (Vec<u32>, Vec<f64>),
) -> Result<RecordBatch> {
Ok(RecordBatch::try_new(
Arc::clone(schema),
vec![
Arc::new(UInt32Array::from(data.0)),
Arc::new(Float64Array::from(data.1)),
],
)?)
}

let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));

let batches = vec![
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
];
let plan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(&[batches], Arc::clone(&schema), None)?);

let grouping_set = PhysicalGroupBy::new(
vec![(col("a", &schema)?, "a".to_string())],
vec![],
vec![vec![false]],
);

// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::min_max::min_udaf(),
vec![col("b", &schema)?],
)
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("AVG(b)")
.build()?,
),
];

let single_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
grouping_set,
aggregates,
vec![None, None],
plan,
Arc::clone(&schema),
)?);

let batch_size = 2;
let memory_pool = Arc::new(FairSpillPool::new(pool_size));
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
.with_runtime(Arc::new(
RuntimeEnvBuilder::new()
.with_memory_pool(memory_pool)
.build()?,
)),
);

let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;

assert_spill_count_metric(expect_spill, single_aggregate);

#[rustfmt::skip]
assert_batches_sorted_eq!(
[
"+---+--------+--------+",
"| a | MIN(b) | AVG(b) |",
"+---+--------+--------+",
"| 2 | 1.0 | 1.0 |",
"| 3 | 2.0 | 2.0 |",
"| 4 | 3.0 | 3.5 |",
"+---+--------+--------+",
],
&result
);

Ok(())
}

fn assert_spill_count_metric(
expect_spill: bool,
single_aggregate: Arc<AggregateExec>,
) {
if let Some(metrics_set) = single_aggregate.metrics() {
let mut spill_count = 0;

// Inspect metrics for SpillCount
for metric in metrics_set.iter() {
if let MetricValue::SpillCount(count) = metric.value() {
spill_count = count.value();
break;
}
}

if expect_spill && spill_count == 0 {
panic!(
"Expected spill but SpillCount metric not found or SpillCount was 0."
);
} else if !expect_spill && spill_count > 0 {
panic!("Expected no spill but found SpillCount metric with value greater than 0.");
}
} else {
panic!("No metrics returned from the operator; cannot verify spilling.");
}
}

#[tokio::test]
async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
// test with spill
run_test_with_spill_pool_if_necessary(2_000, true).await?;
// test without spill
run_test_with_spill_pool_if_necessary(20_000, false).await?;
Ok(())
}
}
34 changes: 28 additions & 6 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use std::vec;
use crate::aggregates::group_values::{new_group_values, GroupValues};
use crate::aggregates::order::GroupOrderingFull;
use crate::aggregates::{
evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode,
PhysicalGroupBy,
create_schema, evaluate_group_by, evaluate_many, evaluate_optional, group_schema,
AggregateMode, PhysicalGroupBy,
};
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::sort::sort_batch;
Expand Down Expand Up @@ -490,6 +490,31 @@ impl GroupedHashAggregateStream {
.collect::<Result<_>>()?;

let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;

// fix https://github.com/apache/datafusion/issues/13949
// Builds a **partial aggregation** schema by combining the group columns and
// the accumulator state columns produced by each aggregate expression.
//
// # Why Partial Aggregation Schema Is Needed
//
// In a multi-stage (partial/final) aggregation strategy, each partial-aggregate
// operator produces *intermediate* states (e.g., partial sums, counts) rather
// than final scalar values. These extra columns do **not** exist in the original
// input schema (which may be something like `[colA, colB, ...]`). Instead,
// each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`).
//
// Therefore, when we spill these intermediate states or pass them to another
// aggregation operator, we must use a schema that includes both the group
// columns **and** the partial-state columns.
let partial_agg_schema = create_schema(
&agg.input().schema(),
&agg_group_by,
&aggregate_exprs,
AggregateMode::Partial,
)?;

let partial_agg_schema = Arc::new(partial_agg_schema);

let spill_expr = group_schema
.fields
.into_iter()
Expand Down Expand Up @@ -522,7 +547,7 @@ impl GroupedHashAggregateStream {
let spill_state = SpillState {
spills: vec![],
spill_expr,
spill_schema: Arc::clone(&agg_schema),
spill_schema: partial_agg_schema,
is_stream_merging: false,
merging_aggregate_arguments,
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
Expand Down Expand Up @@ -964,9 +989,6 @@ impl GroupedHashAggregateStream {
&& self.update_memory_reservation().is_err()
{
assert_ne!(self.mode, AggregateMode::Partial);
// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();
self.spill()?;
self.clear_shrink(batch);
}
Expand Down

0 comments on commit 81b50c4

Please sign in to comment.