/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/no_grouping.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Aggregate without grouping columns |
19 | | |
20 | | use crate::aggregates::{ |
21 | | aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem, |
22 | | AggregateMode, |
23 | | }; |
24 | | use crate::metrics::{BaselineMetrics, RecordOutput}; |
25 | | use crate::{RecordBatchStream, SendableRecordBatchStream}; |
26 | | use arrow::datatypes::SchemaRef; |
27 | | use arrow::record_batch::RecordBatch; |
28 | | use datafusion_common::Result; |
29 | | use datafusion_execution::TaskContext; |
30 | | use datafusion_physical_expr::PhysicalExpr; |
31 | | use futures::stream::BoxStream; |
32 | | use std::borrow::Cow; |
33 | | use std::sync::Arc; |
34 | | use std::task::{Context, Poll}; |
35 | | |
36 | | use crate::filter::batch_filter; |
37 | | use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; |
38 | | use futures::stream::{Stream, StreamExt}; |
39 | | |
40 | | use super::AggregateExec; |
41 | | |
42 | | /// stream struct for aggregation without grouping columns |
43 | | pub(crate) struct AggregateStream { |
44 | | stream: BoxStream<'static, Result<RecordBatch>>, |
45 | | schema: SchemaRef, |
46 | | } |
47 | | |
48 | | /// Actual implementation of [`AggregateStream`]. |
49 | | /// |
50 | | /// This is wrapped into yet another struct because we need to interact with the async memory management subsystem |
51 | | /// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with |
52 | | /// [`futures::stream::unfold`]. |
53 | | /// |
54 | | /// The latter requires a state object, which is [`AggregateStreamInner`]. |
55 | | struct AggregateStreamInner { |
56 | | schema: SchemaRef, |
57 | | mode: AggregateMode, |
58 | | input: SendableRecordBatchStream, |
59 | | baseline_metrics: BaselineMetrics, |
60 | | aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>, |
61 | | filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>, |
62 | | accumulators: Vec<AccumulatorItem>, |
63 | | reservation: MemoryReservation, |
64 | | finished: bool, |
65 | | } |
66 | | |
67 | | impl AggregateStream { |
68 | | /// Create a new AggregateStream |
69 | 2 | pub fn new( |
70 | 2 | agg: &AggregateExec, |
71 | 2 | context: Arc<TaskContext>, |
72 | 2 | partition: usize, |
73 | 2 | ) -> Result<Self> { |
74 | 2 | let agg_schema = Arc::clone(&agg.schema); |
75 | 2 | let agg_filter_expr = agg.filter_expr.clone(); |
76 | 2 | |
77 | 2 | let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); |
78 | 2 | let input = agg.input.execute(partition, Arc::clone(&context))?0 ; |
79 | | |
80 | 2 | let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?0 ; |
81 | 2 | let filter_expressions = match agg.mode { |
82 | | AggregateMode::Partial |
83 | | | AggregateMode::Single |
84 | 2 | | AggregateMode::SinglePartitioned => agg_filter_expr, |
85 | | AggregateMode::Final | AggregateMode::FinalPartitioned => { |
86 | 0 | vec![None; agg.aggr_expr.len()] |
87 | | } |
88 | | }; |
89 | 2 | let accumulators = create_accumulators(&agg.aggr_expr)?0 ; |
90 | | |
91 | 2 | let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]")) |
92 | 2 | .register(context.memory_pool()); |
93 | 2 | |
94 | 2 | let inner = AggregateStreamInner { |
95 | 2 | schema: Arc::clone(&agg.schema), |
96 | 2 | mode: agg.mode, |
97 | 2 | input, |
98 | 2 | baseline_metrics, |
99 | 2 | aggregate_expressions, |
100 | 2 | filter_expressions, |
101 | 2 | accumulators, |
102 | 2 | reservation, |
103 | 2 | finished: false, |
104 | 2 | }; |
105 | 2 | let stream = futures::stream::unfold(inner, |mut this| async move { |
106 | 2 | if this.finished { |
107 | 0 | return None; |
108 | 2 | } |
109 | 2 | |
110 | 2 | let elapsed_compute = this.baseline_metrics.elapsed_compute(); |
111 | | |
112 | | loop { |
113 | 2 | let result1 = match this.input.next().await1 { |
114 | 1 | Some(Ok(batch)) => { |
115 | 1 | let timer = elapsed_compute.timer(); |
116 | 1 | let result = aggregate_batch( |
117 | 1 | &this.mode, |
118 | 1 | batch, |
119 | 1 | &mut this.accumulators, |
120 | 1 | &this.aggregate_expressions, |
121 | 1 | &this.filter_expressions, |
122 | 1 | ); |
123 | 1 | |
124 | 1 | timer.done(); |
125 | 1 | |
126 | 1 | // allocate memory |
127 | 1 | // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with |
128 | 1 | // overshooting a bit. Also this means we either store the whole record batch or not. |
129 | 1 | match result |
130 | 1 | .and_then(|allocated| this.reservation.try_grow(allocated)) |
131 | | { |
132 | 0 | Ok(_) => continue, |
133 | 1 | Err(e) => Err(e), |
134 | | } |
135 | | } |
136 | 0 | Some(Err(e)) => Err(e), |
137 | | None => { |
138 | 0 | this.finished = true; |
139 | 0 | let timer = this.baseline_metrics.elapsed_compute().timer(); |
140 | 0 | let result = |
141 | 0 | finalize_aggregation(&mut this.accumulators, &this.mode) |
142 | 0 | .and_then(|columns| { |
143 | 0 | RecordBatch::try_new( |
144 | 0 | Arc::clone(&this.schema), |
145 | 0 | columns, |
146 | 0 | ) |
147 | 0 | .map_err(Into::into) |
148 | 0 | }) |
149 | 0 | .record_output(&this.baseline_metrics); |
150 | 0 |
|
151 | 0 | timer.done(); |
152 | 0 |
|
153 | 0 | result |
154 | | } |
155 | | }; |
156 | | |
157 | 1 | this.finished = true; |
158 | 1 | return Some((result, this)); |
159 | | } |
160 | 3 | }); |
161 | 2 | |
162 | 2 | // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream. |
163 | 2 | let stream = stream.fuse(); |
164 | 2 | let stream = Box::pin(stream); |
165 | 2 | |
166 | 2 | Ok(Self { |
167 | 2 | schema: agg_schema, |
168 | 2 | stream, |
169 | 2 | }) |
170 | 2 | } |
171 | | } |
172 | | |
173 | | impl Stream for AggregateStream { |
174 | | type Item = Result<RecordBatch>; |
175 | | |
176 | 3 | fn poll_next( |
177 | 3 | mut self: std::pin::Pin<&mut Self>, |
178 | 3 | cx: &mut Context<'_>, |
179 | 3 | ) -> Poll<Option<Self::Item>> { |
180 | 3 | let this = &mut *self; |
181 | 3 | this.stream.poll_next_unpin(cx) |
182 | 3 | } |
183 | | } |
184 | | |
185 | | impl RecordBatchStream for AggregateStream { |
186 | 0 | fn schema(&self) -> SchemaRef { |
187 | 0 | Arc::clone(&self.schema) |
188 | 0 | } |
189 | | } |
190 | | |
191 | | /// Perform group-by aggregation for the given [`RecordBatch`]. |
192 | | /// |
193 | | /// If successful, this returns the additional number of bytes that were allocated during this process. |
194 | | /// |
195 | | /// TODO: Make this a member function |
196 | 1 | fn aggregate_batch( |
197 | 1 | mode: &AggregateMode, |
198 | 1 | batch: RecordBatch, |
199 | 1 | accumulators: &mut [AccumulatorItem], |
200 | 1 | expressions: &[Vec<Arc<dyn PhysicalExpr>>], |
201 | 1 | filters: &[Option<Arc<dyn PhysicalExpr>>], |
202 | 1 | ) -> Result<usize> { |
203 | 1 | let mut allocated = 0usize; |
204 | 1 | |
205 | 1 | // 1.1 iterate accumulators and respective expressions together |
206 | 1 | // 1.2 filter the batch if necessary |
207 | 1 | // 1.3 evaluate expressions |
208 | 1 | // 1.4 update / merge accumulators with the expressions' values |
209 | 1 | |
210 | 1 | // 1.1 |
211 | 1 | accumulators |
212 | 1 | .iter_mut() |
213 | 1 | .zip(expressions) |
214 | 1 | .zip(filters) |
215 | 1 | .try_for_each(|((accum, expr), filter)| { |
216 | | // 1.2 |
217 | 1 | let batch = match filter { |
218 | 0 | Some(filter) => Cow::Owned(batch_filter(&batch, filter)?), |
219 | 1 | None => Cow::Borrowed(&batch), |
220 | | }; |
221 | | |
222 | | // 1.3 |
223 | 1 | let values = &expr |
224 | 1 | .iter() |
225 | 1 | .map(|e| { |
226 | 1 | e.evaluate(&batch) |
227 | 1 | .and_then(|v| v.into_array(batch.num_rows())) |
228 | 1 | }) |
229 | 1 | .collect::<Result<Vec<_>>>()?0 ; |
230 | | |
231 | | // 1.4 |
232 | 1 | let size_pre = accum.size(); |
233 | 1 | let res = match mode { |
234 | | AggregateMode::Partial |
235 | | | AggregateMode::Single |
236 | 1 | | AggregateMode::SinglePartitioned => accum.update_batch(values), |
237 | | AggregateMode::Final | AggregateMode::FinalPartitioned => { |
238 | 0 | accum.merge_batch(values) |
239 | | } |
240 | | }; |
241 | 1 | let size_post = accum.size(); |
242 | 1 | allocated += size_post.saturating_sub(size_pre); |
243 | 1 | res |
244 | 1 | })?0 ; |
245 | | |
246 | 1 | Ok(allocated) |
247 | 1 | } |