Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/mod.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Aggregates functionalities
19
20
use std::any::Any;
21
use std::sync::Arc;
22
23
use super::{DisplayAs, ExecutionMode, ExecutionPlanProperties, PlanProperties};
24
use crate::aggregates::{
25
    no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26
    topk_stream::GroupedTopKAggregateStream,
27
};
28
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
29
use crate::projection::get_field_metadata;
30
use crate::windows::get_ordered_partition_by_indices;
31
use crate::{
32
    DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
33
    SendableRecordBatchStream, Statistics,
34
};
35
36
use arrow::array::ArrayRef;
37
use arrow::datatypes::{Field, Schema, SchemaRef};
38
use arrow::record_batch::RecordBatch;
39
use datafusion_common::stats::Precision;
40
use datafusion_common::{internal_err, not_impl_err, Result};
41
use datafusion_execution::TaskContext;
42
use datafusion_expr::Accumulator;
43
use datafusion_physical_expr::{
44
    equivalence::{collapse_lex_req, ProjectionMapping},
45
    expressions::Column,
46
    physical_exprs_contains, EquivalenceProperties, LexOrdering, LexRequirement,
47
    PhysicalExpr, PhysicalSortRequirement,
48
};
49
50
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
51
use itertools::Itertools;
52
53
pub mod group_values;
54
mod no_grouping;
55
pub mod order;
56
mod row_hash;
57
mod topk;
58
mod topk_stream;
59
60
/// Hash aggregate modes
61
///
62
/// See [`Accumulator::state`] for background information on multi-phase
63
/// aggregation and how these modes are used.
64
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
65
pub enum AggregateMode {
66
    /// Partial aggregate that can be applied in parallel across input
67
    /// partitions.
68
    ///
69
    /// This is the first phase of a multi-phase aggregation.
70
    Partial,
71
    /// Final aggregate that produces a single partition of output by combining
72
    /// the output of multiple partial aggregates.
73
    ///
74
    /// This is the second phase of a multi-phase aggregation.
75
    Final,
76
    /// Final aggregate that works on pre-partitioned data.
77
    ///
78
    /// This requires the invariant that all rows with a particular
79
    /// grouping key are in the same partitions, such as is the case
80
    /// with Hash repartitioning on the group keys. If a group key is
81
    /// duplicated, duplicate groups would be produced
82
    FinalPartitioned,
83
    /// Applies the entire logical aggregation operation in a single operator,
84
    /// as opposed to Partial / Final modes which apply the logical aggregation using
85
    /// two operators.
86
    ///
87
    /// This mode requires that the input is a single partition (like Final)
88
    Single,
89
    /// Applies the entire logical aggregation operation in a single operator,
90
    /// as opposed to Partial / Final modes which apply the logical aggregation using
91
    /// two operators.
92
    ///
93
    /// This mode requires that the input is partitioned by group key (like
94
    /// FinalPartitioned)
95
    SinglePartitioned,
96
}
97
98
impl AggregateMode {
99
    /// Checks whether this aggregation step describes a "first stage" calculation.
100
    /// In other words, its input is not another aggregation result and the
101
    /// `merge_batch` method will not be called for these modes.
102
59
    pub fn is_first_stage(&self) -> bool {
103
59
        match self {
104
            AggregateMode::Partial
105
            | AggregateMode::Single
106
34
            | AggregateMode::SinglePartitioned => true,
107
25
            AggregateMode::Final | AggregateMode::FinalPartitioned => false,
108
        }
109
59
    }
110
}
111
112
/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
113
/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
114
/// and a single group [false, false].
115
/// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression
116
/// into multiple groups, using null expressions to align each group.
117
/// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should
118
/// create a `PhysicalGroupBy` like
119
/// ```text
120
/// PhysicalGroupBy {
121
///     expr: [(col(a), a), (col(b), b)],
122
///     null_expr: [(NULL, a), (NULL, b)],
123
///     groups: [
124
///         [false, false], // (a,b)
125
///         [false, true],  // (a) <=> (a, NULL)
126
///         [true, false]   // (b) <=> (NULL, b)
127
///     ]
128
/// }
129
/// ```
130
#[derive(Clone, Debug, Default)]
131
pub struct PhysicalGroupBy {
132
    /// Distinct (Physical Expr, Alias) in the grouping set
133
    expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
134
    /// Corresponding NULL expressions for expr
135
    null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
136
    /// Null mask for each group in this grouping set. Each group is
137
    /// composed of either one of the group expressions in expr or a null
138
    /// expression in null_expr. If `groups[i][j]` is true, then the
139
    /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
140
    groups: Vec<Vec<bool>>,
141
}
142
143
impl PhysicalGroupBy {
144
    /// Create a new `PhysicalGroupBy`
145
1
    pub fn new(
146
1
        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
147
1
        null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
148
1
        groups: Vec<Vec<bool>>,
149
1
    ) -> Self {
150
1
        Self {
151
1
            expr,
152
1
            null_expr,
153
1
            groups,
154
1
        }
155
1
    }
156
157
    /// Create a GROUPING SET with only a single group. This is the "standard"
158
    /// case when building a plan from an expression such as `GROUP BY a,b,c`
159
100
    pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
160
100
        let num_exprs = expr.len();
161
100
        Self {
162
100
            expr,
163
100
            null_expr: vec![],
164
100
            groups: vec![vec![false; num_exprs]],
165
100
        }
166
100
    }
167
168
    /// Calculate GROUP BY expressions nullable
169
50
    pub fn exprs_nullable(&self) -> Vec<bool> {
170
50
        let mut exprs_nullable = vec![false; self.expr.len()];
171
59
        for group in 
self.groups.iter()50
{
172
83
            
group.iter().enumerate().for_each(59
|(index, is_null)| {
173
83
                if *is_null {
174
15
                    exprs_nullable[index] = true;
175
68
                }
176
83
            }
)59
177
        }
178
50
        exprs_nullable
179
50
    }
180
181
    /// Returns the group expressions
182
4
    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
183
4
        &self.expr
184
4
    }
185
186
    /// Returns the null expressions
187
0
    pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
188
0
        &self.null_expr
189
0
    }
190
191
    /// Returns the group null masks
192
0
    pub fn groups(&self) -> &[Vec<bool>] {
193
0
        &self.groups
194
0
    }
195
196
    /// Returns true if this `PhysicalGroupBy` has no group expressions
197
0
    pub fn is_empty(&self) -> bool {
198
0
        self.expr.is_empty()
199
0
    }
200
201
    /// Check whether grouping set is single group
202
62
    pub fn is_single(&self) -> bool {
203
62
        self.null_expr.is_empty()
204
62
    }
205
206
    /// Calculate GROUP BY expressions according to input schema.
207
59
    pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
208
59
        self.expr
209
59
            .iter()
210
59
            .map(|(expr, _alias)| 
Arc::clone(expr)58
)
211
59
            .collect()
212
59
    }
213
214
    /// Return grouping expressions as they occur in the output schema.
215
70
    pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
216
70
        self.expr
217
70
            .iter()
218
70
            .enumerate()
219
84
            .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _)
220
70
            .collect()
221
70
    }
