Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/no_grouping.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Aggregate without grouping columns
19
20
use crate::aggregates::{
21
    aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
22
    AggregateMode,
23
};
24
use crate::metrics::{BaselineMetrics, RecordOutput};
25
use crate::{RecordBatchStream, SendableRecordBatchStream};
26
use arrow::datatypes::SchemaRef;
27
use arrow::record_batch::RecordBatch;
28
use datafusion_common::Result;
29
use datafusion_execution::TaskContext;
30
use datafusion_physical_expr::PhysicalExpr;
31
use futures::stream::BoxStream;
32
use std::borrow::Cow;
33
use std::sync::Arc;
34
use std::task::{Context, Poll};
35
36
use crate::filter::batch_filter;
37
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
38
use futures::stream::{Stream, StreamExt};
39
40
use super::AggregateExec;
41
42
/// stream struct for aggregation without grouping columns
43
pub(crate) struct AggregateStream {
44
    stream: BoxStream<'static, Result<RecordBatch>>,
45
    schema: SchemaRef,
46
}
47
48
/// Actual implementation of [`AggregateStream`].
49
///
50
/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
51
/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
52
/// [`futures::stream::unfold`].
53
///
54
/// The latter requires a state object, which is [`AggregateStreamInner`].
55
struct AggregateStreamInner {
56
    schema: SchemaRef,
57
    mode: AggregateMode,
58
    input: SendableRecordBatchStream,
59
    baseline_metrics: BaselineMetrics,
60
    aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
61
    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
62
    accumulators: Vec<AccumulatorItem>,
63
    reservation: MemoryReservation,
64
    finished: bool,
65
}
66
67
impl AggregateStream {
68
    /// Create a new AggregateStream
69
2
    pub fn new(
70
2
        agg: &AggregateExec,
71
2
        context: Arc<TaskContext>,
72
2
        partition: usize,
73
2
    ) -> Result<Self> {
74
2
        let agg_schema = Arc::clone(&agg.schema);
75
2
        let agg_filter_expr = agg.filter_expr.clone();
76
2
77
2
        let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
78
2
        let input = agg.input.execute(partition, Arc::clone(&context))
?0
;
79
80
2
        let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)
?0
;
81
2
        let filter_expressions = match agg.mode {
82
            AggregateMode::Partial
83
            | AggregateMode::Single
84
2
            | AggregateMode::SinglePartitioned => agg_filter_expr,
85
            AggregateMode::Final | AggregateMode::FinalPartitioned => {
86
0
                vec![None; agg.aggr_expr.len()]
87
            }
88
        };
89
2
        let accumulators = create_accumulators(&agg.aggr_expr)
?0
;
90
91
2
        let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]"))
92
2
            .register(context.memory_pool());
93
2
94
2
        let inner = AggregateStreamInner {
95
2
            schema: Arc::clone(&agg.schema),
96
2
            mode: agg.mode,
97
2
            input,
98
2
            baseline_metrics,
99
2
            aggregate_expressions,
100
2
            filter_expressions,
101
2
            accumulators,
102
2
            reservation,
103
2
            finished: false,
104
2
        };
