diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 312a3263aa81..6d7c3c21bc2f 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -295,6 +295,8 @@ impl AggregateExec { self.aggr_expr.clone(), input, baseline_metrics, + context, + partition, )?)) } else if self.row_aggregate_supported() { Ok(StreamType::GroupedHashAggregateStreamV2( @@ -737,7 +739,7 @@ mod tests { use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; - use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count}; + use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median}; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use futures::{FutureExt, Stream}; use std::any::Any; @@ -1131,12 +1133,20 @@ mod tests { ); let task_ctx = session_ctx.task_ctx(); - let groups = PhysicalGroupBy { + let groups_none = PhysicalGroupBy::default(); + let groups_some = PhysicalGroupBy { expr: vec![(col("a", &input_schema)?, "a".to_string())], null_expr: vec![], groups: vec![vec![false]], }; + // something that allocates within the aggregator + let aggregates_v0: Vec> = vec![Arc::new(Median::new( + col("a", &input_schema)?, + "MEDIAN(a)".to_string(), + DataType::UInt32, + ))]; + // use slow-path in `hash.rs` let aggregates_v1: Vec> = vec![Arc::new(ApproxDistinct::new( @@ -1152,10 +1162,14 @@ mod tests { DataType::Float64, ))]; - for (version, aggregates) in [(1, aggregates_v1), (2, aggregates_v2)] { + for (version, groups, aggregates) in [ + (0, groups_none, aggregates_v0), + (1, groups_some.clone(), aggregates_v1), + (2, groups_some, aggregates_v2), + ] { let partial_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Partial, - groups.clone(), + groups, aggregates, input.clone(), input_schema.clone(), @@ -1165,6 +1179,9 @@ mod tests { // ensure that we really got the version we wanted match version { + 0 => { + assert!(matches!(stream, StreamType::AggregateStream(_))); + } 1 => { assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); } diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs index f687c982c220..8c3556bb6f21 100644 --- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs +++ b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs @@ -17,6 +17,9 @@ //! Aggregate without grouping columns +use crate::execution::context::TaskContext; +use crate::execution::memory_manager::proxy::MemoryConsumerProxy; +use crate::execution::MemoryConsumerId; use crate::physical_plan::aggregates::{ aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem, AggregateMode, @@ -28,22 +31,31 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; +use futures::stream::BoxStream; use std::sync::Arc; use std::task::{Context, Poll}; -use futures::{ - ready, - stream::{Stream, StreamExt}, -}; +use futures::stream::{Stream, StreamExt}; /// stream struct for aggregation without grouping columns pub(crate) struct AggregateStream { + stream: BoxStream<'static, ArrowResult>, + schema: SchemaRef, +} + +/// Actual implementation of [`AggregateStream`]. +/// +/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem +/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with +/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`]. +struct AggregateStreamInner { schema: SchemaRef, mode: AggregateMode, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, aggregate_expressions: Vec>>, accumulators: Vec, + memory_consumer: MemoryConsumerProxy, finished: bool, } @@ -55,19 +67,87 @@ impl AggregateStream { aggr_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, + context: Arc, + partition: usize, ) -> datafusion_common::Result { let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?; let accumulators = create_accumulators(&aggr_expr)?; - - Ok(Self { - schema, + let memory_consumer = MemoryConsumerProxy::new( + "AggregationState", + MemoryConsumerId::new(partition), + Arc::clone(&context.runtime_env().memory_manager), + ); + + let inner = AggregateStreamInner { + schema: Arc::clone(&schema), mode, input, baseline_metrics, aggregate_expressions, accumulators, + memory_consumer, finished: false, - }) + }; + let stream = futures::stream::unfold(inner, |mut this| async move { + if this.finished { + return None; + } + + let elapsed_compute = this.baseline_metrics.elapsed_compute(); + + loop { + let result = match this.input.next().await { + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + let result = aggregate_batch( + &this.mode, + &batch, + &mut this.accumulators, + &this.aggregate_expressions, + ); + + timer.done(); + + // allocate memory + // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with + // overshooting a bit. Also this means we either store the whole record batch or not. + let result = match result { + Ok(allocated) => this.memory_consumer.alloc(allocated).await, + Err(e) => Err(e), + }; + + match result { + Ok(_) => continue, + Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + } + } + Some(Err(e)) => Err(e), + None => { + this.finished = true; + let timer = this.baseline_metrics.elapsed_compute().timer(); + let result = finalize_aggregation(&this.accumulators, &this.mode) + .map_err(|e| ArrowError::ExternalError(Box::new(e))) + .and_then(|columns| { + RecordBatch::try_new(this.schema.clone(), columns) + }) + .record_output(&this.baseline_metrics); + + timer.done(); + + result + } + }; + + this.finished = true; + return Some((result, this)); + } + }); + + // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream. + let stream = stream.fuse(); + let stream = Box::pin(stream); + + Ok(Self { schema, stream }) } } @@ -79,49 +159,7 @@ impl Stream for AggregateStream { cx: &mut Context<'_>, ) -> Poll> { let this = &mut *self; - if this.finished { - return Poll::Ready(None); - } - - let elapsed_compute = this.baseline_metrics.elapsed_compute(); - - loop { - let result = match ready!(this.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = aggregate_batch( - &this.mode, - &batch, - &mut this.accumulators, - &this.aggregate_expressions, - ); - - timer.done(); - - match result { - Ok(_) => continue, - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), - } - } - Some(Err(e)) => Err(e), - None => { - this.finished = true; - let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = finalize_aggregation(&this.accumulators, &this.mode) - .map_err(|e| ArrowError::ExternalError(Box::new(e))) - .and_then(|columns| { - RecordBatch::try_new(this.schema.clone(), columns) - }) - .record_output(&this.baseline_metrics); - - timer.done(); - result - } - }; - - this.finished = true; - return Poll::Ready(Some(result)); - } + this.stream.poll_next_unpin(cx) } } @@ -131,13 +169,19 @@ impl RecordBatchStream for AggregateStream { } } +/// Perform group-by aggregation for the given [`RecordBatch`]. +/// +/// If successfull, this returns the additional number of bytes that were allocated during this process. +/// /// TODO: Make this a member function fn aggregate_batch( mode: &AggregateMode, batch: &RecordBatch, accumulators: &mut [AccumulatorItem], expressions: &[Vec>], -) -> Result<()> { +) -> Result { + let mut allocated = 0usize; + // 1.1 iterate accumulators and respective expressions together // 1.2 evaluate expressions // 1.3 update / merge accumulators with the expressions' values @@ -155,11 +199,17 @@ fn aggregate_batch( .collect::>>()?; // 1.3 - match mode { + let size_pre = accum.size(); + let res = match mode { AggregateMode::Partial => accum.update_batch(values), AggregateMode::Final | AggregateMode::FinalPartitioned => { accum.merge_batch(values) } - } - }) + }; + let size_post = accum.size(); + allocated += size_post.saturating_sub(size_pre); + res + })?; + + Ok(allocated) }