222
}
223
224
impl PartialEq for PhysicalGroupBy {
225
0
    fn eq(&self, other: &PhysicalGroupBy) -> bool {
226
0
        self.expr.len() == other.expr.len()
227
0
            && self
228
0
                .expr
229
0
                .iter()
230
0
                .zip(other.expr.iter())
231
0
                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
232
0
            && self.null_expr.len() == other.null_expr.len()
233
0
            && self
234
0
                .null_expr
235
0
                .iter()
236
0
                .zip(other.null_expr.iter())
237
0
                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
238
0
            && self.groups == other.groups
239
0
    }
240
}
241
242
enum StreamType {
243
    AggregateStream(AggregateStream),
244
    GroupedHash(GroupedHashAggregateStream),
245
    GroupedPriorityQueue(GroupedTopKAggregateStream),
246
}
247
248
impl From<StreamType> for SendableRecordBatchStream {
249
72
    fn from(stream: StreamType) -> Self {
250
72
        match stream {
251
2
            StreamType::AggregateStream(stream) => Box::pin(stream),
252
70
            StreamType::GroupedHash(stream) => Box::pin(stream),
253
0
            StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
254
        }
255
72
    }
256
}
257
258
/// Hash aggregate execution plan
259
#[derive(Debug)]
260
pub struct AggregateExec {
261
    /// Aggregation mode (full, partial)
262
    mode: AggregateMode,
263
    /// Group by expressions
264
    group_by: PhysicalGroupBy,
265
    /// Aggregate expressions
266
    aggr_expr: Vec<AggregateFunctionExpr>,
267
    /// FILTER (WHERE clause) expression for each aggregate expression
268
    filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
269
    /// Set if the output of this aggregation is truncated by a upstream sort/limit clause
270
    limit: Option<usize>,
271
    /// Input plan, could be a partial aggregate or the input to the aggregate
272
    pub input: Arc<dyn ExecutionPlan>,
273
    /// Schema after the aggregate is applied
274
    schema: SchemaRef,
275
    /// Input schema before any aggregation is applied. For partial aggregate this will be the
276
    /// same as input.schema() but for the final aggregate it will be the same as the input
277
    /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`.
278
    /// We need the input schema of partial aggregate to be able to deserialize aggregate
279
    /// expressions from protobuf for final aggregate.
280
    pub input_schema: SchemaRef,
281
    /// Execution metrics
282
    metrics: ExecutionPlanMetricsSet,
283
    required_input_ordering: Option<LexRequirement>,
284
    /// Describes how the input is ordered relative to the group by columns
285
    input_order_mode: InputOrderMode,
286
    cache: PlanProperties,
287
}
288
289
impl AggregateExec {
290
    /// Function used in `OptimizeAggregateOrder` optimizer rule,
291
    /// where we need parts of the new value, others cloned from the old one
292
    /// Rewrites aggregate exec with new aggregate expressions.
293
0
    pub fn with_new_aggr_exprs(&self, aggr_expr: Vec<AggregateFunctionExpr>) -> Self {
294
0
        Self {
295
0
            aggr_expr,
296
0
            // clone the rest of the fields
297
0
            required_input_ordering: self.required_input_ordering.clone(),
298
0
            metrics: ExecutionPlanMetricsSet::new(),
299
0
            input_order_mode: self.input_order_mode.clone(),
300
0
            cache: self.cache.clone(),
301
0
            mode: self.mode,
302
0
            group_by: self.group_by.clone(),
303
0
            filter_expr: self.filter_expr.clone(),
304
0
            limit: self.limit,
305
0
            input: Arc::clone(&self.input),
306
0
            schema: Arc::clone(&self.schema),
307
0
            input_schema: Arc::clone(&self.input_schema),
308
0
        }
309
0
    }
310
311
0
    pub fn cache(&self) -> &PlanProperties {
312
0
        &self.cache
313
0
    }
314
315
    /// Create a new hash aggregate execution plan
316
49
    pub fn try_new(
317
49
        mode: AggregateMode,
318
49
        group_by: PhysicalGroupBy,
319
49
        aggr_expr: Vec<AggregateFunctionExpr>,
320
49
        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
321
49
        input: Arc<dyn ExecutionPlan>,
322
49
        input_schema: SchemaRef,
323
49
    ) -> Result<Self> {
324
49
        let schema = create_schema(
325
49
            &input.schema(),
326
49
            &group_by.expr,
327
49
            &aggr_expr,
328
49
            group_by.exprs_nullable(),
329
49
            mode,
330
49
        )
?0
;
331
332
49
        let schema = Arc::new(schema);
333
49
        AggregateExec::try_new_with_schema(
334
49
            mode,
335
49
            group_by,
336
49
            aggr_expr,
337
49
            filter_expr,
338
49
            input,
339
49
            input_schema,
340
49
            schema,
341
49
        )
342
49
    }
343
344
    /// Create a new hash aggregate execution plan with the given schema.
345
    /// This constructor isn't part of the public API, it is used internally
346
    /// by DataFusion to enforce schema consistency during when re-creating
347
    /// `AggregateExec`s inside optimization rules. Schema field names of an
348
    /// `AggregateExec` depends on the names of aggregate expressions. Since
349
    /// a rule may re-write aggregate expressions (e.g. reverse them) during
350
    /// initialization, field names may change inadvertently if one re-creates
351
    /// the schema in such cases.
352
    #[allow(clippy::too_many_arguments)]
353
50
    fn try_new_with_schema(
354
50
        mode: AggregateMode,
355
50
        group_by: PhysicalGroupBy,
356
50
        mut aggr_expr: Vec<AggregateFunctionExpr>,
357
50
        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
358
50
        input: Arc<dyn ExecutionPlan>,
359
50
        input_schema: SchemaRef,
360
50
        schema: SchemaRef,
361
50
    ) -> Result<Self> {
362
50
        // Make sure arguments are consistent in size
363
50
        if aggr_expr.len() != filter_expr.len() {
364
0
            return internal_err!("Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr);
365
50
        }
366
50
367
50
        let input_eq_properties = input.equivalence_properties();
368
50
        // Get GROUP BY expressions:
369
50
        let groupby_exprs = group_by.input_exprs();
370
50
        // If existing ordering satisfies a prefix of the GROUP BY expressions,
371
50
        // prefix requirements with this section. In this case, aggregation will
372
50
        // work more efficiently.
373
50
        let indices = get_ordered_partition_by_indices(&groupby_exprs, &input);
374
50
        let mut new_requirement = LexRequirement::new(
375
50
            indices
376
50
                .iter()
377
50
                .map(|&idx| PhysicalSortRequirement {
378
1
                    expr: Arc::clone(&groupby_exprs[idx]),
379
1
                    options: None,
380
50
                })
381
50
                .collect::<Vec<_>>(),
382
50
        );
383
384
50
        let req = get_finer_aggregate_exprs_requirement(
385
50
            &mut aggr_expr,
386
50
            &group_by,
387
50
            input_eq_properties,
388
50
            &mode,
389
50
        )
?0
;
390
50
        new_requirement.inner.extend(req);
391
50
        new_requirement = collapse_lex_req(new_requirement);
392
50
393
50
        // If our aggregation has grouping sets then our base grouping exprs will
394
50
        // be expanded based on the flags in `group_by.groups` where for each
395
50
        // group we swap the grouping expr for `null` if the flag is `true`
396
50
        // That means that each index in `indices` is valid if and only if
397
50
        // it is not null in every group
398
50
        let indices: Vec<usize> = indices
399
50
            .into_iter()
400
50
            .filter(|idx| 
group_by.groups.iter().all(1
|group|
!group[*idx]1
)1
)
401
50
            .collect();
402
403
50
        let input_order_mode = if indices.len() == groupby_exprs.len()
404
2
            && !indices.is_empty()
405
0
            && group_by.groups.len() == 1
406
        {
407
0
            InputOrderMode::Sorted
408
50
        } else if !indices.is_empty() {
409
0
            InputOrderMode::PartiallySorted(indices)
410
        } else {
411
50
            InputOrderMode::Linear
412
        };
413
414
        // construct a map from the input expression to the output expression of the Aggregation group by
415
50
        let projection_mapping =
416
50
            ProjectionMapping::try_new(&group_by.expr, &input.schema())
?0
;
417
418
50
        let required_input_ordering =
419
50
            (!new_requirement.is_empty()).then_some(new_requirement);
420
50
421
50
        let cache = Self::compute_properties(
422
50
            &input,
423
50
            Arc::clone(&schema),
424
50
            &projection_mapping,
425
50
            &mode,
426
50
            &input_order_mode,
427
50
        );
428
50
429
50
        Ok(AggregateExec {
430
50
            mode,
431
50
            group_by,
432
50
            aggr_expr,
433
50
            filter_expr,
434
50
            input,
435
50
            schema,
436
50
            input_schema,
437
50
            metrics: ExecutionPlanMetricsSet::new(),
438
50
            required_input_ordering,
439
50
            limit: None,
440
50
            input_order_mode,
441
50
            cache,
442
50
        })
443
50
    }
444
445
    /// Aggregation mode (full, partial)
446
0
    pub fn mode(&self) -> &AggregateMode {
447
0
        &self.mode
448
0
    }
449
450
    /// Set the `limit` of this AggExec
451
0
    pub fn with_limit(mut self, limit: Option<usize>) -> Self {
452
0
        self.limit = limit;
453
0
        self
454
0
    }
455
    /// Grouping expressions
456
4
    pub fn group_expr(&self) -> &PhysicalGroupBy {
457
4
        &self.group_by
458
4
    }
459
460
    /// Grouping expressions as they occur in the output schema
461
0
    pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
462
0
        self.group_by.output_exprs()
463
0
    }
464
465
    /// Aggregate expressions
466
0
    pub fn aggr_expr(&self) -> &[AggregateFunctionExpr] {
467
0
        &self.aggr_expr
468
0
    }
469
470
    /// FILTER (WHERE clause) expression for each aggregate expression
471
0
    pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
472
0
        &self.filter_expr
473
0
    }
474
475
    /// Input plan
476
16
    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
477
16
        &self.input
478
16
    }
479
480
    /// Get the input schema before any aggregates are applied
481
0
    pub fn input_schema(&self) -> SchemaRef {
482
0
        Arc::clone(&self.input_schema)
483
0
    }
484
485
    /// number of rows soft limit of the AggregateExec
486
0
    pub fn limit(&self) -> Option<usize> {
487
0
        self.limit
488
0
    }
489
490
72
    fn execute_typed(
491
72
        &self,
492
72
        partition: usize,
493
72
        context: Arc<TaskContext>,
494
72
    ) -> Result<StreamType> {
495
72
        // no group by at all
496
72
        if self.group_by.expr.is_empty() {
497
2
            return Ok(StreamType::AggregateStream(AggregateStream::new(
498
2
                self, context, partition,
499
2
            )
?0
));
500
70
        }
501
502
        // grouping by an expression that has a sort/limit upstream
503
70
        if let Some(
limit0
) = self.limit {
504
0
            if !self.is_unordered_unfiltered_group_by_distinct() {
505
                return Ok(StreamType::GroupedPriorityQueue(
506
0
                    GroupedTopKAggregateStream::new(self, context, partition, limit)?,
507
                ));
508
0
            }
509
70
        }
510
511
        // grouping by something else and we need to just materialize all results
512
70
        Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
513
70
            self, context, partition,
514
70
        )
?0
))
515
72
    }
516
517
    /// Finds the DataType and SortDirection for this Aggregate, if there is one
518
0
    pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
519
0
        let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
520
0
        agg_expr.get_minmax_desc()
521
0
    }
522
523
    /// true, if this Aggregate has a group-by with no required or explicit ordering,
524
    /// no filtering and no aggregate expressions
525
    /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule
526
    /// on an AggregateExec.
527
0
    pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
528
0
        // ensure there is a group by
529
0
        if self.group_expr().is_empty() {
530
0
            return false;
531
0
        }
532
0
        // ensure there are no aggregate expressions
533
0
        if !self.aggr_expr().is_empty() {
534
0
            return false;
535
0
        }
536
0
        // ensure there are no filters on aggregate expressions; the above check
537
0
        // may preclude this case
538
0
        if self.filter_expr().iter().any(|e| e.is_some()) {
539
0
            return false;
540
0
        }
541
0
        // ensure there are no order by expressions
542
0
        if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) {
543
0
            return false;
544
0
        }
545
0
        // ensure there is no output ordering; can this rule be relaxed?
546
0
        if self.properties().output_ordering().is_some() {
547
0
            return false;
548
0
        }
549
0
        // ensure no ordering is required on the input
550
0
        if self.required_input_ordering()[0].is_some() {
551
0
            return false;
552
0
        }
553
0
        true
554
0
    }
555
556
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
557
50
    pub fn compute_properties(
558
50
        input: &Arc<dyn ExecutionPlan>,
559
50
        schema: SchemaRef,
560
50
        projection_mapping: &ProjectionMapping,
561
50
        mode: &AggregateMode,
562
50
        input_order_mode: &InputOrderMode,
563
50
    ) -> PlanProperties {
564
50
        // Construct equivalence properties:
565
50
        let eq_properties = input
566
50
            .equivalence_properties()
567
50
            .project(projection_mapping, schema);
568
50
569
50
        // Get output partitioning:
570
50
        let input_partitioning = input.output_partitioning().clone();
571
50
        let output_partitioning = if mode.is_first_stage() {
572
            // First stage aggregation will not change the output partitioning,
573
            // but needs to respect aliases (e.g. mapping in the GROUP BY
574
            // expression).
575
25
            let input_eq_properties = input.equivalence_properties();
576
25
            input_partitioning.project(projection_mapping, input_eq_properties)
577
        } else {
578
25
            input_partitioning.clone()
579
        };
580
581
        // Determine execution mode:
582
50
        let mut exec_mode = input.execution_mode();
583
50
        if exec_mode == ExecutionMode::Unbounded
584
0
            && *input_order_mode == InputOrderMode::Linear
585
0
        {
586
0
            // Cannot run without breaking the pipeline
587
0
            exec_mode = ExecutionMode::PipelineBreaking;
588
50
        }
589
590
50
        PlanProperties::new(eq_properties, output_partitioning, exec_mode)
591
50
    }
592
593
0
    pub fn input_order_mode(&self) -> &InputOrderMode {
594
0
        &self.input_order_mode
595
0
    }
596
}
597
598
impl DisplayAs for AggregateExec {
599
0
    fn fmt_as(
600
0
        &self,
601
0
        t: DisplayFormatType,
602
0
        f: &mut std::fmt::Formatter,
603
0
    ) -> std::fmt::Result {
604
0
        match t {
605
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
606
0
                write!(f, "AggregateExec: mode={:?}", self.mode)?;
607
0
                let g: Vec<String> = if self.group_by.is_single() {
608
0
                    self.group_by
609
0
                        .expr
610
0
                        .iter()
611
0
                        .map(|(e, alias)| {
612
0
                            let e = e.to_string();
613
0
                            if &e != alias {
614
0
                                format!("{e} as {alias}")
615
                            } else {
616
0
                                e
617
                            }
618
0
                        })
619
0
                        .collect()
620
                } else {
621
0
                    self.group_by
622
0
                        .groups
623
0
                        .iter()
624
0
                        .map(|group| {
625
0
                            let terms = group
626
0
                                .iter()
627
0
                                .enumerate()
628
0
                                .map(|(idx, is_null)| {
629
0
                                    if *is_null {
630
0
                                        let (e, alias) = &self.group_by.null_expr[idx];
631
0
                                        let e = e.to_string();
632
0
                                        if &e != alias {
633
0
                                            format!("{e} as {alias}")
634
                                        } else {
635
0
                                            e
636
                                        }
637
                                    } else {
638
0
                                        let (e, alias) = &self.group_by.expr[idx];
639
0
                                        let e = e.to_string();
640
0
                                        if &e != alias {
641
0
                                            format!("{e} as {alias}")
642
                                        } else {
643
0
                                            e
644
                                        }
645
                                    }
646
0
                                })
647
0
                                .collect::<Vec<String>>()
648
0
                                .join(", ");
649
0
                            format!("({terms})")
650
0
                        })
651
0
                        .collect()
652
                };
653
654
0
                write!(f, ", gby=[{}]", g.join(", "))?;
655
656
0
                let a: Vec<String> = self
657
0
                    .aggr_expr
658
0
                    .iter()
659
0
                    .map(|agg| agg.name().to_string())
660
0
                    .collect();
661
0
                write!(f, ", aggr=[{}]", a.join(", "))?;
662
0
                if let Some(limit) = self.limit {
663
0
                    write!(f, ", lim=[{limit}]")?;
664
0
                }
665
666
0
                if self.input_order_mode != InputOrderMode::Linear {
667
0
                    write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
668
0
                }
669
            }
670
        }
671
0
        Ok(())
672
0
    }
673
}
674
675
impl ExecutionPlan for AggregateExec {
676
0
    fn name(&self) -> &'static str {
677
0
        "AggregateExec"
678
0
    }
679
680
    /// Return a reference to Any that can be used for down-casting
681
0
    fn as_any(&self) -> &dyn Any {
682
0
        self
683
0
    }
684
685
154
    fn properties(&self) -> &PlanProperties {
686
154
        &self.cache
687
154
    }
688
689
0
    fn required_input_distribution(&self) -> Vec<Distribution> {
690
0
        match &self.mode {
691
            AggregateMode::Partial => {
692
0
                vec![Distribution::UnspecifiedDistribution]
693
            }
694
            AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
695
0
                vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
696
            }
697
            AggregateMode::Final | AggregateMode::Single => {
698
0
                vec![Distribution::SinglePartition]
699
            }
700
        }
701
0
    }
702
703
0
    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
704
0
        vec![self.required_input_ordering.clone()]
705
0
    }
706
707
0
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
708
0
        vec![&self.input]
709
0
    }
710
711
1
    fn with_new_children(
712
1
        self: Arc<Self>,
713
1
        children: Vec<Arc<dyn ExecutionPlan>>,
714
1
    ) -> Result<Arc<dyn ExecutionPlan>> {
715
1
        let mut me = AggregateExec::try_new_with_schema(
716
1
            self.mode,
717
1
            self.group_by.clone(),
718
1
            self.aggr_expr.clone(),
719
1
            self.filter_expr.clone(),
720
1
            Arc::clone(&children[0]),
721
1
            Arc::clone(&self.input_schema),
722
1
            Arc::clone(&self.schema),
723
1
        )
?0
;
724
1
        me.limit = self.limit;
725
1
726
1
        Ok(Arc::new(me))
727
1
    }
728
729
70
    fn execute(
730
70
        &self,
731
70
        partition: usize,
732
70
        context: Arc<TaskContext>,
733
70
    ) -> Result<SendableRecordBatchStream> {
734
70
        self.execute_typed(partition, context)
735
70
            .map(|stream| stream.into())
736
70
    }
737
738
8
    fn metrics(&self) -> Option<MetricsSet> {
739
8
        Some(self.metrics.clone_inner())
740
8
    }
741
742
8
    fn statistics(&self) -> Result<Statistics> {
743
8
        // TODO stats: group expressions:
744
8
        // - once expressions will be able to compute their own stats, use it here
745
8
        // - case where we group by on a column for which with have the `distinct` stat
746
8
        // TODO stats: aggr expression:
747
8
        // - aggregations sometimes also preserve invariants such as min, max...
748
8
        let column_statistics = Statistics::unknown_column(&self.schema());
749
8
        match self.mode {
750
8
            AggregateMode::Final | AggregateMode::FinalPartitioned
751
8
                if self.group_by.expr.is_empty(
)0
=>
752
            {
753
0
                Ok(Statistics {
754
0
                    num_rows: Precision::Exact(1),
755
0
                    column_statistics,
756
0
                    total_byte_size: Precision::Absent,
757
0
                })
758
            }
759
            _ => {
760
                // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability.
761
                // When it is larger than 1, we degrade the precision since it may decrease after aggregation.
762
8
                let num_rows = if let Some(value) =
763
8
                    self.input().statistics()
?0
.num_rows.get_value()
764
                {
765
8
                    if *value > 1 {
766
8
                        self.input().statistics()
?0
.num_rows.to_inexact()
767
0
                    } else if *value == 0 {
768
                        // Aggregation on an empty table creates a null row.
769
0
                        self.input()
770
0
                            .statistics()?
771
                            .num_rows
772
0
                            .add(&Precision::Exact(1))
773
                    } else {
774
                        // num_rows = 1 case
775
0
                        self.input().statistics()?.num_rows
776
                    }
777
                } else {
778
0
                    Precision::Absent
779
                };
780
8
                Ok(Statistics {
781
8
                    num_rows,
782
8
                    column_statistics,
783
8
                    total_byte_size: Precision::Absent,
784
8
                })
785
            }
786
        }
787
8
    }
788
}
789
790
50
fn create_schema(
791
50
    input_schema: &Schema,
792
50
    group_expr: &[(Arc<dyn PhysicalExpr>, String)],
793
50
    aggr_expr: &[AggregateFunctionExpr],
794
50
    group_expr_nullable: Vec<bool>,
795
50
    mode: AggregateMode,
796
50
) -> Result<Schema> {
797
50
    let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
798
59
    for (index, (expr, name)) in 
group_expr.iter().enumerate()50
{
799
59
        fields.push(
800
59
            Field::new(
801
59
                name,
802
59
                expr.data_type(input_schema)
?0
,
803
                // In cases where we have multiple grouping sets, we will use NULL expressions in
804
                // order to align the grouping sets. So the field must be nullable even if the underlying
805
                // schema field is not.
806
59
                group_expr_nullable[index] || 
expr.nullable(input_schema)47
?0
,
807
            )
808
59
            .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()),
809
        )
810
    }
811
812
50
    match mode {
813
        AggregateMode::Partial => {
814
            // in partial mode, the fields of the accumulator's state
815
49
            for 
expr25
in aggr_expr {
816
25
                fields.extend(expr.state_fields()
?0
.iter().cloned())
817
            }
818
        }
819
        AggregateMode::Final
820
        | AggregateMode::FinalPartitioned
821
        | AggregateMode::Single
822
        | AggregateMode::SinglePartitioned => {
823
            // in final mode, the field with the final result of the accumulator
824
44
            for 
expr18
in aggr_expr {
825
18
                fields.push(expr.field())
826
            }
827
        }
828
    }
829
830
50
    Ok(Schema::new_with_metadata(
831
50
        fields,
832
50
        input_schema.metadata().clone(),
833
50
    ))
834
50
}
835
836
70
fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
837
70
    let group_fields = schema.fields()[0..group_count].to_vec();
838
70
    Arc::new(Schema::new(group_fields))
839
70
}
840
841
/// Determines the lexical ordering requirement for an aggregate expression.
842
///
843
/// # Parameters
844
///
845
/// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the
846
///   aggregate expression.
847
/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
848
///   physical GROUP BY expression.
849
/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
850
///   mode of aggregation.
851
///
852
/// # Returns
853
///
854
/// A `LexOrdering` instance indicating the lexical ordering requirement for
855
/// the aggregate expression.
856
54
fn get_aggregate_expr_req(
857
54
    aggr_expr: &AggregateFunctionExpr,
858
54
    group_by: &PhysicalGroupBy,
859
54
    agg_mode: &AggregateMode,
860
54
) -> LexOrdering {
861
54
    // If the aggregation function is ordering requirement is not absolutely
862
54
    // necessary, or the aggregation is performing a "second stage" calculation,
863
54
    // then ignore the ordering requirement.
864
54
    if !aggr_expr.order_sensitivity().hard_requires() || 
!agg_mode.is_first_stage()9
{
865
45
        return vec![];
866
9
    }
867
9
868
9
    let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec();
869
9
870
9
    // In non-first stage modes, we accumulate data (using `merge_batch`) from
871
9
    // different partitions (i.e. merge partial results). During this merge, we
872
9
    // consider the ordering of each partial result. Hence, we do not need to
873
9
    // use the ordering requirement in such modes as long as partial results are
874
9
    // generated with the correct ordering.
875
9
    if group_by.is_single() {
876
9
        // Remove all orderings that occur in the group by. These requirements
877
9
        // will definitely be satisfied -- Each group by expression will have
878
9
        // distinct values per group, hence all requirements are satisfied.
879
9
        let physical_exprs = group_by.input_exprs();
880
18
        req.retain(|sort_expr| {
881
18
            !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
882
18
        });
883
9
    }
0
884
9
    req
885
54
}
886
887
/// Computes the finer ordering for between given existing ordering requirement
888
/// of aggregate expression.
889
///
890
/// # Parameters
891
///
892
/// * `existing_req` - The existing lexical ordering that needs refinement.
893
/// * `aggr_expr` - A reference to an aggregate expression trait object.
894
/// * `group_by` - Information about the physical grouping (e.g group by expression).
895
/// * `eq_properties` - Equivalence properties relevant to the computation.
896
/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.).
897
///
898
/// # Returns
899
///
900
/// An `Option<LexOrdering>` representing the computed finer lexical ordering,
901
/// or `None` if there is no finer ordering; e.g. the existing requirement and
902
/// the aggregator requirement is incompatible.
903
54
fn finer_ordering(
904
54
    existing_req: &LexOrdering,
905
54
    aggr_expr: &AggregateFunctionExpr,
906
54
    group_by: &PhysicalGroupBy,
907
54
    eq_properties: &EquivalenceProperties,
908
54
    agg_mode: &AggregateMode,
909
54
) -> Option<LexOrdering> {
910
54
    let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode);
911
54
    eq_properties.get_finer_ordering(existing_req, &aggr_req)
912
54
}
913
914
/// Concatenates the given slices.
915
0
pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
916
0
    [lhs, rhs].concat()
917
0
}
918
919
/// Get the common requirement that satisfies all the aggregate expressions.
920
///
921
/// # Parameters
922
///
923
/// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the
924
///   aggregate expressions.
925
/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
926
///   physical GROUP BY expression.
927
/// - `eq_properties`: A reference to an `EquivalenceProperties` instance
928
///   representing equivalence properties for ordering.
929
/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
930
///   mode of aggregation.
931
///
932
/// # Returns
933
///
934
/// A `LexRequirement` instance, which is the requirement that satisfies all the
935
/// aggregate requirements. Returns an error in case of conflicting requirements.
936
51
pub fn get_finer_aggregate_exprs_requirement(
937
51
    aggr_exprs: &mut [AggregateFunctionExpr],
938
51
    group_by: &PhysicalGroupBy,
939
51
    eq_properties: &EquivalenceProperties,
940
51
    agg_mode: &AggregateMode,
941
51
) -> Result<LexRequirement> {
942
51
    let mut requirement = vec![];
943
51
    for 
aggr_expr48
in aggr_exprs.iter_mut() {
944
48
        if let Some(finer_ordering) =
945
48
            finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode)
946
        {
947
48
            if eq_properties.ordering_satisfy(&finer_ordering) {
948
                // Requirement is satisfied by existing ordering
949
45
                requirement = finer_ordering;
950
45
                continue;
951
3
            }
952
0
        }
953
3
        if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
954
3
            if let Some(
finer_ordering1
) = finer_ordering(
955
3
                &requirement,
956
3
                &reverse_aggr_expr,
957
3
                group_by,
958
3
                eq_properties,
959
3
                agg_mode,
960
3
            ) {
961
1
                if eq_properties.ordering_satisfy(&finer_ordering) {
962
                    // Reverse requirement is satisfied by exiting ordering.
963
                    // Hence reverse the aggregator
964
0
                    requirement = finer_ordering;
965
0
                    *aggr_expr = reverse_aggr_expr;
966
0
                    continue;
967
1
                }
968
2
            }
969
0
        }
970
3
        if let Some(finer_ordering) =
971
3
            finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode)
972
        {
973
            // There is a requirement that both satisfies existing requirement and current
974
            // aggregate requirement. Use updated requirement
975
3
            requirement = finer_ordering;
976
3
            continue;
977
0
        }
978
0
        if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
979
0
            if let Some(finer_ordering) = finer_ordering(
980
0
                &requirement,
981
0
                &reverse_aggr_expr,
982
0
                group_by,
983
0
                eq_properties,
984
0
                agg_mode,
985
0
            ) {
986
                // There is a requirement that both satisfies existing requirement and reverse
987
                // aggregate requirement. Use updated requirement
988
0
                requirement = finer_ordering;
989
0
                *aggr_expr = reverse_aggr_expr;
990
0
                continue;
991
0
            }
992
0
        }
993
994
        // Neither the existing requirement and current aggregate requirement satisfy the other, this means
995
        // requirements are conflicting. Currently, we do not support
996
        // conflicting requirements.
997
0
        return not_impl_err!(
998
0
            "Conflicting ordering requirements in aggregate functions is not supported"
999
0
        );
1000
    }
1001
1002
51
    Ok(PhysicalSortRequirement::from_sort_exprs(&requirement))
1003
51
}
1004
1005
/// Returns physical expressions for arguments to evaluate against a batch.
1006
///
1007
/// The expressions are different depending on `mode`:
1008
/// * Partial: AggregateFunctionExpr::expressions
1009
/// * Final: columns of `AggregateFunctionExpr::state_fields()`
1010
142
pub fn aggregate_expressions(
1011
142
    aggr_expr: &[AggregateFunctionExpr],
1012
142
    mode: &AggregateMode,
1013
142
    col_idx_base: usize,
1014
142
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1015
142
    match mode {
1016
        AggregateMode::Partial
1017
        | AggregateMode::Single
1018
55
        | AggregateMode::SinglePartitioned => Ok(aggr_expr
1019
55
            .iter()
1020
55
            .map(|agg| {
1021
55
                let mut result = agg.expressions();
1022
                // Append ordering requirements to expressions' results. This
1023
                // way order sensitive aggregators can satisfy requirement
1024
                // themselves.
1025
55
                if let Some(
ordering_req32
) = agg.order_bys() {
1026
32
                    result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr)));
1027
32
                }
23
1028
55
                result
1029
55
            })
1030
55
            .collect()),
1031
        // In this mode, we build the merge expressions of the aggregation.
1032
        AggregateMode::Final | AggregateMode::FinalPartitioned => {
1033
87
            let mut col_idx_base = col_idx_base;
1034
87
            aggr_expr
1035
87
                .iter()
1036
87
                .map(|agg| {
1037
87
                    let exprs = merge_expressions(col_idx_base, agg)
?0
;
1038
87
                    col_idx_base += exprs.len();
1039
87
                    Ok(exprs)
1040
87
                })
1041
87
                .collect()
1042
        }
1043
    }
1044
142
}
1045
1046
/// uses `state_fields` to build a vec of physical column expressions required to merge the
1047
/// AggregateFunctionExpr' accumulator's state.
1048
///
1049
/// `index_base` is the starting physical column index for the next expanded state field.
1050
87
fn merge_expressions(
1051
87
    index_base: usize,
1052
87
    expr: &AggregateFunctionExpr,
1053
87
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1054
87
    expr.state_fields().map(|fields| {
1055
87
        fields
1056
87
            .iter()
1057
87
            .enumerate()
1058
201
            .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1059
87
            .collect()
1060
87
    })
1061
87
}
1062
1063
pub type AccumulatorItem = Box<dyn Accumulator>;
1064
1065
2
pub fn create_accumulators(
1066
2
    aggr_expr: &[AggregateFunctionExpr],
1067
2
) -> Result<Vec<AccumulatorItem>> {
1068
2
    aggr_expr
1069
2
        .iter()
1070
2
        .map(|expr| expr.create_accumulator())
1071
2
        .collect()
1072
2
}
1073
1074
/// returns a vector of ArrayRefs, where each entry corresponds to either the
1075
/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
1076
0
pub fn finalize_aggregation(
1077
0
    accumulators: &mut [AccumulatorItem],
1078
0
    mode: &AggregateMode,
1079
0
) -> Result<Vec<ArrayRef>> {
1080
0
    match mode {
1081
        AggregateMode::Partial => {
1082
            // Build the vector of states
1083
0
            accumulators
1084
0
                .iter_mut()
1085
0
                .map(|accumulator| {
1086
0
                    accumulator.state().and_then(|e| {
1087
0
                        e.iter()
1088
0
                            .map(|v| v.to_array())
1089
0
                            .collect::<Result<Vec<ArrayRef>>>()
1090
0
                    })
1091
0
                })
1092
0
                .flatten_ok()
1093
0
                .collect()
1094
        }
1095
        AggregateMode::Final
1096
        | AggregateMode::FinalPartitioned
1097
        | AggregateMode::Single
1098
        | AggregateMode::SinglePartitioned => {
1099
            // Merge the state to the final value
1100
0
            accumulators
1101
0
                .iter_mut()
1102
0
                .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1103
0
                .collect()
1104
        }
1105
    }
1106
0
}
1107
1108
/// Evaluates expressions against a record batch.
1109
131
fn evaluate(
1110
131
    expr: &[Arc<dyn PhysicalExpr>],
1111
131
    batch: &RecordBatch,
1112
131
) -> Result<Vec<ArrayRef>> {
1113
131
    expr.iter()
1114
245
        .map(|expr| {
1115
245
            expr.evaluate(batch)
1116
245
                .and_then(|v| v.into_array(batch.num_rows()))
1117
245
        })
1118
131
        .collect()
1119
131
}
1120
1121
/// Evaluates expressions against a record batch.
1122
131
pub(crate) fn evaluate_many(
1123
131
    expr: &[Vec<Arc<dyn PhysicalExpr>>],
1124
131
    batch: &RecordBatch,
1125
131
) -> Result<Vec<Vec<ArrayRef>>> {
1126
131
    expr.iter().map(|expr| evaluate(expr, batch)).collect()
1127
131
}
1128
1129
131
fn evaluate_optional(
1130
131
    expr: &[Option<Arc<dyn PhysicalExpr>>],
1131
131
    batch: &RecordBatch,
1132
131
) -> Result<Vec<Option<ArrayRef>>> {
1133
131
    expr.iter()
1134
131
        .map(|expr| {
1135
131
            expr.as_ref()
1136
131
                .map(|expr| {
1137
0
                    expr.evaluate(batch)
1138
0
                        .and_then(|v| v.into_array(batch.num_rows()))
1139
131
                })
1140
131
                .transpose()
1141
131
        })
1142
131
        .collect()
1143
131
}
1144
1145
/// Evaluate a group by expression against a `RecordBatch`
1146
///
1147
/// Arguments:
1148
/// - `group_by`: the expression to evaluate
1149
/// - `batch`: the `RecordBatch` to evaluate against
1150
///
1151
/// Returns: A Vec of Vecs of Array of results
1152
/// The outer Vec appears to be for grouping sets
1153
/// The inner Vec contains the results per expression
1154
/// The inner-inner Array contains the results per row
1155
131
pub(crate) fn evaluate_group_by(
1156
131
    group_by: &PhysicalGroupBy,
1157
131
    batch: &RecordBatch,
1158
131
) -> Result<Vec<Vec<ArrayRef>>> {
1159
131
    let exprs: Vec<ArrayRef> = group_by
1160
131
        .expr
1161
131
        .iter()
1162
163
        .map(|(expr, _)| {
1163
163
            let value = expr.evaluate(batch)
?0
;
1164
163
            value.into_array(batch.num_rows())
1165
163
        })
1166
131
        .collect::<Result<Vec<_>>>()
?0
;
1167
1168
131
    let null_exprs: Vec<ArrayRef> = group_by
1169
131
        .null_expr
1170
131
        .iter()
1171
131
        .map(|(expr, _)| 
{44
1172
44
            let value = expr.evaluate(batch)
?0
;
1173
44
            value.into_array(batch.num_rows())
1174
131
        
}44
)
1175
131
        .collect::<Result<Vec<_>>>()
?0
;
1176
1177
131
    Ok(group_by
1178
131
        .groups
1179
131
        .iter()
1180
171
        .map(|group| {
1181
171
            group
1182
171
                .iter()
1183
171
                .enumerate()
1184
251
                .map(|(idx, is_null)| {
1185
251
                    if *is_null {
1186
56
                        Arc::clone(&null_exprs[idx])
1187
                    } else {
1188
195
                        Arc::clone(&exprs[idx])
1189
                    }
1190
251
                })
1191
171
                .collect()
1192
171
        })
1193
131
        .collect())
1194
131
}
1195
1196
#[cfg(test)]
1197
mod tests {
1198
    use std::task::{Context, Poll};
1199
1200
    use super::*;
1201
    use crate::coalesce_batches::CoalesceBatchesExec;
1202
    use crate::coalesce_partitions::CoalescePartitionsExec;
1203
    use crate::common;
1204
    use crate::expressions::col;
1205
    use crate::memory::MemoryExec;
1206
    use crate::test::assert_is_pending;
1207
    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
1208
    use crate::RecordBatchStream;
1209
1210
    use arrow::array::{Float64Array, UInt32Array};
1211
    use arrow::compute::{concat_batches, SortOptions};
1212
    use arrow::datatypes::{DataType, Int32Type};
1213
    use arrow_array::{
1214
        DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array,
1215
    };
1216
    use datafusion_common::{
1217
        assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
1218
        ScalarValue,
1219
    };
1220
    use datafusion_execution::config::SessionConfig;
1221
    use datafusion_execution::memory_pool::FairSpillPool;
1222
    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1223
    use datafusion_functions_aggregate::array_agg::array_agg_udaf;
1224
    use datafusion_functions_aggregate::average::avg_udaf;
1225
    use datafusion_functions_aggregate::count::count_udaf;
1226
    use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
1227
    use datafusion_functions_aggregate::median::median_udaf;
1228
    use datafusion_functions_aggregate::sum::sum_udaf;
1229
    use datafusion_physical_expr::expressions::lit;
1230
    use datafusion_physical_expr::PhysicalSortExpr;
1231
1232
    use crate::common::collect;
1233
    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1234
    use datafusion_physical_expr::expressions::Literal;
1235
    use datafusion_physical_expr::Partitioning;
1236
    use futures::{FutureExt, Stream};
1237
1238
    // Generate a schema which consists of 5 columns (a, b, c, d, e)
1239
1
    fn create_test_schema() -> Result<SchemaRef> {
1240
1
        let a = Field::new("a", DataType::Int32, true);
1241
1
        let b = Field::new("b", DataType::Int32, true);
1242
1
        let c = Field::new("c", DataType::Int32, true);
1243
1
        let d = Field::new("d", DataType::Int32, true);
1244
1
        let e = Field::new("e", DataType::Int32, true);
1245
1
        let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
1246
1
1247
1
        Ok(schema)
1248
1
    }
1249
1250
    /// some mock data to aggregates
1251
43
    fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
1252
43
        // define a schema.
1253
43
        let schema = Arc::new(Schema::new(vec![
1254
43
            Field::new("a", DataType::UInt32, false),
1255
43
            Field::new("b", DataType::Float64, false),
1256
43
        ]));
1257
43
1258
43
        // define data.
1259
43
        (
1260
43
            Arc::clone(&schema),
1261
43
            vec![
1262
43
                RecordBatch::try_new(
1263
43
                    Arc::clone(&schema),
1264
43
                    vec![
1265
43
                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1266
43
                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1267
43
                    ],
1268
43
                )
1269
43
                .unwrap(),
1270
43
                RecordBatch::try_new(
1271
43
                    schema,
1272
43
                    vec![
1273
43
                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1274
43
                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1275
43
                    ],
1276
43
                )
1277
43
                .unwrap(),
1278
43
            ],
1279
43
        )
1280
43
    }
1281
1282
    /// Generates some mock data for aggregate tests.
1283
8
    fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
1284
8
        // Define a schema:
1285
8
        let schema = Arc::new(Schema::new(vec![
1286
8
            Field::new("a", DataType::UInt32, false),
1287
8
            Field::new("b", DataType::Float64, false),
1288
8
        ]));
1289
8
1290
8
        // Generate data so that first and last value results are at 2nd and
1291
8
        // 3rd partitions.  With this construction, we guarantee we don't receive
1292
8
        // the expected result by accident, but merging actually works properly;
1293
8
        // i.e. it doesn't depend on the data insertion order.
1294
8
        (
1295
8
            Arc::clone(&schema),
1296
8
            vec![
1297
8
                RecordBatch::try_new(
1298
8
                    Arc::clone(&schema),
1299
8
                    vec![
1300
8
                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1301
8
                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1302
8
                    ],
1303
8
                )
1304
8
                .unwrap(),
1305
8
                RecordBatch::try_new(
1306
8
                    Arc::clone(&schema),
1307
8
                    vec![
1308
8
                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1309
8
                        Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
1310
8
                    ],
1311
8
                )
1312
8
                .unwrap(),
1313
8
                RecordBatch::try_new(
1314
8
                    Arc::clone(&schema),
1315
8
                    vec![
1316
8
                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1317
8
                        Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
1318
8
                    ],
1319
8
                )
1320
8
                .unwrap(),
1321
8
                RecordBatch::try_new(
1322
8
                    schema,
1323
8
                    vec![
1324
8
                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1325
8
                        Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
1326
8
                    ],
1327
8
                )
1328
8
                .unwrap(),
1329
8
            ],
1330
8
        )
1331
8
    }
1332
1333
12
    fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
1334
12
        let session_config = SessionConfig::new().with_batch_size(batch_size);
1335
12
        let runtime = RuntimeEnvBuilder::default()
1336
12
            .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
1337
12
            .build_arc()
1338
12
            .unwrap();
1339
12
        let task_ctx = TaskContext::default()
1340
12
            .with_session_config(session_config)
1341
12
            .with_runtime(runtime);
1342
12
        Arc::new(task_ctx)
1343
12
    }
1344
1345
4
    async fn check_grouping_sets(
1346
4
        input: Arc<dyn ExecutionPlan>,
1347
4
        spill: bool,
1348
4
    ) -> Result<()> {
1349
4
        let input_schema = input.schema();
1350
1351
4
        let grouping_set = PhysicalGroupBy {
1352
4
            expr: vec![
1353
4
                (col("a", &input_schema)
?0
, "a".to_string()),
1354
4
                (col("b", &input_schema)
?0
, "b".to_string()),
1355
4
            ],
1356
4
            null_expr: vec![
1357
4
                (lit(ScalarValue::UInt32(None)), "a".to_string()),
1358
4
                (lit(ScalarValue::Float64(None)), "b".to_string()),
1359
4
            ],
1360
4
            groups: vec![
1361
4
                vec![false, true],  // (a, NULL)
1362
4
                vec![true, false],  // (NULL, b)
1363
4
                vec![false, false], // (a,b)
1364
4
            ],
1365
        };
1366
1367
4
        let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1368
4
            .schema(Arc::clone(&input_schema))
1369
4
            .alias("COUNT(1)")
1370
4
            .build()
?0
];
1371
1372
4
        let task_ctx = if spill {
1373
            // adjust the max memory size to have the partial aggregate result for spill mode.
1374
2
            new_spill_ctx(4, 500)
1375
        } else {
1376
2
            Arc::new(TaskContext::default())
1377
        };
1378
1379
4
        let partial_aggregate = Arc::new(AggregateExec::try_new(
1380
4
            AggregateMode::Partial,
1381
4
            grouping_set.clone(),
1382
4
            aggregates.clone(),
1383
4
            vec![None],
1384
4
            input,
1385
4
            Arc::clone(&input_schema),
1386
4
        )
?0
);
1387
1388
4
        let result =
1389
4
            common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))
?0
).
await2
?0
;
1390
1391
4
        let expected = if spill {
1392
            // In spill mode, we test with the limited memory, if the mem usage exceeds,
1393
            // we trigger the early emit rule, which turns out the partial aggregate result.
1394
2
            vec![
1395
2
                "+---+-----+-----------------+",
1396
2
                "| a | b   | COUNT(1)[count] |",
1397
2
                "+---+-----+-----------------+",
1398
2
                "|   | 1.0 | 1               |",
1399
2
                "|   | 1.0 | 1               |",
1400
2
                "|   | 2.0 | 1               |",
1401
2
                "|   | 2.0 | 1               |",
1402
2
                "|   | 3.0 | 1               |",
1403
2
                "|   | 3.0 | 1               |",
1404
2
                "|   | 4.0 | 1               |",
1405
2
                "|   | 4.0 | 1               |",
1406
2
                "| 2 |     | 1               |",
1407
2
                "| 2 |     | 1               |",
1408
2
                "| 2 | 1.0 | 1               |",
1409
2
                "| 2 | 1.0 | 1               |",
1410
2
                "| 3 |     | 1               |",
1411
2
                "| 3 |     | 2               |",
1412
2
                "| 3 | 2.0 | 2               |",
1413
2
                "| 3 | 3.0 | 1               |",
1414
2
                "| 4 |     | 1               |",
1415
2
                "| 4 |     | 2               |",
1416
2
                "| 4 | 3.0 | 1               |",
1417
2
                "| 4 | 4.0 | 2               |",
1418
2
                "+---+-----+-----------------+",
1419
2
            ]
1420
        } else {
1421
2
            vec![
1422
2
                "+---+-----+-----------------+",
1423
2
                "| a | b   | COUNT(1)[count] |",
1424
2
                "+---+-----+-----------------+",
1425
2
                "|   | 1.0 | 2               |",
1426
2
                "|   | 2.0 | 2               |",
1427
2
                "|   | 3.0 | 2               |",
1428
2
                "|   | 4.0 | 2               |",
1429
2
                "| 2 |     | 2               |",
1430
2
                "| 2 | 1.0 | 2               |",
1431
2
                "| 3 |     | 3               |",
1432
2
                "| 3 | 2.0 | 2               |",
1433
2
                "| 3 | 3.0 | 1               |",
1434
2
                "| 4 |     | 3               |",
1435
2
                "| 4 | 3.0 | 1               |",
1436
2
                "| 4 | 4.0 | 2               |",
1437
2
                "+---+-----+-----------------+",
1438
2
            ]
1439
        };
1440
4
        assert_batches_sorted_eq!(expected, &result);
1441
1442
4
        let groups = partial_aggregate.group_expr().expr().to_vec();
1443
4
1444
4
        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1445
1446
4
        let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = groups
1447
4
            .iter()
1448
8
            .map(|(_expr, name)| Ok((col(name, &input_schema)
?0
, name.clone())))
1449
4
            .collect::<Result<_>>()
?0
;
1450
1451
4
        let final_grouping_set = PhysicalGroupBy::new_single(final_group);
1452
1453
4
        let task_ctx = if spill {
1454
2
            new_spill_ctx(4, 3160)
1455
        } else {
1456
2
            task_ctx
1457
        };
1458
1459
4
        let merged_aggregate = Arc::new(AggregateExec::try_new(
1460
4
            AggregateMode::Final,
1461
4
            final_grouping_set,
1462
4
            aggregates,
1463
4
            vec![None],
1464
4
            merge,
1465
4
            input_schema,
1466
4
        )
?0
);
1467
1468
4
        let result =
1469
4
            common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))
?0
).
await2
?0
;
1470
4
        let batch = concat_batches(&result[0].schema(), &result)
?0
;
1471
4
        assert_eq!(batch.num_columns(), 3);
1472
4
        assert_eq!(batch.num_rows(), 12);
1473
1474
4
        let expected = vec![
1475
4
            "+---+-----+----------+",
1476
4
            "| a | b   | COUNT(1) |",
1477
4
            "+---+-----+----------+",
1478
4
            "|   | 1.0 | 2        |",
1479
4
            "|   | 2.0 | 2        |",
1480
4
            "|   | 3.0 | 2        |",
1481
4
            "|   | 4.0 | 2        |",
1482
4
            "| 2 |     | 2        |",
1483
4
            "| 2 | 1.0 | 2        |",
1484
4
            "| 3 |     | 3        |",
1485
4
            "| 3 | 2.0 | 2        |",
1486
4
            "| 3 | 3.0 | 1        |",
1487
4
            "| 4 |     | 3        |",
1488
4
            "| 4 | 3.0 | 1        |",
1489
4
            "| 4 | 4.0 | 2        |",
1490
4
            "+---+-----+----------+",
1491
4
        ];
1492
4
1493
4
        assert_batches_sorted_eq!(&expected, &result);
1494
1495
4
        let metrics = merged_aggregate.metrics().unwrap();
1496
4
        let output_rows = metrics.output_rows().unwrap();
1497
4
        assert_eq!(12, output_rows);
1498
1499
4
        Ok(())
1500
4
    }
1501
1502
    /// build the aggregates on the data from some_data() and check the results
1503
4
    async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
1504
4
        let input_schema = input.schema();
1505
1506
4
        let grouping_set = PhysicalGroupBy {
1507
4
            expr: vec![(col("a", &input_schema)
?0
, "a".to_string())],
1508
4
            null_expr: vec![],
1509
4
            groups: vec![vec![false]],
1510
        };
1511
1512
4
        let aggregates: Vec<AggregateFunctionExpr> =
1513
4
            vec![
1514
4
                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)
?0
])
1515
4
                    .schema(Arc::clone(&input_schema))
1516
4
                    .alias("AVG(b)")
1517
4
                    .build()
?0
,
1518
            ];
1519
1520
4
        let task_ctx = if spill {
1521
            // set to an appropriate value to trigger spill
1522
2
            new_spill_ctx(2, 1600)
1523
        } else {
1524
2
            Arc::new(TaskContext::default())
1525
        };
1526
1527
4
        let partial_aggregate = Arc::new(AggregateExec::try_new(
1528
4
            AggregateMode::Partial,
1529
4
            grouping_set.clone(),
1530
4
            aggregates.clone(),
1531
4
            vec![None],
1532
4
            input,
1533
4
            Arc::clone(&input_schema),
1534
4
        )
?0
);
1535
1536
4
        let result =
1537
4
            common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))
?0
).
await2
?0
;
1538
1539
4
        let expected = if spill {
1540
2
            vec![
1541
2
                "+---+---------------+-------------+",
1542
2
                "| a | AVG(b)[count] | AVG(b)[sum] |",
1543
2
                "+---+---------------+-------------+",
1544
2
                "| 2 | 1             | 1.0         |",
1545
2
                "| 2 | 1             | 1.0         |",
1546
2
                "| 3 | 1             | 2.0         |",
1547
2
                "| 3 | 2             | 5.0         |",
1548
2
                "| 4 | 3             | 11.0        |",
1549
2
                "+---+---------------+-------------+",
1550
2
            ]
1551
        } else {
1552
2
            vec![
1553
2
                "+---+---------------+-------------+",
1554
2
                "| a | AVG(b)[count] | AVG(b)[sum] |",
1555
2
                "+---+---------------+-------------+",
1556
2
                "| 2 | 2             | 2.0         |",
1557
2
                "| 3 | 3             | 7.0         |",
1558
2
                "| 4 | 3             | 11.0        |",
1559
2
                "+---+---------------+-------------+",
1560
2
            ]
1561
        };
1562
4
        assert_batches_sorted_eq!(expected, &result);
1563
1564
4
        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1565
1566
4
        let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = grouping_set
1567
4
            .expr
1568
4
            .iter()
1569
4
            .map(|(_expr, name)| Ok((col(name, &input_schema)
?0
, name.clone())))
1570
4
            .collect::<Result<_>>()
?0
;
1571
1572
4
        let final_grouping_set = PhysicalGroupBy::new_single(final_group);
1573
1574
4
        let merged_aggregate = Arc::new(AggregateExec::try_new(
1575
4
            AggregateMode::Final,
1576
4
            final_grouping_set,
1577
4
            aggregates,
1578
4
            vec![None],
1579
4
            merge,
1580
4
            input_schema,
1581
4
        )
?0
);
1582
1583
4
        let task_ctx = if spill {
1584
            // enlarge memory limit to let the final aggregation finish
1585
2
            new_spill_ctx(2, 2600)
1586
        } else {
1587
2
            Arc::clone(&task_ctx)
1588
        };
1589
4
        let result = common::collect(merged_aggregate.execute(0, task_ctx)
?0
).
await2
?0
;
1590
4
        let batch = concat_batches(&result[0].schema(), &result)
?0
;
1591
4
        assert_eq!(batch.num_columns(), 2);
1592
4
        assert_eq!(batch.num_rows(), 3);
1593
1594
4
        let expected = vec![
1595
4
            "+---+--------------------+",
1596
4
            "| a | AVG(b)             |",
1597
4
            "+---+--------------------+",
1598
4
            "| 2 | 1.0                |",
1599
4
            "| 3 | 2.3333333333333335 |", // 3, (2 + 3 + 2) / 3
1600
4
            "| 4 | 3.6666666666666665 |", // 4, (3 + 4 + 4) / 3
1601
4
            "+---+--------------------+",
1602
4
        ];
1603
4
1604
4
        assert_batches_sorted_eq!(&expected, &result);
1605
1606
4
        let metrics = merged_aggregate.metrics().unwrap();
1607
4
        let output_rows = metrics.output_rows().unwrap();
1608
4
        if spill {
1609
            // When spilling, the output rows metrics become partial output size + final output size
1610
            // This is because final aggregation starts while partial aggregation is still emitting
1611
2
            assert_eq!(8, output_rows);
1612
        } else {
1613
2
            assert_eq!(3, output_rows);
1614
        }
1615
1616
4
        Ok(())
1617
4
    }
1618
1619
    /// Define a test source that can yield back to runtime before returning its first item ///
1620
1621
    #[derive(Debug)]
1622
    struct TestYieldingExec {
1623
        /// True if this exec should yield back to runtime the first time it is polled
1624
        pub yield_first: bool,
1625
        cache: PlanProperties,
1626
    }
1627
1628
    impl TestYieldingExec {
1629
9
        fn new(yield_first: bool) -> Self {
1630
9
            let schema = some_data().0;
1631
9
            let cache = Self::compute_properties(schema);
1632
9
            Self { yield_first, cache }
1633
9
        }
1634
1635
        /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1636
9
        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1637
9
            let eq_properties = EquivalenceProperties::new(schema);
1638
9
            PlanProperties::new(
1639
9
                eq_properties,
1640
9
                // Output Partitioning
1641
9
                Partitioning::UnknownPartitioning(1),
1642
9
                // Execution Mode
1643
9
                ExecutionMode::Bounded,
1644
9
            )
1645
9
        }
1646
    }
1647
1648
    impl DisplayAs for TestYieldingExec {
1649
0
        fn fmt_as(
1650
0
            &self,
1651
0
            t: DisplayFormatType,
1652
0
            f: &mut std::fmt::Formatter,
1653
0
        ) -> std::fmt::Result {
1654
0
            match t {
1655
                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1656
0
                    write!(f, "TestYieldingExec")
1657
0
                }
1658
0
            }
1659
0
        }
1660
    }
1661
1662
    impl ExecutionPlan for TestYieldingExec {
1663
0
        fn name(&self) -> &'static str {
1664
0
            "TestYieldingExec"
1665
0
        }
1666
1667
0
        fn as_any(&self) -> &dyn Any {
1668
0
            self
1669
0
        }
1670
1671
89
        fn properties(&self) -> &PlanProperties {
1672
89
            &self.cache
1673
89
        }
1674
1675
0
        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1676
0
            vec![]
1677
0
        }
1678
1679
0
        fn with_new_children(
1680
0
            self: Arc<Self>,
1681
0
            _: Vec<Arc<dyn ExecutionPlan>>,
1682
0
        ) -> Result<Arc<dyn ExecutionPlan>> {
1683
0
            internal_err!("Children cannot be replaced in {self:?}")
1684
0
        }
1685
1686
18
        fn execute(
1687
18
            &self,
1688
18
            _partition: usize,
1689
18
            _context: Arc<TaskContext>,
1690
18
        ) -> Result<SendableRecordBatchStream> {
1691
18
            let stream = if self.yield_first {
1692
10
                TestYieldingStream::New
1693
            } else {
1694
8
                TestYieldingStream::Yielded
1695
            };
1696
1697
18
            Ok(Box::pin(stream))
1698
18
        }
1699
1700
0
        fn statistics(&self) -> Result<Statistics> {
1701
0
            let (_, batches) = some_data();
1702
0
            Ok(common::compute_record_batch_statistics(
1703
0
                &[batches],
1704
0
                &self.schema(),
1705
0
                None,
1706
0
            ))
1707
0
        }
1708
    }
1709
1710
    /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records
1711
    enum TestYieldingStream {
1712
        New,
1713
        Yielded,
1714
        ReturnedBatch1,
1715
        ReturnedBatch2,
1716
    }
1717
1718
    impl Stream for TestYieldingStream {
1719
        type Item = Result<RecordBatch>;
1720
1721
60
        fn poll_next(
1722
60
            mut self: std::pin::Pin<&mut Self>,
1723
60
            cx: &mut Context<'_>,
1724
60
        ) -> Poll<Option<Self::Item>> {
1725
60
            match &*self {
1726
                TestYieldingStream::New => {
1727
10
                    *(self.as_mut()) = TestYieldingStream::Yielded;
1728
10
                    cx.waker().wake_by_ref();
1729
10
                    Poll::Pending
1730
                }
1731
                TestYieldingStream::Yielded => {
1732
18
                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
1733
18
                    Poll::Ready(Some(Ok(some_data().1[0].clone())))
1734
                }
1735
                TestYieldingStream::ReturnedBatch1 => {
1736
16
                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
1737
16
                    Poll::Ready(Some(Ok(some_data().1[1].clone())))
1738
                }
1739
16
                TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
1740
            }
1741
60
        }
1742
    }
1743
1744
    impl RecordBatchStream for TestYieldingStream {
1745
0
        fn schema(&self) -> SchemaRef {
1746
0
            some_data().0
1747
0
        }
1748
    }
1749
1750
    //--- Tests ---//
1751
1752
    #[tokio::test]
1753
1
    async fn aggregate_source_not_yielding() -> Result<()> {
1754
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1755
1
1756
1
        check_aggregates(input, false).
await0
1757
1
    }
1758
1759
    #[tokio::test]
1760
1
    async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
1761
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1762
1
1763
1
        check_grouping_sets(input, false).
await0
1764
1
    }
1765
1766
    #[tokio::test]
1767
1
    async fn aggregate_source_with_yielding() -> Result<()> {
1768
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1769
1
1770
2
        check_aggregates(input, false).await
1771
1
    }
1772
1773
    #[tokio::test]
1774
1
    async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
1775
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1776
1
1777
2
        check_grouping_sets(input, false).await
1778
1
    }
1779
1780
    #[tokio::test]
1781
1
    async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
1782
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1783
1
1784
1
        check_aggregates(input, true).
await0
1785
1
    }
1786
1787
    #[tokio::test]
1788
1
    async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
1789
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1790
1
1791
1
        check_grouping_sets(input, true).
await0
1792
1
    }
1793
1794
    #[tokio::test]
1795
1
    async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
1796
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1797
1
1798
2
        check_aggregates(input, true).await
1799
1
    }
1800
1801
    #[tokio::test]
1802
1
    async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
1803
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1804
1
1805
2
        check_grouping_sets(input, true).await
1806
1
    }
1807
1808
    // Median(a)
1809
1
    fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
1810
1
        AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)
?0
])
1811
1
            .schema(schema)
1812
1
            .alias("MEDIAN(a)")
1813
1
            .build()
1814
1
    }
1815
1816
    #[tokio::test]
1817
1
    async fn test_oom() -> Result<()> {
1818
1
        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1819
1
        let input_schema = input.schema();
1820
1
1821
1
        let runtime = RuntimeEnvBuilder::default()
1822
1
            .with_memory_limit(1, 1.0)
1823
1
            .build_arc()
?0
;
1824
1
        let task_ctx = TaskContext::default().with_runtime(runtime);
1825
1
        let task_ctx = Arc::new(task_ctx);
1826
1
1827
1
        let groups_none = PhysicalGroupBy::default();
1828
1
        let groups_some = PhysicalGroupBy {
1829
1
            expr: vec![(col("a", &input_schema)
?0
, "a".to_string())],
1830
1
            null_expr: vec![],
1831
1
            groups: vec![vec![false]],
1832
1
        };
1833
1
1834
1
        // something that allocates within the aggregator
1835
1
        let aggregates_v0: Vec<AggregateFunctionExpr> =
1836
1
            vec![test_median_agg_expr(Arc::clone(&input_schema))
?0
];
1837
1
1838
1
        // use fast-path in `row_hash.rs`.
1839
1
        let aggregates_v2: Vec<AggregateFunctionExpr> =
1840
1
            vec![
1841
1
                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)
?0
])
1842
1
                    .schema(Arc::clone(&input_schema))
1843
1
                    .alias("AVG(b)")
1844
1
                    .build()
?0
,
1845
1
            ];
1846
1
1847
2
        for (version, groups, aggregates) in [
1848
1
            (0, groups_none, aggregates_v0),
1849
1
            (2, groups_some, aggregates_v2),
1850
1
        ] {
1851
2
            let n_aggr = aggregates.len();
1852
2
            let partial_aggregate = Arc::new(AggregateExec::try_new(
1853
2
                AggregateMode::Partial,
1854
2
                groups,
1855
2
                aggregates,
1856
2
                vec![None; n_aggr],
1857
2
                Arc::clone(&input),
1858
2
                Arc::clone(&input_schema),
1859
2
            )
?0
);
1860
1
1861
2
            let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))
?0
;
1862
1
1863
1
            // ensure that we really got the version we wanted
1864
2
            match version {
1865
1
                0 => {
1866
1
                    assert!(
matches!0
(stream, StreamType::AggregateStream(_)));
1867
1
                }
1868
1
                1 => {
1869
1
                    
assert!0
(
matches!0
(
stream0
, StreamType::GroupedHash(_)));
1870
1
                }
1871
1
                2 => {
1872
1
                    assert!(
matches!0
(stream, StreamType::GroupedHash(_)));
1873
1
                }
1874
1
                _ => 
panic!("Unknown version: {version}")0
,
1875
1
            }
1876
1
1877
2
            let stream: SendableRecordBatchStream = stream.into();
1878
2
            let err = common::collect(stream).await.unwrap_err();
1879
2
1880
2
            // error root cause traversal is a bit complicated, see #4172.
1881
2
            let err = err.find_root();
1882
2
            assert!(
1883
2
                
matches!0
(err, DataFusionError::ResourcesExhausted(_)),
1884
1
                
"Wrong error type: {err}"0
,
1885
1
            );
1886
1
        }
1887
1
1888
1
        Ok(())
1889
1
    }
1890
1891
    #[tokio::test]
1892
1
    async fn test_drop_cancel_without_groups() -> Result<()> {
1893
1
        let task_ctx = Arc::new(TaskContext::default());
1894
1
        let schema =
1895
1
            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
1896
1
1897
1
        let groups = PhysicalGroupBy::default();
1898
1
1899
1
        let aggregates: Vec<AggregateFunctionExpr> =
1900
1
            vec![
1901
1
                AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)
?0
])
1902
1
                    .schema(Arc::clone(&schema))
1903
1
                    .alias("AVG(a)")
1904
1
                    .build()
?0
,
1905
1
            ];
1906
1
1907
1
        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1908
1
        let refs = blocking_exec.refs();
1909
1
        let aggregate_exec = Arc::new(AggregateExec::try_new(
1910
1
            AggregateMode::Partial,
1911
1
            groups.clone(),
1912
1
            aggregates.clone(),
1913
1
            vec![None],
1914
1
            blocking_exec,
1915
1
            schema,
1916
1
        )
?0
);
1917
1
1918
1
        let fut = crate::collect(aggregate_exec, task_ctx);
1919
1
        let mut fut = fut.boxed();
1920
1
1921
1
        assert_is_pending(&mut fut);
1922
1
        drop(fut);
1923
1
        assert_strong_count_converges_to_zero(refs).
await0
;
1924
1
1925
1
        Ok(())
1926
1
    }
1927
1928
    #[tokio::test]
1929
1
    async fn test_drop_cancel_with_groups() -> Result<()> {
1930
1
        let task_ctx = Arc::new(TaskContext::default());
1931
1
        let schema = Arc::new(Schema::new(vec![
1932
1
            Field::new("a", DataType::Float64, true),
1933
1
            Field::new("b", DataType::Float64, true),
1934
1
        ]));
1935
1
1936
1
        let groups =
1937
1
            PhysicalGroupBy::new_single(vec![(col("a", &schema)
?0
, "a".to_string())]);
1938
1
1939
1
        let aggregates: Vec<AggregateFunctionExpr> =
1940
1
            vec![
1941
1
                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)
?0
])
1942
1
                    .schema(Arc::clone(&schema))
1943
1
                    .alias("AVG(b)")
1944
1
                    .build()
?0
,
1945
1
            ];
1946
1
1947
1
        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1948
1
        let refs = blocking_exec.refs();
1949
1
        let aggregate_exec = Arc::new(AggregateExec::try_new(
1950
1
            AggregateMode::Partial,
1951
1
            groups,
1952
1
            aggregates.clone(),
1953
1
            vec![None],
1954
1
            blocking_exec,
1955
1
            schema,
1956
1
        )
?0
);
1957
1
1958
1
        let fut = crate::collect(aggregate_exec, task_ctx);
1959
1
        let mut fut = fut.boxed();
1960
1
1961
1
        assert_is_pending(&mut fut);
1962
1
        drop(fut);
1963
1
        assert_strong_count_converges_to_zero(refs).
await0
;
1964
1
1965
1
        Ok(())
1966
1
    }
1967
1968
    #[tokio::test]
1969
1
    async fn run_first_last_multi_partitions() -> Result<()> {
1970
3
        for 
use_coalesce_batches2
in [false, true] {
1971
6
            for 
is_first_acc4
in [false, true] {
1972
12
                for 
spill8
in [false, true] {
1973
8
                    first_last_multi_partitions(
1974
8
                        use_coalesce_batches,
1975
8
                        is_first_acc,
1976
8
                        spill,
1977
8
                        4200,
1978
8
                    )
1979
72
                    .await
?0
1980
1
                }
1981
1
            }
1982
1
        }
1983
1
        Ok(())
1984
1
    }
1985
1986
    // FIRST_VALUE(b ORDER BY b <SortOptions>)
1987
5
    fn test_first_value_agg_expr(
1988
5
        schema: &Schema,
1989
5
        sort_options: SortOptions,
1990
5
    ) -> Result<AggregateFunctionExpr> {
1991
5
        let ordering_req = [PhysicalSortExpr {
1992
5
            expr: col("b", schema)
?0
,
1993
5
            options: sort_options,
1994
        }];
1995
5
        let args = [col("b", schema)
?0
];
1996
1997
5
        AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
1998
5
            .order_by(ordering_req.to_vec())
1999
5
            .schema(Arc::new(schema.clone()))
2000
5
            .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2001
5
            .build()
2002
5
    }
2003
2004
    // LAST_VALUE(b ORDER BY b <SortOptions>)
2005
5
    fn test_last_value_agg_expr(
2006
5
        schema: &Schema,
2007
5
        sort_options: SortOptions,
2008
5
    ) -> Result<AggregateFunctionExpr> {
2009
5
        let ordering_req = [PhysicalSortExpr {
2010
5
            expr: col("b", schema)
?0
,
2011
5
            options: sort_options,
2012
        }];
2013
5
        let args = [col("b", schema)
?0
];
2014
5
        AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2015
5
            .order_by(ordering_req.to_vec())
2016
5
            .schema(Arc::new(schema.clone()))
2017
5
            .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2018
5
            .build()
2019
5
    }
2020
2021
    // This function either constructs the physical plan below,
2022
    //
2023
    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2024
    // "  CoalesceBatchesExec: target_batch_size=1024",
2025
    // "    CoalescePartitionsExec",
2026
    // "      AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2027
    // "        MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2028
    //
2029
    // or
2030
    //
2031
    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2032
    // "  CoalescePartitionsExec",
2033
    // "    AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2034
    // "      MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2035
    //
2036
    // and checks whether the function `merge_batch` works correctly for
2037
    // FIRST_VALUE and LAST_VALUE functions.
2038
8
    async fn first_last_multi_partitions(
2039
8
        use_coalesce_batches: bool,
2040
8
        is_first_acc: bool,
2041
8
        spill: bool,
2042
8
        max_memory: usize,
2043
8
    ) -> Result<()> {
2044
8
        let task_ctx = if spill {
2045
4
            new_spill_ctx(2, max_memory)
2046
        } else {
2047
4
            Arc::new(TaskContext::default())
2048
        };
2049
2050
8
        let (schema, data) = some_data_v2();
2051
8
        let partition1 = data[0].clone();
2052
8
        let partition2 = data[1].clone();
2053
8
        let partition3 = data[2].clone();
2054
8
        let partition4 = data[3].clone();
2055
2056
8
        let groups =
2057
8
            PhysicalGroupBy::new_single(vec![(col("a", &schema)
?0
, "a".to_string())]);
2058
8
2059
8
        let sort_options = SortOptions {
2060
8
            descending: false,
2061
8
            nulls_first: false,
2062
8
        };
2063
8
        let aggregates: Vec<AggregateFunctionExpr> = if is_first_acc {
2064
4
            vec![test_first_value_agg_expr(&schema, sort_options)
?0
]
2065
        } else {
2066
4
            vec![test_last_value_agg_expr(&schema, sort_options)
?0
]
2067
        };
2068
2069
8
        let memory_exec = Arc::new(MemoryExec::try_new(
2070
8
            &[
2071
8
                vec![partition1],
2072
8
                vec![partition2],
2073
8
                vec![partition3],
2074
8
                vec![partition4],
2075
8
            ],
2076
8
            Arc::clone(&schema),
2077
8
            None,
2078
8
        )
?0
);
2079
8
        let aggregate_exec = Arc::new(AggregateExec::try_new(
2080
8
            AggregateMode::Partial,
2081
8
            groups.clone(),
2082
8
            aggregates.clone(),
2083
8
            vec![None],
2084
8
            memory_exec,
2085
8
            Arc::clone(&schema),
2086
8
        )
?0
);
2087
8
        let coalesce = if use_coalesce_batches {
2088
4
            let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec));
2089
4
            Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan>
2090
        } else {
2091
4
            Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2092
4
                as Arc<dyn ExecutionPlan>
2093
        };
2094
8
        let aggregate_final = Arc::new(AggregateExec::try_new(
2095
8
            AggregateMode::Final,
2096
8
            groups,
2097
8
            aggregates.clone(),
2098
8
            vec![None],
2099
8
            coalesce,
2100
8
            schema,
2101
8
        )
?0
) as Arc<dyn ExecutionPlan>;
2102
2103
72
        let 
result8
=
crate::collect(aggregate_final, task_ctx)8
.await
?0
;
2104
8
        if is_first_acc {
2105
4
            let expected = [
2106
4
                "+---+--------------------------------------------+",
2107
4
                "| a | first_value(b) ORDER BY [b ASC NULLS LAST] |",
2108
4
                "+---+--------------------------------------------+",
2109
4
                "| 2 | 0.0                                        |",
2110
4
                "| 3 | 1.0                                        |",
2111
4
                "| 4 | 3.0                                        |",
2112
4
                "+---+--------------------------------------------+",
2113
4
            ];
2114
4
            assert_batches_eq!(expected, &result);
2115
        } else {
2116
4
            let expected = [
2117
4
                "+---+-------------------------------------------+",
2118
4
                "| a | last_value(b) ORDER BY [b ASC NULLS LAST] |",
2119
4
                "+---+-------------------------------------------+",
2120
4
                "| 2 | 3.0                                       |",
2121
4
                "| 3 | 5.0                                       |",
2122
4
                "| 4 | 6.0                                       |",
2123
4
                "+---+-------------------------------------------+",
2124
4
            ];
2125
4
            assert_batches_eq!(expected, &result);
2126
        };
2127
8
        Ok(())
2128
8
    }
2129
2130
    #[tokio::test]
2131
1
    async fn test_get_finest_requirements() -> Result<()> {
2132
1
        let test_schema = create_test_schema()
?0
;
2133
1
2134
1
        // Assume column a and b are aliases
2135
1
        // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent).
2136
1
        let options1 = SortOptions {
2137
1
            descending: false,
2138
1
            nulls_first: false,
2139
1
        };
2140
1
        let col_a = &col("a", &test_schema)
?0
;
2141
1
        let col_b = &col("b", &test_schema)
?0
;
2142
1
        let col_c = &col("c", &test_schema)
?0
;
2143
1
        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2144
1
        // Columns a and b are equal.
2145
1
        eq_properties.add_equal_conditions(col_a, col_b)
?0
;
2146
1
        // Aggregate requirements are
2147
1
        // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively
2148
1
        let order_by_exprs = vec![
2149
1
            None,
2150
1
            Some(vec![PhysicalSortExpr {
2151
1
                expr: Arc::clone(col_a),
2152
1
                options: options1,
2153
1
            }]),
2154
1
            Some(vec![
2155
1
                PhysicalSortExpr {
2156
1
                    expr: Arc::clone(col_a),
2157
1
                    options: options1,
2158
1
                },
2159
1
                PhysicalSortExpr {
2160
1
                    expr: Arc::clone(col_b),
2161
1
                    options: options1,
2162
1
                },
2163
1
                PhysicalSortExpr {
2164
1
                    expr: Arc::clone(col_c),
2165
1
                    options: options1,
2166
1
                },
2167
1
            ]),
2168
1
            Some(vec![
2169
1
                PhysicalSortExpr {
2170
1
                    expr: Arc::clone(col_a),
2171
1
                    options: options1,
2172
1
                },
2173
1
                PhysicalSortExpr {
2174
1
                    expr: Arc::clone(col_b),
2175
1
                    options: options1,
2176
1
                },
2177
1
            ]),
2178
1
        ];
2179
1
2180
1
        let common_requirement = vec![
2181
1
            PhysicalSortExpr {
2182
1
                expr: Arc::clone(col_a),
2183
1
                options: options1,
2184
1
            },
2185
1
            PhysicalSortExpr {
2186
1
                expr: Arc::clone(col_c),
2187
1
                options: options1,
2188
1
            },
2189
1
        ];
2190
1
        let mut aggr_exprs = order_by_exprs
2191
1
            .into_iter()
2192
4
            .map(|order_by_expr| {
2193
4
                let ordering_req = order_by_expr.unwrap_or_default();
2194
4
                AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2195
4
                    .alias("a")
2196
4
                    .order_by(ordering_req.to_vec())
2197
4
                    .schema(Arc::clone(&test_schema))
2198
4
                    .build()
2199
4
                    .unwrap()
2200
4
            })
2201
1
            .collect::<Vec<_>>();
2202
1
        let group_by = PhysicalGroupBy::new_single(vec![]);
2203
1
        let res = get_finer_aggregate_exprs_requirement(
2204
1
            &mut aggr_exprs,
2205
1
            &group_by,
2206
1
            &eq_properties,
2207
1
            &AggregateMode::Partial,
2208
1
        )
?0
;
2209
1
        let res = PhysicalSortRequirement::to_sort_exprs(res);
2210
1
        assert_eq!(res, common_requirement);
2211
1
        Ok(())
2212
1
    }
2213
2214
    #[test]
2215
1
    fn test_agg_exec_same_schema() -> Result<()> {
2216
1
        let schema = Arc::new(Schema::new(vec![
2217
1
            Field::new("a", DataType::Float32, true),
2218
1
            Field::new("b", DataType::Float32, true),
2219
1
        ]));
2220
2221
1
        let col_a = col("a", &schema)
?0
;
2222
1
        let option_desc = SortOptions {
2223
1
            descending: true,
2224
1
            nulls_first: true,
2225
1
        };
2226
1
        let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
2227
2228
1
        let aggregates: Vec<AggregateFunctionExpr> = vec![
2229
1
            test_first_value_agg_expr(&schema, option_desc)
?0
,
2230
1
            test_last_value_agg_expr(&schema, option_desc)
?0
,
2231
        ];
2232
1
        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2233
1
        let aggregate_exec = Arc::new(AggregateExec::try_new(
2234
1
            AggregateMode::Partial,
2235
1
            groups,
2236
1
            aggregates,
2237
1
            vec![None, None],
2238
1
            Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
2239
1
            schema,
2240
1
        )
?0
);
2241
1
        let new_agg =
2242
1
            Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])
?0
;
2243
1
        assert_eq!(new_agg.schema(), aggregate_exec.schema());
2244
1
        Ok(())
2245
1
    }
2246
2247
    #[tokio::test]
2248
1
    async fn test_agg_exec_group_by_const() -> Result<()> {
2249
1
        let schema = Arc::new(Schema::new(vec![
2250
1
            Field::new("a", DataType::Float32, true),
2251
1
            Field::new("b", DataType::Float32, true),
2252
1
            Field::new("const", DataType::Int32, false),
2253
1
        ]));
2254
1
2255
1
        let col_a = col("a", &schema)
?0
;
2256
1
        let col_b = col("b", &schema)
?0
;
2257
1
        let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2258
1
2259
1
        let groups = PhysicalGroupBy::new(
2260
1
            vec![
2261
1
                (col_a, "a".to_string()),
2262
1
                (col_b, "b".to_string()),
2263
1
                (const_expr, "const".to_string()),
2264
1
            ],
2265
1
            vec![
2266
1
                (
2267
1
                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2268
1
                    "a".to_string(),
2269
1
                ),
2270
1
                (
2271
1
                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2272
1
                    "b".to_string(),
2273
1
                ),
2274
1
                (
2275
1
                    Arc::new(Literal::new(ScalarValue::Int32(None))),
2276
1
                    "const".to_string(),
2277
1
                ),
2278
1
            ],
2279
1
            vec![
2280
1
                vec![false, true, true],
2281
1
                vec![true, false, true],
2282
1
                vec![true, true, false],
2283
1
            ],
2284
1
        );
2285
1
2286
1
        let aggregates: Vec<AggregateFunctionExpr> =
2287
1
            vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
2288
1
                .schema(Arc::clone(&schema))
2289
1
                .alias("1")
2290
1
                .build()
?0
];
2291
1
2292
1
        let input_batches = (0..4)
2293
4
            .map(|_| {
2294
4
                let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2295
4
                let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2296
4
                let c = Arc::new(Int32Array::from(vec![1; 8192]));
2297
4
2298
4
                RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
2299
4
            })
2300
1
            .collect();
2301
1
2302
1
        let input = Arc::new(MemoryExec::try_new(
2303
1
            &[input_batches],
2304
1
            Arc::clone(&schema),
2305
1
            None,
2306
1
        )
?0
);
2307
1
2308
1
        let aggregate_exec = Arc::new(AggregateExec::try_new(
2309
1
            AggregateMode::Partial,
2310
1
            groups,
2311
1
            aggregates.clone(),
2312
1
            vec![None],
2313
1
            input,
2314
1
            schema,
2315
1
        )
?0
);
2316
1
2317
1
        let output =
2318
1
            collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))
?0
).
await0
?0
;
2319
1
2320
1
        let expected = [
2321
1
            "+-----+-----+-------+----------+",
2322
1
            "| a   | b   | const | 1[count] |",
2323
1
            "+-----+-----+-------+----------+",
2324
1
            "|     | 0.0 |       | 32768    |",
2325
1
            "| 0.0 |     |       | 32768    |",
2326
1
            "|     |     | 1     | 32768    |",
2327
1
            "+-----+-----+-------+----------+",
2328
1
        ];
2329
1
        assert_batches_sorted_eq!(expected, &output);
2330
1
2331
1
        Ok(())
2332
1
    }
2333
2334
    #[tokio::test]
2335
1
    async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2336
1
        let batch = RecordBatch::try_new(
2337
1
            Arc::new(Schema::new(vec![
2338
1
                Field::new(
2339
1
                    "labels".to_string(),
2340
1
                    DataType::Struct(
2341
1
                        vec![
2342
1
                            Field::new_dict(
2343
1
                                "a".to_string(),
2344
1
                                DataType::Dictionary(
2345
1
                                    Box::new(DataType::Int32),
2346
1
                                    Box::new(DataType::Utf8),
2347
1
                                ),
2348
1
                                true,
2349
1
                                0,
2350
1
                                false,
2351
1
                            ),
2352
1
                            Field::new_dict(
2353
1
                                "b".to_string(),
2354
1
                                DataType::Dictionary(
2355
1
                                    Box::new(DataType::Int32),
2356
1
                                    Box::new(DataType::Utf8),
2357
1
                                ),
2358
1
                                true,
2359
1
                                0,
2360
1
                                false,
2361
1
                            ),
2362
1
                        ]
2363
1
                        .into(),
2364
1
                    ),
2365
1
                    false,
2366
1
                ),
2367
1
                Field::new("value", DataType::UInt64, false),
2368
1
            ])),
2369
1
            vec![
2370
1
                Arc::new(StructArray::from(vec![
2371
1
                    (
2372
1
                        Arc::new(Field::new_dict(
2373
1
                            "a".to_string(),
2374
1
                            DataType::Dictionary(
2375
1
                                Box::new(DataType::Int32),
2376
1
                                Box::new(DataType::Utf8),
2377
1
                            ),
2378
1
                            true,
2379
1
                            0,
2380
1
                            false,
2381
1
                        )),
2382
1
                        Arc::new(
2383
1
                            vec![Some("a"), None, Some("a")]
2384
1
                                .into_iter()
2385
1
                                .collect::<DictionaryArray<Int32Type>>(),
2386
1
                        ) as ArrayRef,
2387
1
                    ),
2388
1
                    (
2389
1
                        Arc::new(Field::new_dict(
2390
1
                            "b".to_string(),
2391
1
                            DataType::Dictionary(
2392
1
                                Box::new(DataType::Int32),
2393
1
                                Box::new(DataType::Utf8),
2394
1
                            ),
2395
1
                            true,
2396
1
                            0,
2397
1
                            false,
2398
1
                        )),
2399
1
                        Arc::new(
2400
1
                            vec![Some("b"), Some("c"), Some("b")]
2401
1
                                .into_iter()
2402
1
                                .collect::<DictionaryArray<Int32Type>>(),
2403
1
                        ) as ArrayRef,
2404
1
                    ),
2405
1
                ])),
2406
1
                Arc::new(UInt64Array::from(vec![1, 1, 1])),
2407
1
            ],
2408
1
        )
2409
1
        .expect("Failed to create RecordBatch");
2410
1
2411
1
        let group_by = PhysicalGroupBy::new_single(vec![(
2412
1
            col("labels", &batch.schema())
?0
,
2413
1
            "labels".to_string(),
2414
1
        )]);
2415
1
2416
1
        let aggr_expr = vec![AggregateExprBuilder::new(
2417
1
            sum_udaf(),
2418
1
            vec![col("value", &batch.schema())
?0
],
2419
1
        )
2420
1
        .schema(Arc::clone(&batch.schema()))
2421
1
        .alias(String::from("SUM(value)"))
2422
1
        .build()
?0
];
2423
1
2424
1
        let input = Arc::new(MemoryExec::try_new(
2425
1
            &[vec![batch.clone()]],
2426
1
            Arc::<arrow_schema::Schema>::clone(&batch.schema()),
2427
1
            None,
2428
1
        )
?0
);
2429
1
        let aggregate_exec = Arc::new(AggregateExec::try_new(
2430
1
            AggregateMode::FinalPartitioned,
2431
1
            group_by,
2432
1
            aggr_expr,
2433
1
            vec![None],
2434
1
            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2435
1
            batch.schema(),
2436
1
        )
?0
);
2437
1
2438
1
        let session_config = SessionConfig::default();
2439
1
        let ctx = TaskContext::default().with_session_config(session_config);
2440
1
        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))
?0
).
await0
?0
;
2441
1
2442
1
        let expected = [
2443
1
            "+--------------+------------+",
2444
1
            "| labels       | SUM(value) |",
2445
1
            "+--------------+------------+",
2446
1
            "| {a: a, b: b} | 2          |",
2447
1
            "| {a: , b: c}  | 1          |",
2448
1
            "+--------------+------------+",
2449
1
        ];
2450
1
        assert_batches_eq!(expected, &output);
2451
1
2452
1
        Ok(())
2453
1
    }
2454
2455
    #[tokio::test]
2456
1
    async fn test_skip_aggregation_after_first_batch() -> Result<()> {
2457
1
        let schema = Arc::new(Schema::new(vec![
2458
1
            Field::new("key", DataType::Int32, true),
2459
1
            Field::new("val", DataType::Int32, true),
2460
1
        ]));
2461
1
2462
1
        let group_by =
2463
1
            PhysicalGroupBy::new_single(vec![(col("key", &schema)
?0
, "key".to_string())]);
2464
1
2465
1
        let aggr_expr =
2466
1
            vec![
2467
1
                AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)
?0
])
2468
1
                    .schema(Arc::clone(&schema))
2469
1
                    .alias(String::from("COUNT(val)"))
2470
1
                    .build()
?0
,
2471
1
            ];
2472
1
2473
1
        let input_data = vec![
2474
1
            RecordBatch::try_new(
2475
1
                Arc::clone(&schema),
2476
1
                vec![
2477
1
                    Arc::new(Int32Array::from(vec![1, 2, 3])),
2478
1
                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2479
1
                ],
2480
1
            )
2481
1
            .unwrap(),
2482
1
            RecordBatch::try_new(
2483
1
                Arc::clone(&schema),
2484
1
                vec![
2485
1
                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2486
1
                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2487
1
                ],
2488
1
            )
2489
1
            .unwrap(),
2490
1
        ];
2491
1
2492
1
        let input = Arc::new(MemoryExec::try_new(
2493
1
            &[input_data],
2494
1
            Arc::clone(&schema),
2495
1
            None,
2496
1
        )
?0
);
2497
1
        let aggregate_exec = Arc::new(AggregateExec::try_new(
2498
1
            AggregateMode::Partial,
2499
1
            group_by,
2500
1
            aggr_expr,
2501
1
            vec![None],
2502
1
            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2503
1
            schema,
2504
1
        )
?0
);
2505
1
2506
1
        let mut session_config = SessionConfig::default();
2507
1
        session_config = session_config.set(
2508
1
            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2509
1
            &ScalarValue::Int64(Some(2)),
2510
1
        );
2511
1
        session_config = session_config.set(
2512
1
            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2513
1
            &ScalarValue::Float64(Some(0.1)),
2514
1
        );
2515
1
2516
1
        let ctx = TaskContext::default().with_session_config(session_config);
2517
1
        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))
?0
).
await0
?0
;
2518
1
2519
1
        let expected = [
2520
1
            "+-----+-------------------+",
2521
1
            "| key | COUNT(val)[count] |",
2522
1
            "+-----+-------------------+",
2523
1
            "| 1   | 1                 |",
2524
1
            "| 2   | 1                 |",
2525
1
            "| 3   | 1                 |",
2526
1
            "| 2   | 1                 |",
2527
1
            "| 3   | 1                 |",
2528
1
            "| 4   | 1                 |",
2529
1
            "+-----+-------------------+",
2530
1
        ];
2531
1
        assert_batches_eq!(expected, &output);
2532
1
2533
1
        Ok(())
2534
1
    }
2535
2536
    #[tokio::test]
2537
1
    async fn test_skip_aggregation_after_threshold() -> Result<()> {
2538
1
        let schema = Arc::new(Schema::new(vec![
2539
1
            Field::new("key", DataType::Int32, true),
2540
1
            Field::new("val", DataType::Int32, true),
2541
1
        ]));
2542
1
2543
1
        let group_by =
2544
1
            PhysicalGroupBy::new_single(vec![(col("key", &schema)
?0
, "key".to_string())]);
2545
1
2546
1
        let aggr_expr =
2547
1
            vec![
2548
1
                AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)
?0
])
2549
1
                    .schema(Arc::clone(&schema))
2550
1
                    .alias(String::from("COUNT(val)"))
2551
1
                    .build()
?0
,
2552
1
            ];
2553
1
2554
1
        let input_data = vec![
2555
1
            RecordBatch::try_new(
2556
1
                Arc::clone(&schema),
2557
1
                vec![
2558
1
                    Arc::new(Int32Array::from(vec![1, 2, 3])),
2559
1
                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2560
1
                ],
2561
1
            )
2562
1
            .unwrap(),
2563
1
            RecordBatch::try_new(
2564
1
                Arc::clone(&schema),
2565
1
                vec![
2566
1
                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2567
1
                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2568
1
                ],
2569
1
            )
2570
1
            .unwrap(),
2571
1
            RecordBatch::try_new(
2572
1
                Arc::clone(&schema),
2573
1
                vec![
2574
1
                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2575
1
                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2576
1
                ],
2577
1
            )
2578
1
            .unwrap(),
2579
1
        ];
2580
1
2581
1
        let input = Arc::new(MemoryExec::try_new(
2582
1
            &[input_data],
2583
1
            Arc::clone(&schema),
2584
1
            None,
2585
1
        )
?0
);
2586
1
        let aggregate_exec = Arc::new(AggregateExec::try_new(
2587
1
            AggregateMode::Partial,
2588
1
            group_by,
2589
1
            aggr_expr,
2590
1
            vec![None],
2591
1
            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2592
1
            schema,
2593
1
        )
?0
);
2594
1
2595
1
        let mut session_config = SessionConfig::default();
2596
1
        session_config = session_config.set(
2597
1
            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2598
1
            &ScalarValue::Int64(Some(5)),
2599
1
        );
2600
1
        session_config = session_config.set(
2601
1
            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2602
1
            &ScalarValue::Float64(Some(0.1)),
2603
1
        );
2604
1
2605
1
        let ctx = TaskContext::default().with_session_config(session_config);
2606
1
        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))
?0
).
await0
?0
;
2607
1
2608
1
        let expected = [
2609
1
            "+-----+-------------------+",
2610
1
            "| key | COUNT(val)[count] |",
2611
1
            "+-----+-------------------+",
2612
1
            "| 1   | 1                 |",
2613
1
            "| 2   | 2                 |",
2614
1
            "| 3   | 2                 |",
2615
1
            "| 4   | 1                 |",
2616
1
            "| 2   | 1                 |",
2617
1
            "| 3   | 1                 |",
2618
1
            "| 4   | 1                 |",
2619
1
            "+-----+-------------------+",
2620
1
        ];
2621
1
        assert_batches_eq!(expected, &output);
2622
1
2623
1
        Ok(())
2624
1
    }
2625
2626
    #[test]
2627
1
    fn group_exprs_nullable() -> Result<()> {
2628
1
        let input_schema = Arc::new(Schema::new(vec![
2629
1
            Field::new("a", DataType::Float32, false),
2630
1
            Field::new("b", DataType::Float32, false),
2631
1
        ]));
2632
2633
1
        let aggr_expr =
2634
1
            vec![
2635
1
                AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)
?0
])
2636
1
                    .schema(Arc::clone(&input_schema))
2637
1
                    .alias("COUNT(a)")
2638
1
                    .build()
?0
,
2639
            ];
2640
2641
1
        let grouping_set = PhysicalGroupBy {
2642
1
            expr: vec![
2643
1
                (col("a", &input_schema)
?0
, "a".to_string()),
2644
1
                (col("b", &input_schema)
?0
, "b".to_string()),
2645
1
            ],
2646
1
            null_expr: vec![
2647
1
                (lit(ScalarValue::Float32(None)), "a".to_string()),
2648
1
                (lit(ScalarValue::Float32(None)), "b".to_string()),
2649
1
            ],
2650
1
            groups: vec![
2651
1
                vec![false, true],  // (a, NULL)
2652
1
                vec![false, false], // (a,b)
2653
1
            ],
2654
        };
2655
1
        let aggr_schema = create_schema(
2656
1
            &input_schema,
2657
1
            &grouping_set.expr,
2658
1
            &aggr_expr,
2659
1
            grouping_set.exprs_nullable(),
2660
1
            AggregateMode::Final,
2661
1
        )
?0
;
2662
1
        let expected_schema = Schema::new(vec![
2663
1
            Field::new("a", DataType::Float32, false),
2664
1
            Field::new("b", DataType::Float32, true),
2665
1
            Field::new("COUNT(a)", DataType::Int64, false),
2666
1
        ]);
2667
1
        assert_eq!(aggr_schema, expected_schema);
2668
1
        Ok(())
2669
1
    }
2670
}