/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/row_hash.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Hash aggregation |
19 | | |
20 | | use std::sync::Arc; |
21 | | use std::task::{Context, Poll}; |
22 | | use std::vec; |
23 | | |
24 | | use crate::aggregates::group_values::{new_group_values, GroupValues}; |
25 | | use crate::aggregates::order::GroupOrderingFull; |
26 | | use crate::aggregates::{ |
27 | | evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, |
28 | | PhysicalGroupBy, |
29 | | }; |
30 | | use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; |
31 | | use crate::sorts::sort::sort_batch; |
32 | | use crate::sorts::streaming_merge::StreamingMergeBuilder; |
33 | | use crate::spill::{read_spill_as_stream, spill_record_batch_by_size}; |
34 | | use crate::stream::RecordBatchStreamAdapter; |
35 | | use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; |
36 | | use crate::{RecordBatchStream, SendableRecordBatchStream}; |
37 | | |
38 | | use arrow::array::*; |
39 | | use arrow::datatypes::SchemaRef; |
40 | | use arrow_schema::SortOptions; |
41 | | use datafusion_common::{internal_err, DataFusionError, Result}; |
42 | | use datafusion_execution::disk_manager::RefCountedTempFile; |
43 | | use datafusion_execution::memory_pool::proxy::VecAllocExt; |
44 | | use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; |
45 | | use datafusion_execution::runtime_env::RuntimeEnv; |
46 | | use datafusion_execution::TaskContext; |
47 | | use datafusion_expr::{EmitTo, GroupsAccumulator}; |
48 | | use datafusion_physical_expr::expressions::Column; |
49 | | use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; |
50 | | |
51 | | use datafusion_physical_expr::aggregate::AggregateFunctionExpr; |
52 | | use futures::ready; |
53 | | use futures::stream::{Stream, StreamExt}; |
54 | | use log::debug; |
55 | | |
56 | | use super::order::GroupOrdering; |
57 | | use super::AggregateExec; |
58 | | |
59 | | #[derive(Debug, Clone)] |
60 | | /// This object tracks the aggregation phase (input/output) |
61 | | pub(crate) enum ExecutionState { |
62 | | ReadingInput, |
63 | | /// When producing output, the remaining rows to output are stored |
64 | | /// here and are sliced off as needed in batch_size chunks |
65 | | ProducingOutput(RecordBatch), |
66 | | /// Produce intermediate aggregate state for each input row without |
67 | | /// aggregation. |
68 | | /// |
69 | | /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`] |
70 | | SkippingAggregation, |
71 | | /// All input has been consumed and all groups have been emitted |
72 | | Done, |
73 | | } |
74 | | |
75 | | /// This encapsulates the spilling state |
76 | | struct SpillState { |
77 | | // ======================================================================== |
78 | | // PROPERTIES: |
79 | | // These fields are initialized at the start and remain constant throughout |
80 | | // the execution. |
81 | | // ======================================================================== |
82 | | /// Sorting expression for spilling batches |
83 | | spill_expr: Vec<PhysicalSortExpr>, |
84 | | |
85 | | /// Schema for spilling batches |
86 | | spill_schema: SchemaRef, |
87 | | |
88 | | /// aggregate_arguments for merging spilled data |
89 | | merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>, |
90 | | |
91 | | /// GROUP BY expressions for merging spilled data |
92 | | merging_group_by: PhysicalGroupBy, |
93 | | |
94 | | // ======================================================================== |
95 | | // STATES: |
96 | | // Fields changes during execution. Can be buffer, or state flags that |
97 | | // influence the execution in parent `GroupedHashAggregateStream` |
98 | | // ======================================================================== |
99 | | /// If data has previously been spilled, the locations of the |
100 | | /// spill files (in Arrow IPC format) |
101 | | spills: Vec<RefCountedTempFile>, |
102 | | |
103 | | /// true when streaming merge is in progress |
104 | | is_stream_merging: bool, |
105 | | } |
106 | | |
107 | | /// Tracks if the aggregate should skip partial aggregations |
108 | | /// |
109 | | /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`] |
110 | | struct SkipAggregationProbe { |
111 | | // ======================================================================== |
112 | | // PROPERTIES: |
113 | | // These fields are initialized at the start and remain constant throughout |
114 | | // the execution. |
115 | | // ======================================================================== |
116 | | /// Aggregation ratio check performed when the number of input rows exceeds |
117 | | /// this threshold (from `SessionConfig`) |
118 | | probe_rows_threshold: usize, |
119 | | /// Maximum ratio of `num_groups` to `input_rows` for continuing aggregation |
120 | | /// (from `SessionConfig`). If the ratio exceeds this value, aggregation |
121 | | /// is skipped and input rows are directly converted to output |
122 | | probe_ratio_threshold: f64, |
123 | | |
124 | | // ======================================================================== |
125 | | // STATES: |
126 | | // Fields changes during execution. Can be buffer, or state flags that |
127 | | // influence the exeuction in parent `GroupedHashAggregateStream` |
128 | | // ======================================================================== |
129 | | /// Number of processed input rows (updated during probing) |
130 | | input_rows: usize, |
131 | | /// Number of total group values for `input_rows` (updated during probing) |
132 | | num_groups: usize, |
133 | | |
134 | | /// Flag indicating further data aggregation may be skipped (decision made |
135 | | /// when probing complete) |
136 | | should_skip: bool, |
137 | | /// Flag indicating further updates of `SkipAggregationProbe` state won't |
138 | | /// make any effect (set either while probing or on probing completion) |
139 | | is_locked: bool, |
140 | | |
141 | | /// Number of rows where state was output without aggregation. |
142 | | /// |
143 | | /// * If 0, all input rows were aggregated (should_skip was always false) |
144 | | /// |
145 | | /// * if greater than zero, the number of rows which were output directly |
146 | | /// without aggregation |
147 | | skipped_aggregation_rows: metrics::Count, |
148 | | } |
149 | | |
150 | | impl SkipAggregationProbe { |
151 | 44 | fn new( |
152 | 44 | probe_rows_threshold: usize, |
153 | 44 | probe_ratio_threshold: f64, |
154 | 44 | skipped_aggregation_rows: metrics::Count, |
155 | 44 | ) -> Self { |
156 | 44 | Self { |
157 | 44 | input_rows: 0, |
158 | 44 | num_groups: 0, |
159 | 44 | probe_rows_threshold, |
160 | 44 | probe_ratio_threshold, |
161 | 44 | should_skip: false, |
162 | 44 | is_locked: false, |
163 | 44 | skipped_aggregation_rows, |
164 | 44 | } |
165 | 44 | } |
166 | | |
167 | | /// Updates `SkipAggregationProbe` state: |
168 | | /// - increments the number of input rows |
169 | | /// - replaces the number of groups with the new value |
170 | | /// - on `probe_rows_threshold` exceeded calculates |
171 | | /// aggregation ratio and sets `should_skip` flag |
172 | | /// - if `should_skip` is set, locks further state updates |
173 | 51 | fn update_state(&mut self, input_rows: usize, num_groups: usize) { |
174 | 51 | if self.is_locked { |
175 | 0 | return; |
176 | 51 | } |
177 | 51 | self.input_rows += input_rows; |
178 | 51 | self.num_groups = num_groups; |
179 | 51 | if self.input_rows >= self.probe_rows_threshold { |
180 | 2 | self.should_skip = self.num_groups as f64 / self.input_rows as f64 |
181 | 2 | >= self.probe_ratio_threshold; |
182 | 2 | self.is_locked = true; |
183 | 49 | } |
184 | 51 | } |
185 | | |
186 | 77 | fn should_skip(&self) -> bool { |
187 | 77 | self.should_skip |
188 | 77 | } |
189 | | |
190 | | /// Record the number of rows that were output directly without aggregation |
191 | 2 | fn record_skipped(&mut self, batch: &RecordBatch) { |
192 | 2 | self.skipped_aggregation_rows.add(batch.num_rows()); |
193 | 2 | } |
194 | | } |
195 | | |
196 | | /// HashTable based Grouping Aggregator |
197 | | /// |
198 | | /// # Design Goals |
199 | | /// |
200 | | /// This structure is designed so that updating the aggregates can be |
201 | | /// vectorized (done in a tight loop) without allocations. The |
202 | | /// accumulator state is *not* managed by this operator (e.g in the |
203 | | /// hash table) and instead is delegated to the individual |
204 | | /// accumulators which have type specialized inner loops that perform |
205 | | /// the aggregation. |
206 | | /// |
207 | | /// # Architecture |
208 | | /// |
209 | | /// ```text |
210 | | /// |
211 | | /// Assigns a consecutive group internally stores aggregate values |
212 | | /// index for each unique set for all groups |
213 | | /// of group values |
214 | | /// |
215 | | /// ┌────────────┐ ┌──────────────┐ ┌──────────────┐ |
216 | | /// │ ┌────────┐ │ │┌────────────┐│ │┌────────────┐│ |
217 | | /// │ │ "A" │ │ ││accumulator ││ ││accumulator ││ |
218 | | /// │ ├────────┤ │ ││ 0 ││ ││ N ││ |
219 | | /// │ │ "Z" │ │ ││ ┌────────┐ ││ ││ ┌────────┐ ││ |
220 | | /// │ └────────┘ │ ││ │ state │ ││ ││ │ state │ ││ |
221 | | /// │ │ ││ │┌─────┐ │ ││ ... ││ │┌─────┐ │ ││ |
222 | | /// │ ... │ ││ │├─────┤ │ ││ ││ │├─────┤ │ ││ |
223 | | /// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││ |
224 | | /// │ │ ││ │ │ ││ ││ │ │ ││ |
225 | | /// │ ┌────────┐ │ ││ │ ... │ ││ ││ │ ... │ ││ |
226 | | /// │ │ "Q" │ │ ││ │ │ ││ ││ │ │ ││ |
227 | | /// │ └────────┘ │ ││ │┌─────┐ │ ││ ││ │┌─────┐ │ ││ |
228 | | /// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││ |
229 | | /// └────────────┘ ││ └────────┘ ││ ││ └────────┘ ││ |
230 | | /// │└────────────┘│ │└────────────┘│ |
231 | | /// └──────────────┘ └──────────────┘ |
232 | | /// |
233 | | /// group_values accumulators |
234 | | /// |
235 | | /// ``` |
236 | | /// |
237 | | /// For example, given a query like `COUNT(x), SUM(y) ... GROUP BY z`, |
238 | | /// [`group_values`] will store the distinct values of `z`. There will |
239 | | /// be one accumulator for `COUNT(x)`, specialized for the data type |
240 | | /// of `x` and one accumulator for `SUM(y)`, specialized for the data |
241 | | /// type of `y`. |
242 | | /// |
243 | | /// # Discussion |
244 | | /// |
245 | | /// [`group_values`] does not store any aggregate state inline. It only |
246 | | /// assigns "group indices", one for each (distinct) group value. The |
247 | | /// accumulators manage the in-progress aggregate state for each |
248 | | /// group, with the group values themselves are stored in |
249 | | /// [`group_values`] at the corresponding group index. |
250 | | /// |
251 | | /// The accumulator state (e.g partial sums) is managed by and stored |
252 | | /// by a [`GroupsAccumulator`] accumulator. There is one accumulator |
253 | | /// per aggregate expression (COUNT, AVG, etc) in the |
254 | | /// stream. Internally, each `GroupsAccumulator` manages the state for |
255 | | /// multiple groups, and is passed `group_indexes` during update. Note |
256 | | /// The accumulator state is not managed by this operator (e.g in the |
257 | | /// hash table). |
258 | | /// |
259 | | /// [`group_values`]: Self::group_values |
260 | | /// |
261 | | /// # Partial Aggregate and multi-phase grouping |
262 | | /// |
263 | | /// As described on [`Accumulator::state`], this operator is used in the context |
264 | | /// "multi-phase" grouping when the mode is [`AggregateMode::Partial`]. |
265 | | /// |
266 | | /// An important optimization for multi-phase partial aggregation is to skip |
267 | | /// partial aggregation when it is not effective enough to warrant the memory or |
268 | | /// CPU cost, as is often the case for queries many distinct groups (high |
269 | | /// cardinality group by). Memory is particularly important because each Partial |
270 | | /// aggregator must store the intermediate state for each group. |
271 | | /// |
272 | | /// If the ratio of the number of groups to the number of input rows exceeds a |
273 | | /// threshold, and [`GroupsAccumulator::supports_convert_to_state`] is |
274 | | /// supported, this operator will stop applying Partial aggregation and directly |
275 | | /// pass the input rows to the next aggregation phase. |
276 | | /// |
277 | | /// [`Accumulator::state`]: datafusion_expr::Accumulator::state |
278 | | /// |
279 | | /// # Spilling (to disk) |
280 | | /// |
281 | | /// The sizes of group values and accumulators can become large. Before that causes out of memory, |
282 | | /// this hash aggregator outputs partial states early for partial aggregation or spills to local |
283 | | /// disk using Arrow IPC format for final aggregation. For every input [`RecordBatch`], the memory |
284 | | /// manager checks whether the new input size meets the memory configuration. If not, outputting or |
285 | | /// spilling happens. For outputting, the final aggregation takes care of re-grouping. For spilling, |
286 | | /// later stream-merge sort on reading back the spilled data does re-grouping. Note the rows cannot |
287 | | /// be grouped once spilled onto disk, the read back data needs to be re-grouped again. In addition, |
288 | | /// re-grouping may cause out of memory again. Thus, re-grouping has to be a sort based aggregation. |
289 | | /// |
290 | | /// ```text |
291 | | /// Partial Aggregation [batch_size = 2] (max memory = 3 rows) |
292 | | /// |
293 | | /// INPUTS PARTIALLY AGGREGATED (UPDATE BATCH) OUTPUTS |
294 | | /// ┌─────────┐ ┌─────────────────┐ ┌─────────────────┐ |
295 | | /// │ a │ b │ │ a │ AVG(b) │ │ a │ AVG(b) │ |
296 | | /// │---│-----│ │ │[count]│[sum]│ │ │[count]│[sum]│ |
297 | | /// │ 3 │ 3.0 │ ─▶ │---│-------│-----│ │---│-------│-----│ |
298 | | /// │ 2 │ 2.0 │ │ 2 │ 1 │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1 │ 2.0 │ |
299 | | /// └─────────┘ │ 3 │ 2 │ 7.0 │ │ │ 3 │ 2 │ 7.0 │ |
300 | | /// ┌─────────┐ ─▶ │ 4 │ 1 │ 8.0 │ │ └─────────────────┘ |
301 | | /// │ 3 │ 4.0 │ └─────────────────┘ └▶ ┌─────────────────┐ |
302 | | /// │ 4 │ 8.0 │ ┌─────────────────┐ │ 4 │ 1 │ 8.0 │ |
303 | | /// └─────────┘ │ a │ AVG(b) │ ┌▶ │ 1 │ 1 │ 1.0 │ |
304 | | /// ┌─────────┐ │---│-------│-----│ │ └─────────────────┘ |
305 | | /// │ 1 │ 1.0 │ ─▶ │ 1 │ 1 │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐ |
306 | | /// │ 3 │ 2.0 │ │ 3 │ 1 │ 2.0 │ │ 3 │ 1 │ 2.0 │ |
307 | | /// └─────────┘ └─────────────────┘ └─────────────────┘ |
308 | | /// |
309 | | /// |
310 | | /// Final Aggregation [batch_size = 2] (max memory = 3 rows) |
311 | | /// |
312 | | /// PARTIALLY INPUTS FINAL AGGREGATION (MERGE BATCH) RE-GROUPED (SORTED) |
313 | | /// ┌─────────────────┐ [keep using the partial schema] [Real final aggregation |
314 | | /// │ a │ AVG(b) │ ┌─────────────────┐ output] |
315 | | /// │ │[count]│[sum]│ │ a │ AVG(b) │ ┌────────────┐ |
316 | | /// │---│-------│-----│ ─▶ │ │[count]│[sum]│ │ a │ AVG(b) │ |
317 | | /// │ 3 │ 3 │ 3.0 │ │---│-------│-----│ ─▶ spill ─┐ │---│--------│ |
318 | | /// │ 2 │ 2 │ 1.0 │ │ 2 │ 2 │ 1.0 │ │ │ 1 │ 4.0 │ |
319 | | /// └─────────────────┘ │ 3 │ 4 │ 8.0 │ ▼ │ 2 │ 1.0 │ |
320 | | /// ┌─────────────────┐ ─▶ │ 4 │ 1 │ 7.0 │ Streaming ─▶ └────────────┘ |
321 | | /// │ 3 │ 1 │ 5.0 │ └─────────────────┘ merge sort ─▶ ┌────────────┐ |
322 | | /// │ 4 │ 1 │ 7.0 │ ┌─────────────────┐ ▲ │ a │ AVG(b) │ |
323 | | /// └─────────────────┘ │ a │ AVG(b) │ │ │---│--------│ |
324 | | /// ┌─────────────────┐ │---│-------│-----│ ─▶ memory ─┘ │ 3 │ 2.0 │ |
325 | | /// │ 1 │ 2 │ 8.0 │ ─▶ │ 1 │ 2 │ 8.0 │ │ 4 │ 7.0 │ |
326 | | /// │ 2 │ 2 │ 3.0 │ │ 2 │ 2 │ 3.0 │ └────────────┘ |
327 | | /// └─────────────────┘ └─────────────────┘ |
328 | | /// ``` |
329 | | pub(crate) struct GroupedHashAggregateStream { |
330 | | // ======================================================================== |
331 | | // PROPERTIES: |
332 | | // These fields are initialized at the start and remain constant throughout |
333 | | // the execution. |
334 | | // ======================================================================== |
335 | | schema: SchemaRef, |
336 | | input: SendableRecordBatchStream, |
337 | | mode: AggregateMode, |
338 | | |
339 | | /// Arguments to pass to each accumulator. |
340 | | /// |
341 | | /// The arguments in `accumulator[i]` is passed `aggregate_arguments[i]` |
342 | | /// |
343 | | /// The argument to each accumulator is itself a `Vec` because |
344 | | /// some aggregates such as `CORR` can accept more than one |
345 | | /// argument. |
346 | | aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>, |
347 | | |
348 | | /// Optional filter expression to evaluate, one for each for |
349 | | /// accumulator. If present, only those rows for which the filter |
350 | | /// evaluate to true should be included in the aggregate results. |
351 | | /// |
352 | | /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`, |
353 | | /// the filter expression is `x > 100`. |
354 | | filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>, |
355 | | |
356 | | /// GROUP BY expressions |
357 | | group_by: PhysicalGroupBy, |
358 | | |
359 | | /// max rows in output RecordBatches |
360 | | batch_size: usize, |
361 | | |
362 | | /// Optional soft limit on the number of `group_values` in a batch |
363 | | /// If the number of `group_values` in a single batch exceeds this value, |
364 | | /// the `GroupedHashAggregateStream` operation immediately switches to |
365 | | /// output mode and emits all groups. |
366 | | group_values_soft_limit: Option<usize>, |
367 | | |
368 | | // ======================================================================== |
369 | | // STATE FLAGS: |
370 | | // These fields will be updated during the execution. And control the flow of |
371 | | // the execution. |
372 | | // ======================================================================== |
373 | | /// Tracks if this stream is generating input or output |
374 | | exec_state: ExecutionState, |
375 | | |
376 | | /// Have we seen the end of the input |
377 | | input_done: bool, |
378 | | |
379 | | // ======================================================================== |
380 | | // STATE BUFFERS: |
381 | | // These fields will accumulate intermediate results during the execution. |
382 | | // ======================================================================== |
383 | | /// An interning store of group keys |
384 | | group_values: Box<dyn GroupValues>, |
385 | | |
386 | | /// scratch space for the current input [`RecordBatch`] being |
387 | | /// processed. Reused across batches here to avoid reallocations |
388 | | current_group_indices: Vec<usize>, |
389 | | |
390 | | /// Accumulators, one for each `AggregateFunctionExpr` in the query |
391 | | /// |
392 | | /// For example, if the query has aggregates, `SUM(x)`, |
393 | | /// `COUNT(y)`, there will be two accumulators, each one |
394 | | /// specialized for that particular aggregate and its input types |
395 | | accumulators: Vec<Box<dyn GroupsAccumulator>>, |
396 | | |
397 | | // ======================================================================== |
398 | | // TASK-SPECIFIC STATES: |
399 | | // Inner states groups together properties, states for a specific task. |
400 | | // ======================================================================== |
401 | | /// Optional ordering information, that might allow groups to be |
402 | | /// emitted from the hash table prior to seeing the end of the |
403 | | /// input |
404 | | group_ordering: GroupOrdering, |
405 | | |
406 | | /// The spill state object |
407 | | spill_state: SpillState, |
408 | | |
409 | | /// Optional probe for skipping data aggregation, if supported by |
410 | | /// current stream. |
411 | | skip_aggregation_probe: Option<SkipAggregationProbe>, |
412 | | |
413 | | // ======================================================================== |
414 | | // EXECUTION RESOURCES: |
415 | | // Fields related to managing execution resources and monitoring performance. |
416 | | // ======================================================================== |
417 | | /// The memory reservation for this grouping |
418 | | reservation: MemoryReservation, |
419 | | |
420 | | /// Execution metrics |
421 | | baseline_metrics: BaselineMetrics, |
422 | | |
423 | | /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument |
424 | | runtime: Arc<RuntimeEnv>, |
425 | | } |
426 | | |
427 | | impl GroupedHashAggregateStream { |
428 | | /// Create a new GroupedHashAggregateStream |
429 | 70 | pub fn new( |
430 | 70 | agg: &AggregateExec, |
431 | 70 | context: Arc<TaskContext>, |
432 | 70 | partition: usize, |
433 | 70 | ) -> Result<Self> { |
434 | 70 | debug!("Creating GroupedHashAggregateStream"0 ); |
435 | 70 | let agg_schema = Arc::clone(&agg.schema); |
436 | 70 | let agg_group_by = agg.group_by.clone(); |
437 | 70 | let agg_filter_expr = agg.filter_expr.clone(); |
438 | 70 | |
439 | 70 | let batch_size = context.session_config().batch_size(); |
440 | 70 | let input = agg.input.execute(partition, Arc::clone(&context))?0 ; |
441 | 70 | let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); |
442 | 70 | |
443 | 70 | let timer = baseline_metrics.elapsed_compute().timer(); |
444 | 70 | |
445 | 70 | let aggregate_exprs = agg.aggr_expr.clone(); |
446 | | |
447 | | // arguments for each aggregate, one vec of expressions per |
448 | | // aggregate |
449 | 70 | let aggregate_arguments = aggregates::aggregate_expressions( |
450 | 70 | &agg.aggr_expr, |
451 | 70 | &agg.mode, |
452 | 70 | agg_group_by.expr.len(), |
453 | 70 | )?0 ; |
454 | | // arguments for aggregating spilled data is the same as the one for final aggregation |
455 | 70 | let merging_aggregate_arguments = aggregates::aggregate_expressions( |
456 | 70 | &agg.aggr_expr, |
457 | 70 | &AggregateMode::Final, |
458 | 70 | agg_group_by.expr.len(), |
459 | 70 | )?0 ; |
460 | | |
461 | 70 | let filter_expressions = match agg.mode { |
462 | | AggregateMode::Partial |
463 | | | AggregateMode::Single |
464 | 53 | | AggregateMode::SinglePartitioned => agg_filter_expr, |
465 | | AggregateMode::Final | AggregateMode::FinalPartitioned => { |
466 | 17 | vec![None; agg.aggr_expr.len()] |
467 | | } |
468 | | }; |
469 | | |
470 | | // Instantiate the accumulators |
471 | 70 | let accumulators: Vec<_> = aggregate_exprs |
472 | 70 | .iter() |
473 | 70 | .map(create_group_accumulator) |
474 | 70 | .collect::<Result<_>>()?0 ; |
475 | | |
476 | 70 | let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); |
477 | 70 | let spill_expr = group_schema |
478 | 70 | .fields |
479 | 70 | .into_iter() |
480 | 70 | .enumerate() |
481 | 84 | .map(|(idx, field)| PhysicalSortExpr { |
482 | 84 | expr: Arc::new(Column::new(field.name().as_str(), idx)) as _, |
483 | 84 | options: SortOptions::default(), |
484 | 84 | }) |
485 | 70 | .collect(); |
486 | 70 | |
487 | 70 | let name = format!("GroupedHashAggregateStream[{partition}]"); |
488 | 70 | let reservation = MemoryConsumer::new(name) |
489 | 70 | .with_can_spill(true) |
490 | 70 | .register(context.memory_pool()); |
491 | 70 | let (ordering, _) = agg |
492 | 70 | .properties() |
493 | 70 | .equivalence_properties() |
494 | 70 | .find_longest_permutation(&agg_group_by.output_exprs()); |
495 | 70 | let group_ordering = GroupOrdering::try_new( |
496 | 70 | &group_schema, |
497 | 70 | &agg.input_order_mode, |
498 | 70 | ordering.as_slice(), |
499 | 70 | )?0 ; |
500 | | |
501 | 70 | let group_values = new_group_values(group_schema)?0 ; |
502 | 70 | timer.done(); |
503 | 70 | |
504 | 70 | let exec_state = ExecutionState::ReadingInput; |
505 | 70 | |
506 | 70 | let spill_state = SpillState { |
507 | 70 | spills: vec![], |
508 | 70 | spill_expr, |
509 | 70 | spill_schema: Arc::clone(&agg_schema), |
510 | 70 | is_stream_merging: false, |
511 | 70 | merging_aggregate_arguments, |
512 | 70 | merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), |
513 | 70 | }; |
514 | | |
515 | | // Skip aggregation is supported if: |
516 | | // - aggregation mode is Partial |
517 | | // - input is not ordered by GROUP BY expressions, |
518 | | // since Final mode expects unique group values as its input |
519 | | // - all accumulators support input batch to intermediate |
520 | | // aggregate state conversion |
521 | | // - there is only one GROUP BY expressions set |
522 | 70 | let skip_aggregation_probe = if agg.mode == AggregateMode::Partial |
523 | 53 | && matches!0 (group_ordering, GroupOrdering::None) |
524 | 53 | && accumulators |
525 | 53 | .iter() |
526 | 53 | .all(|acc| acc.supports_convert_to_state()) |
527 | 53 | && agg_group_by.is_single() |
528 | | { |
529 | 44 | let options = &context.session_config().options().execution; |
530 | 44 | let probe_rows_threshold = |
531 | 44 | options.skip_partial_aggregation_probe_rows_threshold; |
532 | 44 | let probe_ratio_threshold = |
533 | 44 | options.skip_partial_aggregation_probe_ratio_threshold; |
534 | 44 | let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) |
535 | 44 | .counter("skipped_aggregation_rows", partition); |
536 | 44 | Some(SkipAggregationProbe::new( |
537 | 44 | probe_rows_threshold, |
538 | 44 | probe_ratio_threshold, |
539 | 44 | skipped_aggregation_rows, |
540 | 44 | )) |
541 | | } else { |
542 | 26 | None |
543 | | }; |
544 | | |
545 | 70 | Ok(GroupedHashAggregateStream { |
546 | 70 | schema: agg_schema, |
547 | 70 | input, |
548 | 70 | mode: agg.mode, |
549 | 70 | accumulators, |
550 | 70 | aggregate_arguments, |
551 | 70 | filter_expressions, |
552 | 70 | group_by: agg_group_by, |
553 | 70 | reservation, |
554 | 70 | group_values, |
555 | 70 | current_group_indices: Default::default(), |
556 | 70 | exec_state, |
557 | 70 | baseline_metrics, |
558 | 70 | batch_size, |
559 | 70 | group_ordering, |
560 | 70 | input_done: false, |
561 | 70 | runtime: context.runtime_env(), |
562 | 70 | spill_state, |
563 | 70 | group_values_soft_limit: agg.limit, |
564 | 70 | skip_aggregation_probe, |
565 | 70 | }) |
566 | 70 | } |
567 | | } |
568 | | |
569 | | /// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if |
570 | | /// that is supported by the aggregate, or a |
571 | | /// [`GroupsAccumulatorAdapter`] if not. |
572 | 70 | pub(crate) fn create_group_accumulator( |
573 | 70 | agg_expr: &AggregateFunctionExpr, |
574 | 70 | ) -> Result<Box<dyn GroupsAccumulator>> { |
575 | 70 | if agg_expr.groups_accumulator_supported() { |
576 | 30 | agg_expr.create_groups_accumulator() |
577 | | } else { |
578 | | // Note in the log when the slow path is used |
579 | 40 | debug!( |
580 | 0 | "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", |
581 | 0 | agg_expr.name() |
582 | | ); |
583 | 40 | let agg_expr_captured = agg_expr.clone(); |
584 | 130 | let factory = move || agg_expr_captured.create_accumulator(); |
585 | 40 | Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) |
586 | | } |
587 | 70 | } |
588 | | |
589 | | /// Extracts a successful Ok(_) or returns Poll::Ready(Some(Err(e))) with errors |
590 | | macro_rules! extract_ok { |
591 | | ($RES: expr) => {{ |
592 | | match $RES { |
593 | | Ok(v) => v, |
594 | | Err(e) => return Poll::Ready(Some(Err(e))), |
595 | | } |
596 | | }}; |
597 | | } |
598 | | |
599 | | impl Stream for GroupedHashAggregateStream { |
600 | | type Item = Result<RecordBatch>; |
601 | | |
602 | 277 | fn poll_next( |
603 | 277 | mut self: std::pin::Pin<&mut Self>, |
604 | 277 | cx: &mut Context<'_>, |
605 | 277 | ) -> Poll<Option<Self::Item>> { |
606 | 277 | let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); |
607 | | |
608 | | loop { |
609 | 477 | match &self.exec_state { |
610 | | ExecutionState::ReadingInput => 'reading_input: { |
611 | 285 | match ready!86 (self.input.poll_next_unpin(cx)) { |
612 | | // New batch to aggregate in partial aggregation operator |
613 | 129 | Some(Ok(batch72 )) if self.mode == AggregateMode::Partial => { |
614 | 72 | let timer = elapsed_compute.timer(); |
615 | 72 | let input_rows = batch.num_rows(); |
616 | | |
617 | | // Do the grouping |
618 | 72 | extract_ok!1 (self.group_aggregate_batch(batch)); |
619 | | |
620 | 71 | self.update_skip_aggregation_probe(input_rows); |
621 | 71 | |
622 | 71 | // If we can begin emitting rows, do so, |
623 | 71 | // otherwise keep consuming input |
624 | 71 | assert!(!self.input_done); |
625 | | |
626 | | // If the number of group values equals or exceeds the soft limit, |
627 | | // emit all groups and switch to producing output |
628 | 71 | if self.hit_soft_group_limit() { |
629 | 0 | timer.done(); |
630 | 0 | extract_ok!(self.set_input_done_and_produce_output()); |
631 | | // make sure the exec_state just set is not overwritten below |
632 | 0 | break 'reading_input; |
633 | 71 | } |
634 | | |
635 | 71 | if let Some(to_emit0 ) = self.group_ordering.emit_to() { |
636 | 0 | let batch = extract_ok!(self.emit(to_emit, false)); |
637 | 0 | self.exec_state = ExecutionState::ProducingOutput(batch); |
638 | 0 | timer.done(); |
639 | 0 | // make sure the exec_state just set is not overwritten below |
640 | 0 | break 'reading_input; |
641 | 71 | } |
642 | | |
643 | 71 | extract_ok!0 (self.emit_early_if_necessary()); |
644 | | |
645 | 71 | extract_ok!0 (self.switch_to_skip_aggregation()); |
646 | | |
647 | 71 | timer.done(); |
648 | | } |
649 | | |
650 | | // New batch to aggregate in terminal aggregation operator |
651 | | // (Final/FinalPartitioned/Single/SinglePartitioned) |
652 | 57 | Some(Ok(batch)) => { |
653 | 57 | let timer = elapsed_compute.timer(); |
654 | | |
655 | | // Make sure we have enough capacity for `batch`, otherwise spill |
656 | 57 | extract_ok!0 (self.spill_previous_if_necessary(&batch)); |
657 | | |
658 | | // Do the grouping |
659 | 57 | extract_ok!0 (self.group_aggregate_batch(batch)); |
660 | | |
661 | | // If we can begin emitting rows, do so, |
662 | | // otherwise keep consuming input |
663 | 57 | assert!(!self.input_done); |
664 | | |
665 | | // If the number of group values equals or exceeds the soft limit, |
666 | | // emit all groups and switch to producing output |
667 | 57 | if self.hit_soft_group_limit() { |
668 | 0 | timer.done(); |
669 | 0 | extract_ok!(self.set_input_done_and_produce_output()); |
670 | | // make sure the exec_state just set is not overwritten below |
671 | 0 | break 'reading_input; |
672 | 57 | } |
673 | | |
674 | 57 | if let Some(to_emit8 ) = self.group_ordering.emit_to() { |
675 | 8 | let batch = extract_ok!0 (self.emit(to_emit, false)); |
676 | 8 | self.exec_state = ExecutionState::ProducingOutput(batch); |
677 | 8 | timer.done(); |
678 | 8 | // make sure the exec_state just set is not overwritten below |
679 | 8 | break 'reading_input; |
680 | 49 | } |
681 | 49 | |
682 | 49 | timer.done(); |
683 | | } |
684 | | |
685 | | // Found error from input stream |
686 | 0 | Some(Err(e)) => { |
687 | 0 | // inner had error, return to caller |
688 | 0 | return Poll::Ready(Some(Err(e))); |
689 | | } |
690 | | |
691 | | // Found end from input stream |
692 | | None => { |
693 | | // inner is done, emit all rows and switch to producing output |
694 | 70 | extract_ok!0 (self.set_input_done_and_produce_output()); |
695 | | } |
696 | | } |
697 | | } |
698 | | |
699 | | ExecutionState::SkippingAggregation => { |
700 | 4 | match ready!0 (self.input.poll_next_unpin(cx)) { |
701 | 2 | Some(Ok(batch)) => { |
702 | 2 | let _timer = elapsed_compute.timer(); |
703 | 2 | if let Some(probe) = self.skip_aggregation_probe.as_mut() { |
704 | 2 | probe.record_skipped(&batch); |
705 | 2 | }0 |
706 | 2 | let states = self.transform_to_states(batch)?0 ; |
707 | 2 | return Poll::Ready(Some(Ok( |
708 | 2 | states.record_output(&self.baseline_metrics) |
709 | 2 | ))); |
710 | | } |
711 | 0 | Some(Err(e)) => { |
712 | 0 | // inner had error, return to caller |
713 | 0 | return Poll::Ready(Some(Err(e))); |
714 | | } |
715 | 2 | None => { |
716 | 2 | // inner is done, switching to `Done` state |
717 | 2 | self.exec_state = ExecutionState::Done; |
718 | 2 | } |
719 | | } |
720 | | } |
721 | | |
722 | 120 | ExecutionState::ProducingOutput(batch) => { |
723 | 120 | // slice off a part of the batch, if needed |
724 | 120 | let output_batch; |
725 | 120 | let size = self.batch_size; |
726 | 120 | (self.exec_state, output_batch) = if batch.num_rows() <= size { |
727 | | ( |
728 | 104 | if self.input_done { |
729 | 66 | ExecutionState::Done |
730 | | } |
731 | | // In Partial aggregation, we also need to check |
732 | | // if we should trigger partial skipping |
733 | 38 | else if self.mode == AggregateMode::Partial |
734 | 30 | && self.should_skip_aggregation() |
735 | | { |
736 | 2 | ExecutionState::SkippingAggregation |
737 | | } else { |
738 | 36 | ExecutionState::ReadingInput |
739 | | }, |
740 | 104 | batch.clone(), |
741 | | ) |
742 | | } else { |
743 | | // output first batch_size rows |
744 | 16 | let size = self.batch_size; |
745 | 16 | let num_remaining = batch.num_rows() - size; |
746 | 16 | let remaining = batch.slice(size, num_remaining); |
747 | 16 | let output = batch.slice(0, size); |
748 | 16 | (ExecutionState::ProducingOutput(remaining), output) |
749 | | }; |
750 | 120 | return Poll::Ready(Some(Ok( |
751 | 120 | output_batch.record_output(&self.baseline_metrics) |
752 | 120 | ))); |
753 | | } |
754 | | |
755 | | ExecutionState::Done => { |
756 | | // release the memory reservation since sending back output batch itself needs |
757 | | // some memory reservation, so make some room for it. |
758 | 68 | self.clear_all(); |
759 | 68 | let _ = self.update_memory_reservation(); |
760 | 68 | return Poll::Ready(None); |
761 | | } |
762 | | } |
763 | | } |
764 | 277 | } |
765 | | } |
766 | | |
767 | | impl RecordBatchStream for GroupedHashAggregateStream { |
768 | 178 | fn schema(&self) -> SchemaRef { |
769 | 178 | Arc::clone(&self.schema) |
770 | 178 | } |
771 | | } |
772 | | |
773 | | impl GroupedHashAggregateStream { |
774 | | /// Perform group-by aggregation for the given [`RecordBatch`]. |
775 | 129 | fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> { |
776 | | // Evaluate the grouping expressions |
777 | 129 | let group_by_values = if self.spill_state.is_stream_merging { |
778 | 12 | evaluate_group_by(&self.spill_state.merging_group_by, &batch)?0 |
779 | | } else { |
780 | 117 | evaluate_group_by(&self.group_by, &batch)?0 |
781 | | }; |
782 | | |
783 | | // Evaluate the aggregation expressions. |
784 | 129 | let input_values = if self.spill_state.is_stream_merging { |
785 | 12 | evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)?0 |
786 | | } else { |
787 | 117 | evaluate_many(&self.aggregate_arguments, &batch)?0 |
788 | | }; |
789 | | |
790 | | // Evaluate the filter expressions, if any, against the inputs |
791 | 129 | let filter_values = if self.spill_state.is_stream_merging { |
792 | 12 | let filter_expressions = vec![None; self.accumulators.len()]; |
793 | 12 | evaluate_optional(&filter_expressions, &batch)?0 |
794 | | } else { |
795 | 117 | evaluate_optional(&self.filter_expressions, &batch)?0 |
796 | | }; |
797 | | |
798 | 298 | for group_values169 in &group_by_values { |
799 | | // calculate the group indices for each input row |
800 | 169 | let starting_num_groups = self.group_values.len(); |
801 | 169 | self.group_values |
802 | 169 | .intern(group_values, &mut self.current_group_indices)?0 ; |
803 | 169 | let group_indices = &self.current_group_indices; |
804 | 169 | |
805 | 169 | // Update ordering information if necessary |
806 | 169 | let total_num_groups = self.group_values.len(); |
807 | 169 | if total_num_groups > starting_num_groups { |
808 | 128 | self.group_ordering.new_groups( |
809 | 128 | group_values, |
810 | 128 | group_indices, |
811 | 128 | total_num_groups, |
812 | 128 | )?0 ; |
813 | 41 | } |
814 | | |
815 | | // Gather the inputs to call the actual accumulator |
816 | 169 | let t = self |
817 | 169 | .accumulators |
818 | 169 | .iter_mut() |
819 | 169 | .zip(input_values.iter()) |
820 | 169 | .zip(filter_values.iter()); |
821 | | |
822 | 338 | for ((acc, values), opt_filter169 ) in t { |
823 | 169 | let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()0 ); |
824 | | |
825 | | // Call the appropriate method on each aggregator with |
826 | | // the entire input row and the relevant group indexes |
827 | 112 | match self.mode { |
828 | | AggregateMode::Partial |
829 | | | AggregateMode::Single |
830 | | | AggregateMode::SinglePartitioned |
831 | 112 | if !self.spill_state.is_stream_merging => |
832 | | { |
833 | 112 | acc.update_batch( |
834 | 112 | values, |
835 | 112 | group_indices, |
836 | 112 | opt_filter, |
837 | 112 | total_num_groups, |
838 | 112 | )?0 ; |
839 | | } |
840 | | _ => { |
841 | | // if aggregation is over intermediate states, |
842 | | // use merge |
843 | 57 | acc.merge_batch( |
844 | 57 | values, |
845 | 57 | group_indices, |
846 | 57 | opt_filter, |
847 | 57 | total_num_groups, |
848 | 57 | )?0 ; |
849 | | } |
850 | | } |
851 | | } |
852 | | } |
853 | | |
854 | 129 | match self.update_memory_reservation() { |
855 | | // Here we can ignore `insufficient_capacity_err` because we will spill later, |
856 | | // but at least one batch should fit in the memory |
857 | | Err(DataFusionError::ResourcesExhausted(_)) |
858 | 33 | if self.group_values.len() >= self.batch_size => |
859 | 32 | { |
860 | 32 | Ok(()) |
861 | | } |
862 | 97 | other => other, |
863 | | } |
864 | 129 | } |
865 | | |
866 | 371 | fn update_memory_reservation(&mut self) -> Result<()> { |
867 | 371 | let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>(); |
868 | 371 | self.reservation.try_resize( |
869 | 371 | acc + self.group_values.size() |
870 | 371 | + self.group_ordering.size() |
871 | 371 | + self.current_group_indices.allocated_size(), |
872 | 371 | ) |
873 | 371 | } |
874 | | |
875 | | /// Create an output RecordBatch with the group keys and |
876 | | /// accumulator states/values specified in emit_to |
877 | 112 | fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> { |
878 | 112 | let schema = if spilling { |
879 | 8 | Arc::clone(&self.spill_state.spill_schema) |
880 | | } else { |
881 | 104 | self.schema() |
882 | | }; |
883 | 112 | if self.group_values.is_empty() { |
884 | 2 | return Ok(RecordBatch::new_empty(schema)); |
885 | 110 | } |
886 | | |
887 | 110 | let mut output = self.group_values.emit(emit_to)?0 ; |
888 | 110 | if let EmitTo::First(n36 ) = emit_to { |
889 | 36 | self.group_ordering.remove_groups(n); |
890 | 74 | } |
891 | | |
892 | | // Next output each aggregate value |
893 | 110 | for acc in self.accumulators.iter_mut() { |
894 | 25 | match self.mode { |
895 | 77 | AggregateMode::Partial => output.extend(acc.state(emit_to)?0 ), |
896 | 8 | _ if spilling => { |
897 | 8 | // If spilling, output partial state because the spilled data will be |
898 | 8 | // merged and re-evaluated later. |
899 | 8 | output.extend(acc.state(emit_to)?0 ) |
900 | | } |
901 | | AggregateMode::Final |
902 | | | AggregateMode::FinalPartitioned |
903 | | | AggregateMode::Single |
904 | 25 | | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?0 ), |
905 | | } |
906 | | } |
907 | | |
908 | | // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is |
909 | | // over the target memory size after emission, we can emit again rather than returning Err. |
910 | 110 | let _ = self.update_memory_reservation(); |
911 | 110 | let batch = RecordBatch::try_new(schema, output)?0 ; |
912 | 110 | Ok(batch) |
913 | 112 | } |
914 | | |
915 | | /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly |
916 | | /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the |
917 | | /// memory. Currently only [`GroupOrdering::None`] is supported for spilling. |
918 | 57 | fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> { |
919 | 57 | // TODO: support group_ordering for spilling |
920 | 57 | if self.group_values.len() > 0 |
921 | 36 | && batch.num_rows() > 0 |
922 | 36 | && matches!8 (self.group_ordering, GroupOrdering::None) |
923 | 28 | && !self.spill_state.is_stream_merging |
924 | 28 | && self.update_memory_reservation().is_err() |
925 | | { |
926 | 4 | assert_ne!(self.mode, AggregateMode::Partial); |
927 | | // Use input batch (Partial mode) schema for spilling because |
928 | | // the spilled data will be merged and re-evaluated later. |
929 | 4 | self.spill_state.spill_schema = batch.schema(); |
930 | 4 | self.spill()?0 ; |
931 | 4 | self.clear_shrink(batch); |
932 | 53 | } |
933 | 57 | Ok(()) |
934 | 57 | } |
935 | | |
936 | | /// Emit all rows, sort them, and store them on disk. |
937 | 4 | fn spill(&mut self) -> Result<()> { |
938 | 4 | let emit = self.emit(EmitTo::All, true)?0 ; |
939 | 4 | let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?0 ; |
940 | 4 | let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?0 ; |
941 | | // TODO: slice large `sorted` and write to multiple files in parallel |
942 | 4 | spill_record_batch_by_size( |
943 | 4 | &sorted, |
944 | 4 | spillfile.path().into(), |
945 | 4 | sorted.schema(), |
946 | 4 | self.batch_size, |
947 | 4 | )?0 ; |
948 | 4 | self.spill_state.spills.push(spillfile); |
949 | 4 | Ok(()) |
950 | 4 | } |
951 | | |
952 | | /// Clear memory and shirk capacities to the size of the batch. |
953 | 76 | fn clear_shrink(&mut self, batch: &RecordBatch) { |
954 | 76 | self.group_values.clear_shrink(batch); |
955 | 76 | self.current_group_indices.clear(); |
956 | 76 | self.current_group_indices.shrink_to(batch.num_rows()); |
957 | 76 | } |
958 | | |
959 | | /// Clear memory and shirk capacities to zero. |
960 | 72 | fn clear_all(&mut self) { |
961 | 72 | let s = self.schema(); |
962 | 72 | self.clear_shrink(&RecordBatch::new_empty(s)); |
963 | 72 | } |
964 | | |
965 | | /// Emit if the used memory exceeds the target for partial aggregation. |
966 | | /// Currently only [`GroupOrdering::None`] is supported for early emitting. |
967 | | /// TODO: support group_ordering for early emitting |
968 | 71 | fn emit_early_if_necessary(&mut self) -> Result<()> { |
969 | 71 | if self.group_values.len() >= self.batch_size |
970 | 32 | && matches!0 (self.group_ordering, GroupOrdering::None) |
971 | 32 | && self.update_memory_reservation().is_err() |
972 | | { |
973 | 28 | assert_eq!(self.mode, AggregateMode::Partial); |
974 | 28 | let n = self.group_values.len() / self.batch_size * self.batch_size; |
975 | 28 | let batch = self.emit(EmitTo::First(n), false)?0 ; |
976 | 28 | self.exec_state = ExecutionState::ProducingOutput(batch); |
977 | 43 | } |
978 | 71 | Ok(()) |
979 | 71 | } |
980 | | |
981 | | /// At this point, all the inputs are read and there are some spills. |
982 | | /// Emit the remaining rows and create a batch. |
983 | | /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully |
984 | | /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. |
985 | 4 | fn update_merged_stream(&mut self) -> Result<()> { |
986 | 4 | let batch = self.emit(EmitTo::All, true)?0 ; |
987 | | // clear up memory for streaming_merge |
988 | 4 | self.clear_all(); |
989 | 4 | self.update_memory_reservation()?0 ; |
990 | 4 | let mut streams: Vec<SendableRecordBatchStream> = vec![]; |
991 | 4 | let expr = self.spill_state.spill_expr.clone(); |
992 | 4 | let schema = batch.schema(); |
993 | 4 | streams.push(Box::pin(RecordBatchStreamAdapter::new( |
994 | 4 | Arc::clone(&schema), |
995 | 4 | futures::stream::once(futures::future::lazy(move |_| { |
996 | 4 | sort_batch(&batch, &expr, None) |
997 | 4 | })), |
998 | 4 | ))); |
999 | 4 | for spill in self.spill_state.spills.drain(..) { |
1000 | 4 | let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?0 ; |
1001 | 4 | streams.push(stream); |
1002 | | } |
1003 | 4 | self.spill_state.is_stream_merging = true; |
1004 | 4 | self.input = StreamingMergeBuilder::new() |
1005 | 4 | .with_streams(streams) |
1006 | 4 | .with_schema(schema) |
1007 | 4 | .with_expressions(&self.spill_state.spill_expr) |
1008 | 4 | .with_metrics(self.baseline_metrics.clone()) |
1009 | 4 | .with_batch_size(self.batch_size) |
1010 | 4 | .with_reservation(self.reservation.new_empty()) |
1011 | 4 | .build()?0 ; |
1012 | 4 | self.input_done = false; |
1013 | 4 | self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); |
1014 | 4 | Ok(()) |
1015 | 4 | } |
1016 | | |
1017 | | /// returns true if there is a soft groups limit and the number of distinct |
1018 | | /// groups we have seen is over that limit |
1019 | 128 | fn hit_soft_group_limit(&self) -> bool { |
1020 | 128 | let Some(group_values_soft_limit0 ) = self.group_values_soft_limit else { |
1021 | 128 | return false; |
1022 | | }; |
1023 | 0 | group_values_soft_limit <= self.group_values.len() |
1024 | 128 | } |
1025 | | |
1026 | | /// common function for signalling end of processing of the input stream |
1027 | 70 | fn set_input_done_and_produce_output(&mut self) -> Result<()> { |
1028 | 70 | self.input_done = true; |
1029 | 70 | self.group_ordering.input_done(); |
1030 | 70 | let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); |
1031 | 70 | let timer = elapsed_compute.timer(); |
1032 | 70 | self.exec_state = if self.spill_state.spills.is_empty() { |
1033 | 66 | let batch = self.emit(EmitTo::All, false)?0 ; |
1034 | 66 | ExecutionState::ProducingOutput(batch) |
1035 | | } else { |
1036 | | // If spill files exist, stream-merge them. |
1037 | 4 | self.update_merged_stream()?0 ; |
1038 | 4 | ExecutionState::ReadingInput |
1039 | | }; |
1040 | 70 | timer.done(); |
1041 | 70 | Ok(()) |
1042 | 70 | } |
1043 | | |
1044 | | /// Updates skip aggregation probe state. |
1045 | | /// |
1046 | | /// Notice: It should only be called in Partial aggregation |
1047 | 71 | fn update_skip_aggregation_probe(&mut self, input_rows: usize) { |
1048 | 71 | if let Some(probe51 ) = self.skip_aggregation_probe.as_mut() { |
1049 | | // Skip aggregation probe is not supported if stream has any spills, |
1050 | | // currently spilling is not supported for Partial aggregation |
1051 | 51 | assert!(self.spill_state.spills.is_empty()); |
1052 | 51 | probe.update_state(input_rows, self.group_values.len()); |
1053 | 20 | }; |
1054 | 71 | } |
1055 | | |
1056 | | /// In case the probe indicates that aggregation may be |
1057 | | /// skipped, forces stream to produce currently accumulated output. |
1058 | | /// |
1059 | | /// Notice: It should only be called in Partial aggregation |
1060 | 71 | fn switch_to_skip_aggregation(&mut self) -> Result<()> { |
1061 | 71 | if let Some(probe51 ) = self.skip_aggregation_probe.as_mut() { |
1062 | 51 | if probe.should_skip() { |
1063 | 2 | let batch = self.emit(EmitTo::All, false)?0 ; |
1064 | 2 | self.exec_state = ExecutionState::ProducingOutput(batch); |
1065 | 49 | } |
1066 | 20 | } |
1067 | | |
1068 | 71 | Ok(()) |
1069 | 71 | } |
1070 | | |
1071 | | /// Returns true if the aggregation probe indicates that aggregation |
1072 | | /// should be skipped. |
1073 | | /// |
1074 | | /// Notice: It should only be called in Partial aggregation |
1075 | 30 | fn should_skip_aggregation(&self) -> bool { |
1076 | 30 | self.skip_aggregation_probe |
1077 | 30 | .as_ref() |
1078 | 30 | .is_some_and(|probe| probe.should_skip()26 ) |
1079 | 30 | } |
1080 | | |
1081 | | /// Transforms input batch to intermediate aggregate state, without grouping it |
1082 | 2 | fn transform_to_states(&self, batch: RecordBatch) -> Result<RecordBatch> { |
1083 | 2 | let mut group_values = evaluate_group_by(&self.group_by, &batch)?0 ; |
1084 | 2 | let input_values = evaluate_many(&self.aggregate_arguments, &batch)?0 ; |
1085 | 2 | let filter_values = evaluate_optional(&self.filter_expressions, &batch)?0 ; |
1086 | | |
1087 | 2 | if group_values.len() != 1 { |
1088 | 0 | return internal_err!("group_values expected to have single element"); |
1089 | 2 | } |
1090 | 2 | let mut output = group_values.swap_remove(0); |
1091 | 2 | |
1092 | 2 | let iter = self |
1093 | 2 | .accumulators |
1094 | 2 | .iter() |
1095 | 2 | .zip(input_values.iter()) |
1096 | 2 | .zip(filter_values.iter()); |
1097 | | |
1098 | 4 | for ((acc, values), opt_filter2 ) in iter { |
1099 | 2 | let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()0 ); |
1100 | 2 | output.extend(acc.convert_to_state(values, opt_filter)?0 ); |
1101 | | } |
1102 | | |
1103 | 2 | let states_batch = RecordBatch::try_new(self.schema(), output)?0 ; |
1104 | | |
1105 | 2 | Ok(states_batch) |
1106 | 2 | } |
1107 | | } |