105
2
        let stream = futures::stream::unfold(inner, |mut this| async move {
106
2
            if this.finished {
107
0
                return None;
108
2
            }
109
2
110
2
            let elapsed_compute = this.baseline_metrics.elapsed_compute();
111
112
            loop {
113
2
                let 
result1
= match this.input.next().
await1
{
114
1
                    Some(Ok(batch)) => {
115
1
                        let timer = elapsed_compute.timer();
116
1
                        let result = aggregate_batch(
117
1
                            &this.mode,
118
1
                            batch,
119
1
                            &mut this.accumulators,
120
1
                            &this.aggregate_expressions,
121
1
                            &this.filter_expressions,
122
1
                        );
123
1
124
1
                        timer.done();
125
1
126
1
                        // allocate memory
127
1
                        // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
128
1
                        // overshooting a bit. Also this means we either store the whole record batch or not.
129
1
                        match result
130
1
                            .and_then(|allocated| this.reservation.try_grow(allocated))
131
                        {
132
0
                            Ok(_) => continue,
133
1
                            Err(e) => Err(e),
134
                        }
135
                    }
136
0
                    Some(Err(e)) => Err(e),
137
                    None => {
138
0
                        this.finished = true;
139
0
                        let timer = this.baseline_metrics.elapsed_compute().timer();
140
0
                        let result =
141
0
                            finalize_aggregation(&mut this.accumulators, &this.mode)
142
0
                                .and_then(|columns| {
143
0
                                    RecordBatch::try_new(
144
0
                                        Arc::clone(&this.schema),
145
0
                                        columns,
146
0
                                    )
147
0
                                    .map_err(Into::into)
148
0
                                })
149
0
                                .record_output(&this.baseline_metrics);
150
0
151
0
                        timer.done();
152
0
153
0
                        result
154
                    }
155
                };
156
157
1
                this.finished = true;
158
1
                return Some((result, this));
159
            }
160
3
        });
161
2
162
2
        // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
163
2
        let stream = stream.fuse();
164
2
        let stream = Box::pin(stream);
165
2
166
2
        Ok(Self {
167
2
            schema: agg_schema,
168
2
            stream,
169
2
        })
170
2
    }
171
}
172
173
impl Stream for AggregateStream {
174
    type Item = Result<RecordBatch>;
175
176
3
    fn poll_next(
177
3
        mut self: std::pin::Pin<&mut Self>,
178
3
        cx: &mut Context<'_>,
179
3
    ) -> Poll<Option<Self::Item>> {
180
3
        let this = &mut *self;
181
3
        this.stream.poll_next_unpin(cx)
182
3
    }
183
}
184
185
impl RecordBatchStream for AggregateStream {
186
0
    fn schema(&self) -> SchemaRef {
187
0
        Arc::clone(&self.schema)
188
0
    }
189
}
190
191
/// Perform group-by aggregation for the given [`RecordBatch`].
192
///
193
/// If successful, this returns the additional number of bytes that were allocated during this process.
194
///
195
/// TODO: Make this a member function
196
1
fn aggregate_batch(
197
1
    mode: &AggregateMode,
198
1
    batch: RecordBatch,
199
1
    accumulators: &mut [AccumulatorItem],
200
1
    expressions: &[Vec<Arc<dyn PhysicalExpr>>],
201
1
    filters: &[Option<Arc<dyn PhysicalExpr>>],
202
1
) -> Result<usize> {
203
1
    let mut allocated = 0usize;
204
1
205
1
    // 1.1 iterate accumulators and respective expressions together
206
1
    // 1.2 filter the batch if necessary
207
1
    // 1.3 evaluate expressions
208
1
    // 1.4 update / merge accumulators with the expressions' values
209
1
210
1
    // 1.1
211
1
    accumulators
212
1
        .iter_mut()
213
1
        .zip(expressions)
214
1
        .zip(filters)
215
1
        .try_for_each(|((accum, expr), filter)| {
216
            // 1.2
217
1
            let batch = match filter {
218
0
                Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
219
1
                None => Cow::Borrowed(&batch),
220
            };
221
222
            // 1.3
223
1
            let values = &expr
224
1
                .iter()
225
1
                .map(|e| {
226
1
                    e.evaluate(&batch)
227
1
                        .and_then(|v| v.into_array(batch.num_rows()))
228
1
                })
229
1
                .collect::<Result<Vec<_>>>()
?0
;
230
231
            // 1.4
232
1
            let size_pre = accum.size();
233
1
            let res = match mode {
234
                AggregateMode::Partial
235
                | AggregateMode::Single
236
1
                | AggregateMode::SinglePartitioned => accum.update_batch(values),
237
                AggregateMode::Final | AggregateMode::FinalPartitioned => {
238
0
                    accum.merge_batch(values)
239
                }
240
            };
241
1
            let size_post = accum.size();
242
1
            allocated += size_post.saturating_sub(size_pre);
243
1
            res
244
1
        })
?0
;
245
246
1
    Ok(allocated)
247
1
}