Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/row_hash.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Hash aggregation
19
20
use std::sync::Arc;
21
use std::task::{Context, Poll};
22
use std::vec;
23
24
use crate::aggregates::group_values::{new_group_values, GroupValues};
25
use crate::aggregates::order::GroupOrderingFull;
26
use crate::aggregates::{
27
    evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode,
28
    PhysicalGroupBy,
29
};
30
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
31
use crate::sorts::sort::sort_batch;
32
use crate::sorts::streaming_merge::StreamingMergeBuilder;
33
use crate::spill::{read_spill_as_stream, spill_record_batch_by_size};
34
use crate::stream::RecordBatchStreamAdapter;
35
use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr};
36
use crate::{RecordBatchStream, SendableRecordBatchStream};
37
38
use arrow::array::*;
39
use arrow::datatypes::SchemaRef;
40
use arrow_schema::SortOptions;
41
use datafusion_common::{internal_err, DataFusionError, Result};
42
use datafusion_execution::disk_manager::RefCountedTempFile;
43
use datafusion_execution::memory_pool::proxy::VecAllocExt;
44
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
45
use datafusion_execution::runtime_env::RuntimeEnv;
46
use datafusion_execution::TaskContext;
47
use datafusion_expr::{EmitTo, GroupsAccumulator};
48
use datafusion_physical_expr::expressions::Column;
49
use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr};
50
51
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
52
use futures::ready;
53
use futures::stream::{Stream, StreamExt};
54
use log::debug;
55
56
use super::order::GroupOrdering;
57
use super::AggregateExec;
58
59
#[derive(Debug, Clone)]
60
/// This object tracks the aggregation phase (input/output)
61
pub(crate) enum ExecutionState {
62
    ReadingInput,
63
    /// When producing output, the remaining rows to output are stored
64
    /// here and are sliced off as needed in batch_size chunks
65
    ProducingOutput(RecordBatch),
66
    /// Produce intermediate aggregate state for each input row without
67
    /// aggregation.
68
    ///
69
    /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
70
    SkippingAggregation,
71
    /// All input has been consumed and all groups have been emitted
72
    Done,
73
}
74
75
/// This encapsulates the spilling state
76
struct SpillState {
77
    // ========================================================================
78
    // PROPERTIES:
79
    // These fields are initialized at the start and remain constant throughout
80
    // the execution.
81
    // ========================================================================
82
    /// Sorting expression for spilling batches
83
    spill_expr: Vec<PhysicalSortExpr>,
84
85
    /// Schema for spilling batches
86
    spill_schema: SchemaRef,
87
88
    /// aggregate_arguments for merging spilled data
89
    merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
90
91
    /// GROUP BY expressions for merging spilled data
92
    merging_group_by: PhysicalGroupBy,
93
94
    // ========================================================================
95
    // STATES:
96
    // Fields changes during execution. Can be buffer, or state flags that
97
    // influence the execution in parent `GroupedHashAggregateStream`
98
    // ========================================================================
99
    /// If data has previously been spilled, the locations of the
100
    /// spill files (in Arrow IPC format)
101
    spills: Vec<RefCountedTempFile>,
102
103
    /// true when streaming merge is in progress
104
    is_stream_merging: bool,
105
}
106
107
/// Tracks if the aggregate should skip partial aggregations
108
///
109
/// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
110
struct SkipAggregationProbe {
111
    // ========================================================================
112
    // PROPERTIES:
113
    // These fields are initialized at the start and remain constant throughout
114
    // the execution.
115
    // ========================================================================
116
    /// Aggregation ratio check performed when the number of input rows exceeds
117
    /// this threshold (from `SessionConfig`)
118
    probe_rows_threshold: usize,
119
    /// Maximum ratio of `num_groups` to `input_rows` for continuing aggregation
120
    /// (from `SessionConfig`). If the ratio exceeds this value, aggregation
121
    /// is skipped and input rows are directly converted to output
122
    probe_ratio_threshold: f64,
123
124
    // ========================================================================
125
    // STATES:
126
    // Fields changes during execution. Can be buffer, or state flags that
127
    // influence the exeuction in parent `GroupedHashAggregateStream`
128
    // ========================================================================
129
    /// Number of processed input rows (updated during probing)
130
    input_rows: usize,
131
    /// Number of total group values for `input_rows` (updated during probing)
132
    num_groups: usize,
133
134
    /// Flag indicating further data aggregation may be skipped (decision made
135
    /// when probing complete)
136
    should_skip: bool,
137
    /// Flag indicating further updates of `SkipAggregationProbe` state won't
138
    /// make any effect (set either while probing or on probing completion)
139
    is_locked: bool,
140
141
    /// Number of rows where state was output without aggregation.
142
    ///
143
    /// * If 0, all input rows were aggregated (should_skip was always false)
144
    ///
145
    /// * if greater than zero, the number of rows which were output directly
146
    ///   without aggregation
147
    skipped_aggregation_rows: metrics::Count,
148
}
149
150
impl SkipAggregationProbe {
151
44
    fn new(
152
44
        probe_rows_threshold: usize,
153
44
        probe_ratio_threshold: f64,
154
44
        skipped_aggregation_rows: metrics::Count,
155
44
    ) -> Self {
156
44
        Self {
157
44
            input_rows: 0,
158
44
            num_groups: 0,
159
44
            probe_rows_threshold,
160
44
            probe_ratio_threshold,
161
44
            should_skip: false,
162
44
            is_locked: false,
163
44
            skipped_aggregation_rows,
164
44
        }
165
44
    }
166
167
    /// Updates `SkipAggregationProbe` state:
168
    /// - increments the number of input rows
169
    /// - replaces the number of groups with the new value
170
    /// - on `probe_rows_threshold` exceeded calculates
171
    ///   aggregation ratio and sets `should_skip` flag
172
    /// - if `should_skip` is set, locks further state updates
173
51
    fn update_state(&mut self, input_rows: usize, num_groups: usize) {
174
51
        if self.is_locked {
175
0
            return;
176
51
        }
177
51
        self.input_rows += input_rows;
178
51
        self.num_groups = num_groups;
179
51
        if self.input_rows >= self.probe_rows_threshold {
180
2
            self.should_skip = self.num_groups as f64 / self.input_rows as f64
181
2
                >= self.probe_ratio_threshold;
182
2
            self.is_locked = true;
183
49
        }
184
51
    }
185
186
77
    fn should_skip(&self) -> bool {
187
77
        self.should_skip
188
77
    }
189
190
    /// Record the number of rows that were output directly without aggregation
191
2
    fn record_skipped(&mut self, batch: &RecordBatch) {
192
2
        self.skipped_aggregation_rows.add(batch.num_rows());
193
2
    }
194
}
195
196
/// HashTable based Grouping Aggregator
197
///
198
/// # Design Goals
199
///
200
/// This structure is designed so that updating the aggregates can be
201
/// vectorized (done in a tight loop) without allocations. The
202
/// accumulator state is *not* managed by this operator (e.g in the
203
/// hash table) and instead is delegated to the individual
204
/// accumulators which have type specialized inner loops that perform
205
/// the aggregation.
206
///
207
/// # Architecture
208
///
209
/// ```text
210
///
211
///     Assigns a consecutive group           internally stores aggregate values
212
///     index for each unique set                     for all groups
213
///         of group values
214
///
215
///         ┌────────────┐              ┌──────────────┐       ┌──────────────┐
216
///         │ ┌────────┐ │              │┌────────────┐│       │┌────────────┐│
217
///         │ │  "A"   │ │              ││accumulator ││       ││accumulator ││
218
///         │ ├────────┤ │              ││     0      ││       ││     N      ││
219
///         │ │  "Z"   │ │              ││ ┌────────┐ ││       ││ ┌────────┐ ││
220
///         │ └────────┘ │              ││ │ state  │ ││       ││ │ state  │ ││
221
///         │            │              ││ │┌─────┐ │ ││  ...  ││ │┌─────┐ │ ││
222
///         │    ...     │              ││ │├─────┤ │ ││       ││ │├─────┤ │ ││
223
///         │            │              ││ │└─────┘ │ ││       ││ │└─────┘ │ ││
224
///         │            │              ││ │        │ ││       ││ │        │ ││
225
///         │ ┌────────┐ │              ││ │  ...   │ ││       ││ │  ...   │ ││
226
///         │ │  "Q"   │ │              ││ │        │ ││       ││ │        │ ││
227
///         │ └────────┘ │              ││ │┌─────┐ │ ││       ││ │┌─────┐ │ ││
228
///         │            │              ││ │└─────┘ │ ││       ││ │└─────┘ │ ││
229
///         └────────────┘              ││ └────────┘ ││       ││ └────────┘ ││
230
///                                     │└────────────┘│       │└────────────┘│
231
///                                     └──────────────┘       └──────────────┘
232
///
233
///         group_values                             accumulators
234
///
235
///  ```
236
///
237
/// For example, given a query like `COUNT(x), SUM(y) ... GROUP BY z`,
238
/// [`group_values`] will store the distinct values of `z`. There will
239
/// be one accumulator for `COUNT(x)`, specialized for the data type
240
/// of `x` and one accumulator for `SUM(y)`, specialized for the data
241
/// type of `y`.
242
///
243
/// # Discussion
244
///
245
/// [`group_values`] does not store any aggregate state inline. It only
246
/// assigns "group indices", one for each (distinct) group value. The
247
/// accumulators manage the in-progress aggregate state for each
248
/// group, with the group values themselves are stored in
249
/// [`group_values`] at the corresponding group index.
250
///
251
/// The accumulator state (e.g partial sums) is managed by and stored
252
/// by a [`GroupsAccumulator`] accumulator. There is one accumulator
253
/// per aggregate expression (COUNT, AVG, etc) in the
254
/// stream. Internally, each `GroupsAccumulator` manages the state for
255
/// multiple groups, and is passed `group_indexes` during update. Note
256
/// The accumulator state is not managed by this operator (e.g in the
257
/// hash table).
258
///
259
/// [`group_values`]: Self::group_values
260
///
261
/// # Partial Aggregate and multi-phase grouping
262
///
263
/// As described on [`Accumulator::state`], this operator is used in the context
264
/// "multi-phase" grouping when the mode is [`AggregateMode::Partial`].
265
///
266
/// An important optimization for multi-phase partial aggregation is to skip
267
/// partial aggregation when it is not effective enough to warrant the memory or
268
/// CPU cost, as is often the case for queries many distinct groups (high
269
/// cardinality group by). Memory is particularly important because each Partial
270
/// aggregator must store the intermediate state for each group.
271
///
272
/// If the ratio of the number of groups to the number of input rows exceeds a
273
/// threshold, and [`GroupsAccumulator::supports_convert_to_state`] is
274
/// supported, this operator will stop applying Partial aggregation and directly
275
/// pass the input rows to the next aggregation phase.
276
///
277
/// [`Accumulator::state`]: datafusion_expr::Accumulator::state
278
///
279
/// # Spilling (to disk)
280
///
281
/// The sizes of group values and accumulators can become large. Before that causes out of memory,
282
/// this hash aggregator outputs partial states early for partial aggregation or spills to local
283
/// disk using Arrow IPC format for final aggregation. For every input [`RecordBatch`], the memory
284
/// manager checks whether the new input size meets the memory configuration. If not, outputting or
285
/// spilling happens. For outputting, the final aggregation takes care of re-grouping. For spilling,
286
/// later stream-merge sort on reading back the spilled data does re-grouping. Note the rows cannot
287
/// be grouped once spilled onto disk, the read back data needs to be re-grouped again. In addition,
288
/// re-grouping may cause out of memory again. Thus, re-grouping has to be a sort based aggregation.
289
///
290
/// ```text
291
/// Partial Aggregation [batch_size = 2] (max memory = 3 rows)
292
///
293
///  INPUTS        PARTIALLY AGGREGATED (UPDATE BATCH)   OUTPUTS
294
/// ┌─────────┐    ┌─────────────────┐                  ┌─────────────────┐
295
/// │ a │ b   │    │ a │    AVG(b)   │                  │ a │    AVG(b)   │
296
/// │---│-----│    │   │[count]│[sum]│                  │   │[count]│[sum]│
297
/// │ 3 │ 3.0 │ ─▶ │---│-------│-----│                  │---│-------│-----│
298
/// │ 2 │ 2.0 │    │ 2 │ 1     │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1     │ 2.0 │
299
/// └─────────┘    │ 3 │ 2     │ 7.0 │               │  │ 3 │ 2     │ 7.0 │
300
/// ┌─────────┐ ─▶ │ 4 │ 1     │ 8.0 │               │  └─────────────────┘
301
/// │ 3 │ 4.0 │    └─────────────────┘               └▶ ┌─────────────────┐
302
/// │ 4 │ 8.0 │    ┌─────────────────┐                  │ 4 │ 1     │ 8.0 │
303
/// └─────────┘    │ a │    AVG(b)   │               ┌▶ │ 1 │ 1     │ 1.0 │
304
/// ┌─────────┐    │---│-------│-----│               │  └─────────────────┘
305
/// │ 1 │ 1.0 │ ─▶ │ 1 │ 1     │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐
306
/// │ 3 │ 2.0 │    │ 3 │ 1     │ 2.0 │                  │ 3 │ 1     │ 2.0 │
307
/// └─────────┘    └─────────────────┘                  └─────────────────┘
308
///
309
///
310
/// Final Aggregation [batch_size = 2] (max memory = 3 rows)
311
///
312
/// PARTIALLY INPUTS       FINAL AGGREGATION (MERGE BATCH)       RE-GROUPED (SORTED)
313
/// ┌─────────────────┐    [keep using the partial schema]       [Real final aggregation
314
/// │ a │    AVG(b)   │    ┌─────────────────┐                    output]
315
/// │   │[count]│[sum]│    │ a │    AVG(b)   │                   ┌────────────┐
316
/// │---│-------│-----│ ─▶ │   │[count]│[sum]│                   │ a │ AVG(b) │
317
/// │ 3 │ 3     │ 3.0 │    │---│-------│-----│ ─▶ spill ─┐       │---│--------│
318
/// │ 2 │ 2     │ 1.0 │    │ 2 │ 2     │ 1.0 │           │       │ 1 │    4.0 │
319
/// └─────────────────┘    │ 3 │ 4     │ 8.0 │           ▼       │ 2 │    1.0 │
320
/// ┌─────────────────┐ ─▶ │ 4 │ 1     │ 7.0 │     Streaming  ─▶ └────────────┘
321
/// │ 3 │ 1     │ 5.0 │    └─────────────────┘     merge sort ─▶ ┌────────────┐
322
/// │ 4 │ 1     │ 7.0 │    ┌─────────────────┐            ▲      │ a │ AVG(b) │
323
/// └─────────────────┘    │ a │    AVG(b)   │            │      │---│--------│
324
/// ┌─────────────────┐    │---│-------│-----│ ─▶ memory ─┘      │ 3 │    2.0 │
325
/// │ 1 │ 2     │ 8.0 │ ─▶ │ 1 │ 2     │ 8.0 │                   │ 4 │    7.0 │
326
/// │ 2 │ 2     │ 3.0 │    │ 2 │ 2     │ 3.0 │                   └────────────┘
327
/// └─────────────────┘    └─────────────────┘
328
/// ```
329
pub(crate) struct GroupedHashAggregateStream {
330
    // ========================================================================
331
    // PROPERTIES:
332
    // These fields are initialized at the start and remain constant throughout
333
    // the execution.
334
    // ========================================================================
335
    schema: SchemaRef,
336
    input: SendableRecordBatchStream,
337
    mode: AggregateMode,
338
339
    /// Arguments to pass to each accumulator.
340
    ///
341
    /// The arguments in `accumulator[i]` is passed `aggregate_arguments[i]`
342
    ///
343
    /// The argument to each accumulator is itself a `Vec` because
344
    /// some aggregates such as `CORR` can accept more than one
345
    /// argument.
346
    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
347
348
    /// Optional filter expression to evaluate, one for each for
349
    /// accumulator. If present, only those rows for which the filter
350
    /// evaluate to true should be included in the aggregate results.
351
    ///
352
    /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`,
353
    /// the filter expression is  `x > 100`.
354
    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
355
356
    /// GROUP BY expressions
357
    group_by: PhysicalGroupBy,
358
359
    /// max rows in output RecordBatches
360
    batch_size: usize,
361
362
    /// Optional soft limit on the number of `group_values` in a batch
363
    /// If the number of `group_values` in a single batch exceeds this value,
364
    /// the `GroupedHashAggregateStream` operation immediately switches to
365
    /// output mode and emits all groups.
366
    group_values_soft_limit: Option<usize>,
367
368
    // ========================================================================
369
    // STATE FLAGS:
370
    // These fields will be updated during the execution. And control the flow of
371
    // the execution.
372
    // ========================================================================
373
    /// Tracks if this stream is generating input or output
374
    exec_state: ExecutionState,
375
376
    /// Have we seen the end of the input
377
    input_done: bool,
378
379
    // ========================================================================
380
    // STATE BUFFERS:
381
    // These fields will accumulate intermediate results during the execution.
382
    // ========================================================================
383
    /// An interning store of group keys
384
    group_values: Box<dyn GroupValues>,
385
386
    /// scratch space for the current input [`RecordBatch`] being
387
    /// processed. Reused across batches here to avoid reallocations
388
    current_group_indices: Vec<usize>,
389
390
    /// Accumulators, one for each `AggregateFunctionExpr` in the query
391
    ///
392
    /// For example, if the query has aggregates, `SUM(x)`,
393
    /// `COUNT(y)`, there will be two accumulators, each one
394
    /// specialized for that particular aggregate and its input types
395
    accumulators: Vec<Box<dyn GroupsAccumulator>>,
396
397
    // ========================================================================
398
    // TASK-SPECIFIC STATES:
399
    // Inner states groups together properties, states for a specific task.
400
    // ========================================================================
401
    /// Optional ordering information, that might allow groups to be
402
    /// emitted from the hash table prior to seeing the end of the
403
    /// input
404
    group_ordering: GroupOrdering,
405
406
    /// The spill state object
407
    spill_state: SpillState,
408
409
    /// Optional probe for skipping data aggregation, if supported by
410
    /// current stream.
411
    skip_aggregation_probe: Option<SkipAggregationProbe>,
412
413
    // ========================================================================
414
    // EXECUTION RESOURCES:
415
    // Fields related to managing execution resources and monitoring performance.
416
    // ========================================================================
417
    /// The memory reservation for this grouping
418
    reservation: MemoryReservation,
419
420
    /// Execution metrics
421
    baseline_metrics: BaselineMetrics,
422
423
    /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument
424
    runtime: Arc<RuntimeEnv>,
425
}
426
427
impl GroupedHashAggregateStream {
428
    /// Create a new GroupedHashAggregateStream
429
70
    pub fn new(
430
70
        agg: &AggregateExec,
431
70
        context: Arc<TaskContext>,
432
70
        partition: usize,
433
70
    ) -> Result<Self> {
434
70
        debug!(
"Creating GroupedHashAggregateStream"0
);
435
70
        let agg_schema = Arc::clone(&agg.schema);
436
70
        let agg_group_by = agg.group_by.clone();
437
70
        let agg_filter_expr = agg.filter_expr.clone();
438
70
439
70
        let batch_size = context.session_config().batch_size();
440
70
        let input = agg.input.execute(partition, Arc::clone(&context))
?0
;
441
70
        let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
442
70
443
70
        let timer = baseline_metrics.elapsed_compute().timer();
444
70
445
70
        let aggregate_exprs = agg.aggr_expr.clone();
446
447
        // arguments for each aggregate, one vec of expressions per
448
        // aggregate
449
70
        let aggregate_arguments = aggregates::aggregate_expressions(
450
70
            &agg.aggr_expr,
451
70
            &agg.mode,
452
70
            agg_group_by.expr.len(),
453
70
        )
?0
;
454
        // arguments for aggregating spilled data is the same as the one for final aggregation
455
70
        let merging_aggregate_arguments = aggregates::aggregate_expressions(
456
70
            &agg.aggr_expr,
457
70
            &AggregateMode::Final,
458
70
            agg_group_by.expr.len(),
459
70
        )
?0
;
460
461
70
        let filter_expressions = match agg.mode {
462
            AggregateMode::Partial
463
            | AggregateMode::Single
464
53
            | AggregateMode::SinglePartitioned => agg_filter_expr,
465
            AggregateMode::Final | AggregateMode::FinalPartitioned => {
466
17
                vec![None; agg.aggr_expr.len()]
467
            }
468
        };
469
470
        // Instantiate the accumulators
471
70
        let accumulators: Vec<_> = aggregate_exprs
472
70
            .iter()
473
70
            .map(create_group_accumulator)
474
70
            .collect::<Result<_>>()
?0
;
475
476
70
        let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
477
70
        let spill_expr = group_schema
478
70
            .fields
479
70
            .into_iter()
480
70
            .enumerate()
481
84
            .map(|(idx, field)| PhysicalSortExpr {
482
84
                expr: Arc::new(Column::new(field.name().as_str(), idx)) as _,
483
84
                options: SortOptions::default(),
484
84
            })
485
70
            .collect();
486
70
487
70
        let name = format!("GroupedHashAggregateStream[{partition}]");
488
70
        let reservation = MemoryConsumer::new(name)
489
70
            .with_can_spill(true)
490
70
            .register(context.memory_pool());
491
70
        let (ordering, _) = agg
492
70
            .properties()
493
70
            .equivalence_properties()
494
70
            .find_longest_permutation(&agg_group_by.output_exprs());
495
70
        let group_ordering = GroupOrdering::try_new(
496
70
            &group_schema,
497
70
            &agg.input_order_mode,
498
70
            ordering.as_slice(),
499
70
        )
?0
;
500
501
70
        let group_values = new_group_values(group_schema)
?0
;
502
70
        timer.done();
503
70
504
70
        let exec_state = ExecutionState::ReadingInput;
505
70
506
70
        let spill_state = SpillState {
507
70
            spills: vec![],
508
70
            spill_expr,
509
70
            spill_schema: Arc::clone(&agg_schema),
510
70
            is_stream_merging: false,
511
70
            merging_aggregate_arguments,
512
70
            merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
513
70
        };
514
515
        // Skip aggregation is supported if:
516
        // - aggregation mode is Partial
517
        // - input is not ordered by GROUP BY expressions,
518
        //   since Final mode expects unique group values as its input
519
        // - all accumulators support input batch to intermediate
520
        //   aggregate state conversion
521
        // - there is only one GROUP BY expressions set
522
70
        let skip_aggregation_probe = if agg.mode == AggregateMode::Partial
523
53
            && 
matches!0
(group_ordering, GroupOrdering::None)
524
53
            && accumulators
525
53
                .iter()
526
53
                .all(|acc| acc.supports_convert_to_state())
527
53
            && agg_group_by.is_single()
528
        {
529
44
            let options = &context.session_config().options().execution;
530
44
            let probe_rows_threshold =
531
44
                options.skip_partial_aggregation_probe_rows_threshold;
532
44
            let probe_ratio_threshold =
533
44
                options.skip_partial_aggregation_probe_ratio_threshold;
534
44
            let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics)
535
44
                .counter("skipped_aggregation_rows", partition);
536
44
            Some(SkipAggregationProbe::new(
537
44
                probe_rows_threshold,
538
44
                probe_ratio_threshold,
539
44
                skipped_aggregation_rows,
540
44
            ))
541
        } else {
542
26
            None
543
        };
544
545
70
        Ok(GroupedHashAggregateStream {
546
70
            schema: agg_schema,
547
70
            input,
548
70
            mode: agg.mode,
549
70
            accumulators,
550
70
            aggregate_arguments,
551
70
            filter_expressions,
552
70
            group_by: agg_group_by,
553
70
            reservation,
554
70
            group_values,
555
70
            current_group_indices: Default::default(),
556
70
            exec_state,
557
70
            baseline_metrics,
558
70
            batch_size,
559
70
            group_ordering,
560
70
            input_done: false,
561
70
            runtime: context.runtime_env(),
562
70
            spill_state,
563
70
            group_values_soft_limit: agg.limit,
564
70
            skip_aggregation_probe,
565
70
        })
566
70
    }
567
}
568
569
/// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if
570
/// that is supported by the aggregate, or a
571
/// [`GroupsAccumulatorAdapter`] if not.
572
70
pub(crate) fn create_group_accumulator(
573
70
    agg_expr: &AggregateFunctionExpr,
574
70
) -> Result<Box<dyn GroupsAccumulator>> {
575
70
    if agg_expr.groups_accumulator_supported() {
576
30
        agg_expr.create_groups_accumulator()
577
    } else {
578
        // Note in the log when the slow path is used
579
40
        debug!(
580
0
            "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}",
581
0
            agg_expr.name()
582
        );
583
40
        let agg_expr_captured = agg_expr.clone();
584
130
        let factory = move || agg_expr_captured.create_accumulator();
585
40
        Ok(Box::new(GroupsAccumulatorAdapter::new(factory)))
586
    }
587
70
}
588
589
/// Extracts a successful Ok(_) or returns Poll::Ready(Some(Err(e))) with errors
590
macro_rules! extract_ok {
591
    ($RES: expr) => {{
592
        match $RES {
593
            Ok(v) => v,
594
            Err(e) => return Poll::Ready(Some(Err(e))),
595
        }
596
    }};
597
}
598
599
impl Stream for GroupedHashAggregateStream {
600
    type Item = Result<RecordBatch>;
601
602
277
    fn poll_next(
603
277
        mut self: std::pin::Pin<&mut Self>,
604
277
        cx: &mut Context<'_>,
605
277
    ) -> Poll<Option<Self::Item>> {
606
277
        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
607
608
        loop {
609
477
            match &self.exec_state {
610
                ExecutionState::ReadingInput => 'reading_input: {
611
285
                    match 
ready!86
(self.input.poll_next_unpin(cx)) {
612
                        // New batch to aggregate in partial aggregation operator
613
129
                        Some(Ok(
batch72
)) if self.mode == AggregateMode::Partial => {
614
72
                            let timer = elapsed_compute.timer();
615
72
                            let input_rows = batch.num_rows();
616
617
                            // Do the grouping
618
72
                            
extract_ok!1
(self.group_aggregate_batch(batch));
619
620
71
                            self.update_skip_aggregation_probe(input_rows);
621
71
622
71
                            // If we can begin emitting rows, do so,
623
71
                            // otherwise keep consuming input
624
71
                            assert!(!self.input_done);
625
626
                            // If the number of group values equals or exceeds the soft limit,
627
                            // emit all groups and switch to producing output
628
71
                            if self.hit_soft_group_limit() {
629
0
                                timer.done();
630
0
                                extract_ok!(self.set_input_done_and_produce_output());
631
                                // make sure the exec_state just set is not overwritten below
632
0
                                break 'reading_input;
633
71
                            }
634
635
71
                            if let Some(
to_emit0
) = self.group_ordering.emit_to() {
636
0
                                let batch = extract_ok!(self.emit(to_emit, false));
637
0
                                self.exec_state = ExecutionState::ProducingOutput(batch);
638
0
                                timer.done();
639
0
                                // make sure the exec_state just set is not overwritten below
640
0
                                break 'reading_input;
641
71
                            }
642
643
71
                            
extract_ok!0
(self.emit_early_if_necessary());
644
645
71
                            
extract_ok!0
(self.switch_to_skip_aggregation());
646
647
71
                            timer.done();
648
                        }
649
650
                        // New batch to aggregate in terminal aggregation operator
651
                        // (Final/FinalPartitioned/Single/SinglePartitioned)
652
57
                        Some(Ok(batch)) => {
653
57
                            let timer = elapsed_compute.timer();
654
655
                            // Make sure we have enough capacity for `batch`, otherwise spill
656
57
                            
extract_ok!0
(self.spill_previous_if_necessary(&batch));
657
658
                            // Do the grouping
659
57
                            
extract_ok!0
(self.group_aggregate_batch(batch));
660
661
                            // If we can begin emitting rows, do so,
662
                            // otherwise keep consuming input
663
57
                            assert!(!self.input_done);
664
665
                            // If the number of group values equals or exceeds the soft limit,
666
                            // emit all groups and switch to producing output
667
57
                            if self.hit_soft_group_limit() {
668
0
                                timer.done();
669
0
                                extract_ok!(self.set_input_done_and_produce_output());
670
                                // make sure the exec_state just set is not overwritten below
671
0
                                break 'reading_input;
672
57
                            }
673
674
57
                            if let Some(
to_emit8
) = self.group_ordering.emit_to() {
675
8
                                let batch = 
extract_ok!0
(self.emit(to_emit, false));
676
8
                                self.exec_state = ExecutionState::ProducingOutput(batch);
677
8
                                timer.done();
678
8
                                // make sure the exec_state just set is not overwritten below
679
8
                                break 'reading_input;
680
49
                            }
681
49
682
49
                            timer.done();
683
                        }
684
685
                        // Found error from input stream
686
0
                        Some(Err(e)) => {
687
0
                            // inner had error, return to caller
688
0
                            return Poll::Ready(Some(Err(e)));
689
                        }
690
691
                        // Found end from input stream
692
                        None => {
693
                            // inner is done, emit all rows and switch to producing output
694
70
                            
extract_ok!0
(self.set_input_done_and_produce_output());
695
                        }
696
                    }
697
                }
698
699
                ExecutionState::SkippingAggregation => {
700
4
                    match 
ready!0
(self.input.poll_next_unpin(cx)) {
701
2
                        Some(Ok(batch)) => {
702
2
                            let _timer = elapsed_compute.timer();
703
2
                            if let Some(probe) = self.skip_aggregation_probe.as_mut() {
704
2
                                probe.record_skipped(&batch);
705
2
                            }
0
706
2
                            let states = self.transform_to_states(batch)
?0
;
707
2
                            return Poll::Ready(Some(Ok(
708
2
                                states.record_output(&self.baseline_metrics)
709
2
                            )));
710
                        }
711
0
                        Some(Err(e)) => {
712
0
                            // inner had error, return to caller
713
0
                            return Poll::Ready(Some(Err(e)));
714
                        }
715
2
                        None => {
716
2
                            // inner is done, switching to `Done` state
717
2
                            self.exec_state = ExecutionState::Done;
718
2
                        }
719
                    }
720
                }
721
722
120
                ExecutionState::ProducingOutput(batch) => {
723
120
                    // slice off a part of the batch, if needed
724
120
                    let output_batch;
725
120
                    let size = self.batch_size;
726
120
                    (self.exec_state, output_batch) = if batch.num_rows() <= size {
727
                        (
728
104
                            if self.input_done {
729
66
                                ExecutionState::Done
730
                            }
731
                            // In Partial aggregation, we also need to check
732
                            // if we should trigger partial skipping
733
38
                            else if self.mode == AggregateMode::Partial
734
30
                                && self.should_skip_aggregation()
735
                            {
736
2
                                ExecutionState::SkippingAggregation
737
                            } else {
738
36
                                ExecutionState::ReadingInput
739
                            },
740
104
                            batch.clone(),
741
                        )
742
                    } else {
743
                        // output first batch_size rows
744
16
                        let size = self.batch_size;
745
16
                        let num_remaining = batch.num_rows() - size;
746
16
                        let remaining = batch.slice(size, num_remaining);
747
16
                        let output = batch.slice(0, size);
748
16
                        (ExecutionState::ProducingOutput(remaining), output)
749
                    };
750
120
                    return Poll::Ready(Some(Ok(
751
120
                        output_batch.record_output(&self.baseline_metrics)
752
120
                    )));
753
                }
754
755
                ExecutionState::Done => {
756
                    // release the memory reservation since sending back output batch itself needs
757
                    // some memory reservation, so make some room for it.
758
68
                    self.clear_all();
759
68
                    let _ = self.update_memory_reservation();
760
68
                    return Poll::Ready(None);
761
                }
762
            }
763
        }
764
277
    }
765
}
766
767
impl RecordBatchStream for GroupedHashAggregateStream {
768
178
    fn schema(&self) -> SchemaRef {
769
178
        Arc::clone(&self.schema)
770
178
    }
771
}
772
773
impl GroupedHashAggregateStream {
774
    /// Perform group-by aggregation for the given [`RecordBatch`].
775
129
    fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
776
        // Evaluate the grouping expressions
777
129
        let group_by_values = if self.spill_state.is_stream_merging {
778
12
            evaluate_group_by(&self.spill_state.merging_group_by, &batch)
?0
779
        } else {
780
117
            evaluate_group_by(&self.group_by, &batch)
?0
781
        };
782
783
        // Evaluate the aggregation expressions.
784
129
        let input_values = if self.spill_state.is_stream_merging {
785
12
            evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)
?0
786
        } else {
787
117
            evaluate_many(&self.aggregate_arguments, &batch)
?0
788
        };
789
790
        // Evaluate the filter expressions, if any, against the inputs
791
129
        let filter_values = if self.spill_state.is_stream_merging {
792
12
            let filter_expressions = vec![None; self.accumulators.len()];
793
12
            evaluate_optional(&filter_expressions, &batch)
?0
794
        } else {
795
117
            evaluate_optional(&self.filter_expressions, &batch)
?0
796
        };
797
798
298
        for 
group_values169
in &group_by_values {
799
            // calculate the group indices for each input row
800
169
            let starting_num_groups = self.group_values.len();
801
169
            self.group_values
802
169
                .intern(group_values, &mut self.current_group_indices)
?0
;
803
169
            let group_indices = &self.current_group_indices;
804
169
805
169
            // Update ordering information if necessary
806
169
            let total_num_groups = self.group_values.len();
807
169
            if total_num_groups > starting_num_groups {
808
128
                self.group_ordering.new_groups(
809
128
                    group_values,
810
128
                    group_indices,
811
128
                    total_num_groups,
812
128
                )
?0
;
813
41
            }
814
815
            // Gather the inputs to call the actual accumulator
816
169
            let t = self
817
169
                .accumulators
818
169
                .iter_mut()
819
169
                .zip(input_values.iter())
820
169
                .zip(filter_values.iter());
821
822
338
            for ((
acc, values), opt_filter169
) in t {
823
169
                let opt_filter = opt_filter.as_ref().map(|filter| 
filter.as_boolean()0
);
824
825
                // Call the appropriate method on each aggregator with
826
                // the entire input row and the relevant group indexes
827
112
                match self.mode {
828
                    AggregateMode::Partial
829
                    | AggregateMode::Single
830
                    | AggregateMode::SinglePartitioned
831
112
                        if !self.spill_state.is_stream_merging =>
832
                    {
833
112
                        acc.update_batch(
834
112
                            values,
835
112
                            group_indices,
836
112
                            opt_filter,
837
112
                            total_num_groups,
838
112
                        )
?0
;
839
                    }
840
                    _ => {
841
                        // if aggregation is over intermediate states,
842
                        // use merge
843
57
                        acc.merge_batch(
844
57
                            values,
845
57
                            group_indices,
846
57
                            opt_filter,
847
57
                            total_num_groups,
848
57
                        )
?0
;
849
                    }
850
                }
851
            }
852
        }
853
854
129
        match self.update_memory_reservation() {
855
            // Here we can ignore `insufficient_capacity_err` because we will spill later,
856
            // but at least one batch should fit in the memory
857
            Err(DataFusionError::ResourcesExhausted(_))
858
33
                if self.group_values.len() >= self.batch_size =>
859
32
            {
860
32
                Ok(())
861
            }
862
97
            other => other,
863
        }
864
129
    }
865
866
371
    fn update_memory_reservation(&mut self) -> Result<()> {
867
371
        let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
868
371
        self.reservation.try_resize(
869
371
            acc + self.group_values.size()
870
371
                + self.group_ordering.size()
871
371
                + self.current_group_indices.allocated_size(),
872
371
        )
873
371
    }
874
875
    /// Create an output RecordBatch with the group keys and
876
    /// accumulator states/values specified in emit_to
877
112
    fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
878
112
        let schema = if spilling {
879
8
            Arc::clone(&self.spill_state.spill_schema)
880
        } else {
881
104
            self.schema()
882
        };
883
112
        if self.group_values.is_empty() {
884
2
            return Ok(RecordBatch::new_empty(schema));
885
110
        }
886
887
110
        let mut output = self.group_values.emit(emit_to)
?0
;
888
110
        if let EmitTo::First(
n36
) = emit_to {
889
36
            self.group_ordering.remove_groups(n);
890
74
        }
891
892
        // Next output each aggregate value
893
110
        for acc in self.accumulators.iter_mut() {
894
25
            match self.mode {
895
77
                AggregateMode::Partial => output.extend(acc.state(emit_to)
?0
),
896
8
                _ if spilling => {
897
8
                    // If spilling, output partial state because the spilled data will be
898
8
                    // merged and re-evaluated later.
899
8
                    output.extend(acc.state(emit_to)
?0
)
900
                }
901
                AggregateMode::Final
902
                | AggregateMode::FinalPartitioned
903
                | AggregateMode::Single
904
25
                | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)
?0
),
905
            }
906
        }
907
908
        // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is
909
        // over the target memory size after emission, we can emit again rather than returning Err.
910
110
        let _ = self.update_memory_reservation();
911
110
        let batch = RecordBatch::try_new(schema, output)
?0
;
912
110
        Ok(batch)
913
112
    }
914
915
    /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
916
    /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the
917
    /// memory. Currently only [`GroupOrdering::None`] is supported for spilling.
918
57
    fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> {
919
57
        // TODO: support group_ordering for spilling
920
57
        if self.group_values.len() > 0
921
36
            && batch.num_rows() > 0
922
36
            && 
matches!8
(self.group_ordering, GroupOrdering::None)
923
28
            && !self.spill_state.is_stream_merging
924
28
            && self.update_memory_reservation().is_err()
925
        {
926
4
            assert_ne!(self.mode, AggregateMode::Partial);
927
            // Use input batch (Partial mode) schema for spilling because
928
            // the spilled data will be merged and re-evaluated later.
929
4
            self.spill_state.spill_schema = batch.schema();
930
4
            self.spill()
?0
;
931
4
            self.clear_shrink(batch);
932
53
        }
933
57
        Ok(())
934
57
    }
935
936
    /// Emit all rows, sort them, and store them on disk.
937
4
    fn spill(&mut self) -> Result<()> {
938
4
        let emit = self.emit(EmitTo::All, true)
?0
;
939
4
        let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)
?0
;
940
4
        let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")
?0
;
941
        // TODO: slice large `sorted` and write to multiple files in parallel
942
4
        spill_record_batch_by_size(
943
4
            &sorted,
944
4
            spillfile.path().into(),
945
4
            sorted.schema(),
946
4
            self.batch_size,
947
4
        )
?0
;
948
4
        self.spill_state.spills.push(spillfile);
949
4
        Ok(())
950
4
    }
951
952
    /// Clear memory and shirk capacities to the size of the batch.
953
76
    fn clear_shrink(&mut self, batch: &RecordBatch) {
954
76
        self.group_values.clear_shrink(batch);
955
76
        self.current_group_indices.clear();
956
76
        self.current_group_indices.shrink_to(batch.num_rows());
957
76
    }
958
959
    /// Clear memory and shirk capacities to zero.
960
72
    fn clear_all(&mut self) {
961
72
        let s = self.schema();
962
72
        self.clear_shrink(&RecordBatch::new_empty(s));
963
72
    }
964
965
    /// Emit if the used memory exceeds the target for partial aggregation.
966
    /// Currently only [`GroupOrdering::None`] is supported for early emitting.
967
    /// TODO: support group_ordering for early emitting
968
71
    fn emit_early_if_necessary(&mut self) -> Result<()> {
969
71
        if self.group_values.len() >= self.batch_size
970
32
            && 
matches!0
(self.group_ordering, GroupOrdering::None)
971
32
            && self.update_memory_reservation().is_err()
972
        {
973
28
            assert_eq!(self.mode, AggregateMode::Partial);
974
28
            let n = self.group_values.len() / self.batch_size * self.batch_size;
975
28
            let batch = self.emit(EmitTo::First(n), false)
?0
;
976
28
            self.exec_state = ExecutionState::ProducingOutput(batch);
977
43
        }
978
71
        Ok(())
979
71
    }
980
981
    /// At this point, all the inputs are read and there are some spills.
982
    /// Emit the remaining rows and create a batch.
983
    /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
984
    /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
985
4
    fn update_merged_stream(&mut self) -> Result<()> {
986
4
        let batch = self.emit(EmitTo::All, true)
?0
;
987
        // clear up memory for streaming_merge
988
4
        self.clear_all();
989
4
        self.update_memory_reservation()
?0
;
990
4
        let mut streams: Vec<SendableRecordBatchStream> = vec![];
991
4
        let expr = self.spill_state.spill_expr.clone();
992
4
        let schema = batch.schema();
993
4
        streams.push(Box::pin(RecordBatchStreamAdapter::new(
994
4
            Arc::clone(&schema),
995
4
            futures::stream::once(futures::future::lazy(move |_| {
996
4
                sort_batch(&batch, &expr, None)
997
4
            })),
998
4
        )));
999
4
        for spill in self.spill_state.spills.drain(..) {
1000
4
            let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)
?0
;
1001
4
            streams.push(stream);
1002
        }
1003
4
        self.spill_state.is_stream_merging = true;
1004
4
        self.input = StreamingMergeBuilder::new()
1005
4
            .with_streams(streams)
1006
4
            .with_schema(schema)
1007
4
            .with_expressions(&self.spill_state.spill_expr)
1008
4
            .with_metrics(self.baseline_metrics.clone())
1009
4
            .with_batch_size(self.batch_size)
1010
4
            .with_reservation(self.reservation.new_empty())
1011
4
            .build()
?0
;
1012
4
        self.input_done = false;
1013
4
        self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
1014
4
        Ok(())
1015
4
    }
1016
1017
    /// returns true if there is a soft groups limit and the number of distinct
1018
    /// groups we have seen is over that limit
1019
128
    fn hit_soft_group_limit(&self) -> bool {
1020
128
        let Some(
group_values_soft_limit0
) = self.group_values_soft_limit else {
1021
128
            return false;
1022
        };
1023
0
        group_values_soft_limit <= self.group_values.len()
1024
128
    }
1025
1026
    /// common function for signalling end of processing of the input stream
1027
70
    fn set_input_done_and_produce_output(&mut self) -> Result<()> {
1028
70
        self.input_done = true;
1029
70
        self.group_ordering.input_done();
1030
70
        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1031
70
        let timer = elapsed_compute.timer();
1032
70
        self.exec_state = if self.spill_state.spills.is_empty() {
1033
66
            let batch = self.emit(EmitTo::All, false)
?0
;
1034
66
            ExecutionState::ProducingOutput(batch)
1035
        } else {
1036
            // If spill files exist, stream-merge them.
1037
4
            self.update_merged_stream()
?0
;
1038
4
            ExecutionState::ReadingInput
1039
        };
1040
70
        timer.done();
1041
70
        Ok(())
1042
70
    }
1043
1044
    /// Updates skip aggregation probe state.
1045
    ///
1046
    /// Notice: It should only be called in Partial aggregation
1047
71
    fn update_skip_aggregation_probe(&mut self, input_rows: usize) {
1048
71
        if let Some(
probe51
) = self.skip_aggregation_probe.as_mut() {
1049
            // Skip aggregation probe is not supported if stream has any spills,
1050
            // currently spilling is not supported for Partial aggregation
1051
51
            assert!(self.spill_state.spills.is_empty());
1052
51
            probe.update_state(input_rows, self.group_values.len());
1053
20
        };
1054
71
    }
1055
1056
    /// In case the probe indicates that aggregation may be
1057
    /// skipped, forces stream to produce currently accumulated output.
1058
    ///
1059
    /// Notice: It should only be called in Partial aggregation
1060
71
    fn switch_to_skip_aggregation(&mut self) -> Result<()> {
1061
71
        if let Some(
probe51
) = self.skip_aggregation_probe.as_mut() {
1062
51
            if probe.should_skip() {
1063
2
                let batch = self.emit(EmitTo::All, false)
?0
;
1064
2
                self.exec_state = ExecutionState::ProducingOutput(batch);
1065
49
            }
1066
20
        }
1067
1068
71
        Ok(())
1069
71
    }
1070
1071
    /// Returns true if the aggregation probe indicates that aggregation
1072
    /// should be skipped.
1073
    ///
1074
    /// Notice: It should only be called in Partial aggregation
1075
30
    fn should_skip_aggregation(&self) -> bool {
1076
30
        self.skip_aggregation_probe
1077
30
            .as_ref()
1078
30
            .is_some_and(|probe| 
probe.should_skip()26
)
1079
30
    }
1080
1081
    /// Transforms input batch to intermediate aggregate state, without grouping it
1082
2
    fn transform_to_states(&self, batch: RecordBatch) -> Result<RecordBatch> {
1083
2
        let mut group_values = evaluate_group_by(&self.group_by, &batch)
?0
;
1084
2
        let input_values = evaluate_many(&self.aggregate_arguments, &batch)
?0
;
1085
2
        let filter_values = evaluate_optional(&self.filter_expressions, &batch)
?0
;
1086
1087
2
        if group_values.len() != 1 {
1088
0
            return internal_err!("group_values expected to have single element");
1089
2
        }
1090
2
        let mut output = group_values.swap_remove(0);
1091
2
1092
2
        let iter = self
1093
2
            .accumulators
1094
2
            .iter()
1095
2
            .zip(input_values.iter())
1096
2
            .zip(filter_values.iter());
1097
1098
4
        for ((
acc, values), opt_filter2
) in iter {
1099
2
            let opt_filter = opt_filter.as_ref().map(|filter| 
filter.as_boolean()0
);
1100
2
            output.extend(acc.convert_to_state(values, opt_filter)
?0
);
1101
        }
1102
1103
2
        let states_batch = RecordBatch::try_new(self.schema(), output)
?0
;
1104
1105
2
        Ok(states_batch)
1106
2
    }
1107
}