Skip to content

Commit

Permalink
feat: ResourceExhausted for memory limit in AggregateStream
Browse files Browse the repository at this point in the history
Closes apache#3940.
  • Loading branch information
crepererum committed Nov 28, 2022
1 parent dafd957 commit 2ff37a4
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 59 deletions.
25 changes: 21 additions & 4 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ impl AggregateExec {
self.aggr_expr.clone(),
input,
baseline_metrics,
context,
partition,
)?))
} else if self.row_aggregate_supported() {
Ok(StreamType::GroupedHashAggregateStreamV2(
Expand Down Expand Up @@ -737,7 +739,7 @@ mod tests {
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
use futures::{FutureExt, Stream};
use std::any::Any;
Expand Down Expand Up @@ -1131,12 +1133,20 @@ mod tests {
);
let task_ctx = session_ctx.task_ctx();

let groups = PhysicalGroupBy {
let groups_none = PhysicalGroupBy::default();
let groups_some = PhysicalGroupBy {
expr: vec![(col("a", &input_schema)?, "a".to_string())],
null_expr: vec![],
groups: vec![vec![false]],
};

// something that allocates within the aggregator
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Median::new(
col("a", &input_schema)?,
"MEDIAN(a)".to_string(),
DataType::UInt32,
))];

// use slow-path in `hash.rs`
let aggregates_v1: Vec<Arc<dyn AggregateExpr>> =
vec![Arc::new(ApproxDistinct::new(
Expand All @@ -1152,10 +1162,14 @@ mod tests {
DataType::Float64,
))];

for (version, aggregates) in [(1, aggregates_v1), (2, aggregates_v2)] {
for (version, groups, aggregates) in [
(0, groups_none, aggregates_v0),
(1, groups_some.clone(), aggregates_v1),
(2, groups_some, aggregates_v2),
] {
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
groups,
aggregates,
input.clone(),
input_schema.clone(),
Expand All @@ -1165,6 +1179,9 @@ mod tests {

// ensure that we really got the version we wanted
match version {
0 => {
assert!(matches!(stream, StreamType::AggregateStream(_)));
}
1 => {
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
}
Expand Down
160 changes: 105 additions & 55 deletions datafusion/core/src/physical_plan/aggregates/no_grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

//! Aggregate without grouping columns
use crate::execution::context::TaskContext;
use crate::execution::memory_manager::proxy::MemoryConsumerProxy;
use crate::execution::MemoryConsumerId;
use crate::physical_plan::aggregates::{
aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
AggregateMode,
Expand All @@ -28,22 +31,31 @@ use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
use futures::stream::BoxStream;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures::{
ready,
stream::{Stream, StreamExt},
};
use futures::stream::{Stream, StreamExt};

/// stream struct for aggregation without grouping columns
pub(crate) struct AggregateStream {
stream: BoxStream<'static, ArrowResult<RecordBatch>>,
schema: SchemaRef,
}

/// Actual implementation of [`AggregateStream`].
///
/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`].
struct AggregateStreamInner {
schema: SchemaRef,
mode: AggregateMode,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
accumulators: Vec<AccumulatorItem>,
memory_consumer: MemoryConsumerProxy,
finished: bool,
}

Expand All @@ -55,19 +67,87 @@ impl AggregateStream {
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
context: Arc<TaskContext>,
partition: usize,
) -> datafusion_common::Result<Self> {
let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?;
let accumulators = create_accumulators(&aggr_expr)?;

Ok(Self {
schema,
let memory_consumer = MemoryConsumerProxy::new(
"AggregationState",
MemoryConsumerId::new(partition),
Arc::clone(&context.runtime_env().memory_manager),
);

let inner = AggregateStreamInner {
schema: Arc::clone(&schema),
mode,
input,
baseline_metrics,
aggregate_expressions,
accumulators,
memory_consumer,
finished: false,
})
};
let stream = futures::stream::unfold(inner, |mut this| async move {
if this.finished {
return None;
}

let elapsed_compute = this.baseline_metrics.elapsed_compute();

loop {
let result = match this.input.next().await {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = aggregate_batch(
&this.mode,
&batch,
&mut this.accumulators,
&this.aggregate_expressions,
);

timer.done();

// allocate memory
// This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
// overshooting a bit. Also this means we either store the whole record batch or not.
let result = match result {
Ok(allocated) => this.memory_consumer.alloc(allocated).await,
Err(e) => Err(e),
};

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = finalize_aggregation(&this.accumulators, &this.mode)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))
.and_then(|columns| {
RecordBatch::try_new(this.schema.clone(), columns)
})
.record_output(&this.baseline_metrics);

timer.done();

result
}
};

this.finished = true;
return Some((result, this));
}
});

// seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
let stream = stream.fuse();
let stream = Box::pin(stream);

Ok(Self { schema, stream })
}
}

Expand All @@ -79,49 +159,7 @@ impl Stream for AggregateStream {
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if this.finished {
return Poll::Ready(None);
}

let elapsed_compute = this.baseline_metrics.elapsed_compute();

loop {
let result = match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = aggregate_batch(
&this.mode,
&batch,
&mut this.accumulators,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = finalize_aggregation(&this.accumulators, &this.mode)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))
.and_then(|columns| {
RecordBatch::try_new(this.schema.clone(), columns)
})
.record_output(&this.baseline_metrics);

timer.done();
result
}
};

this.finished = true;
return Poll::Ready(Some(result));
}
this.stream.poll_next_unpin(cx)
}
}

Expand All @@ -131,13 +169,19 @@ impl RecordBatchStream for AggregateStream {
}
}

/// Perform group-by aggregation for the given [`RecordBatch`].
///
/// If successfull, this returns the additional number of bytes that were allocated during this process.
///
/// TODO: Make this a member function
fn aggregate_batch(
mode: &AggregateMode,
batch: &RecordBatch,
accumulators: &mut [AccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<()> {
) -> Result<usize> {
let mut allocated = 0usize;

// 1.1 iterate accumulators and respective expressions together
// 1.2 evaluate expressions
// 1.3 update / merge accumulators with the expressions' values
Expand All @@ -155,11 +199,17 @@ fn aggregate_batch(
.collect::<Result<Vec<_>>>()?;

// 1.3
match mode {
let size_pre = accum.size();
let res = match mode {
AggregateMode::Partial => accum.update_batch(values),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accum.merge_batch(values)
}
}
})
};
let size_post = accum.size();
allocated += size_post.saturating_sub(size_pre);
res
})?;

Ok(allocated)
}

0 comments on commit 2ff37a4

Please sign in to comment.