Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/aggregate.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
pub(crate) mod groups_accumulator {
19
    #[allow(unused_imports)]
20
    pub(crate) mod accumulate {
21
        pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
22
    }
23
    pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{
24
        accumulate::NullState, GroupsAccumulatorAdapter,
25
    };
26
}
27
pub(crate) mod stats {
28
    pub use datafusion_functions_aggregate_common::stats::StatsType;
29
}
30
pub mod utils {
31
    pub use datafusion_functions_aggregate_common::utils::{
32
        adjust_output_array, get_accum_scalar_values_as_arrays, get_sort_options,
33
        ordering_fields, DecimalAverager, Hashable,
34
    };
35
}
36
37
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
38
use datafusion_common::ScalarValue;
39
use datafusion_common::{internal_err, not_impl_err, Result};
40
use datafusion_expr::AggregateUDF;
41
use datafusion_expr::ReversedUDAF;
42
use datafusion_expr_common::accumulator::Accumulator;
43
use datafusion_expr_common::type_coercion::aggregates::check_arg_count;
44
use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
45
use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
46
use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
47
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
48
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
49
use datafusion_physical_expr_common::utils::reverse_order_bys;
50
51
use datafusion_expr_common::groups_accumulator::GroupsAccumulator;
52
use std::fmt::Debug;
53
use std::sync::Arc;
54
55
/// Builder for physical [`AggregateFunctionExpr`]
56
///
57
/// `AggregateFunctionExpr` contains the information necessary to call
58
/// an aggregate expression.
59
#[derive(Debug, Clone)]
60
pub struct AggregateExprBuilder {
61
    fun: Arc<AggregateUDF>,
62
    /// Physical expressions of the aggregate function
63
    args: Vec<Arc<dyn PhysicalExpr>>,
64
    alias: Option<String>,
65
    /// Arrow Schema for the aggregate function
66
    schema: SchemaRef,
67
    /// The physical order by expressions
68
    ordering_req: LexOrdering,
69
    /// Whether to ignore null values
70
    ignore_nulls: bool,
71
    /// Whether is distinct aggregate function
72
    is_distinct: bool,
73
    /// Whether the expression is reversed
74
    is_reversed: bool,
75
}
76
77
impl AggregateExprBuilder {
78
36
    pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
79
36
        Self {
80
36
            fun,
81
36
            args,
82
36
            alias: None,
83
36
            schema: Arc::new(Schema::empty()),
84
36
            ordering_req: vec![],
85
36
            ignore_nulls: false,
86
36
            is_distinct: false,
87
36
            is_reversed: false,
88
36
        }
89
36
    }
90
91
36
    pub fn build(self) -> Result<AggregateFunctionExpr> {
92
36
        let Self {
93
36
            fun,
94
36
            args,
95
36
            alias,
96
36
            schema,
97
36
            ordering_req,
98
36
            ignore_nulls,
99
36
            is_distinct,
100
36
            is_reversed,
101
36
        } = self;
102
36
        if args.is_empty() {
103
0
            return internal_err!("args should not be empty");
104
36
        }
105
36
106
36
        let mut ordering_fields = vec![];
107
36
108
36
        if !ordering_req.is_empty() {
109
16
            let ordering_types = ordering_req
110
16
                .iter()
111
22
                .map(|e| e.expr.data_type(&schema))
112
16
                .collect::<Result<Vec<_>>>()
?0
;
113
114
16
            ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types);
115
20
        }
116
117
36
        let input_exprs_types = args
118
36
            .iter()
119
36
            .map(|arg| arg.data_type(&schema))
120
36
            .collect::<Result<Vec<_>>>()
?0
;
121
122
36
        check_arg_count(
123
36
            fun.name(),
124
36
            &input_exprs_types,
125
36
            &fun.signature().type_signature,
126
36
        )
?0
;
127
128
36
        let data_type = fun.return_type(&input_exprs_types)
?0
;
129
36
        let is_nullable = fun.is_nullable();
130
36
        let name = match alias {
131
0
            None => return internal_err!("alias should be provided"),
132
36
            Some(alias) => alias,
133
36
        };
134
36
135
36
        Ok(AggregateFunctionExpr {
136
36
            fun: Arc::unwrap_or_clone(fun),
137
36
            args,
138
36
            data_type,
139
36
            name,
140
36
            schema: Arc::unwrap_or_clone(schema),
141
36
            ordering_req,
142
36
            ignore_nulls,
143
36
            ordering_fields,
144
36
            is_distinct,
145
36
            input_types: input_exprs_types,
146
36
            is_reversed,
147
36
            is_nullable,
148
36
        })
149
36
    }
150
151
36
    pub fn alias(mut self, alias: impl Into<String>) -> Self {
152
36
        self.alias = Some(alias.into());
153
36
        self
154
36
    }
155
156
36
    pub fn schema(mut self, schema: SchemaRef) -> Self {
157
36
        self.schema = schema;
158
36
        self
159
36
    }
160
161
17
    pub fn order_by(mut self, order_by: LexOrdering) -> Self {
162
17
        self.ordering_req = order_by;
163
17
        self
164
17
    }
165
166
0
    pub fn reversed(mut self) -> Self {
167
0
        self.is_reversed = true;
168
0
        self
169
0
    }
170
171
3
    pub fn with_reversed(mut self, is_reversed: bool) -> Self {
172
3
        self.is_reversed = is_reversed;
173
3
        self
174
3
    }
175
176
0
    pub fn distinct(mut self) -> Self {
177
0
        self.is_distinct = true;
178
0
        self
179
0
    }
180
181
3
    pub fn with_distinct(mut self, is_distinct: bool) -> Self {
182
3
        self.is_distinct = is_distinct;
183
3
        self
184
3
    }
185
186
0
    pub fn ignore_nulls(mut self) -> Self {
187
0
        self.ignore_nulls = true;
188
0
        self
189
0
    }
190
191
5
    pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
192
5
        self.ignore_nulls = ignore_nulls;
193
5
        self
194
5
    }
195
}
196
197
/// Physical aggregate expression of a UDAF.
198
#[derive(Debug, Clone)]
199
pub struct AggregateFunctionExpr {
200
    fun: AggregateUDF,
201
    args: Vec<Arc<dyn PhysicalExpr>>,
202
    /// Output / return type of this aggregate
203
    data_type: DataType,
204
    name: String,
205
    schema: Schema,
206
    // The physical order by expressions
207
    ordering_req: LexOrdering,
208
    // Whether to ignore null values
209
    ignore_nulls: bool,
210
    // fields used for order sensitive aggregation functions
211
    ordering_fields: Vec<Field>,
212
    is_distinct: bool,
213
    is_reversed: bool,
214
    input_types: Vec<DataType>,
215
    is_nullable: bool,
216
}
217
218
impl AggregateFunctionExpr {
219
    /// Return the `AggregateUDF` used by this `AggregateFunctionExpr`
220
3
    pub fn fun(&self) -> &AggregateUDF {
221
3
        &self.fun
222
3
    }
223
224
    /// expressions that are passed to the Accumulator.
225
    /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many.
226
64
    pub fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
227
64
        self.args.clone()
228
64
    }
229
230
    /// Human readable name such as `"MIN(c2)"`.
231
4
    pub fn name(&self) -> &str {
232
4
        &self.name
233
4
    }
234
235
    /// Return if the aggregation is distinct
236
0
    pub fn is_distinct(&self) -> bool {
237
0
        self.is_distinct
238
0
    }
239
240
    /// Return if the aggregation ignores nulls
241
0
    pub fn ignore_nulls(&self) -> bool {
242
0
        self.ignore_nulls
243
0
    }
244
245
    /// Return if the aggregation is reversed
246
0
    pub fn is_reversed(&self) -> bool {
247
0
        self.is_reversed
248
0
    }
249
250
    /// Return if the aggregation is nullable
251
0
    pub fn is_nullable(&self) -> bool {
252
0
        self.is_nullable
253
0
    }
254
255
    /// the field of the final result of this aggregation.
256
31
    pub fn field(&self) -> Field {
257
31
        Field::new(&self.name, self.data_type.clone(), self.is_nullable)
258
31
    }
259
260
    /// the accumulator used to accumulate values from the expressions.
261
    /// the accumulator expects the same number of arguments as `expressions` and must
262
    /// return states with the same description as `state_fields`
263
132
    pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
264
132
        let acc_args = AccumulatorArgs {
265
132
            return_type: &self.data_type,
266
132
            schema: &self.schema,
267
132
            ignore_nulls: self.ignore_nulls,
268
132
            ordering_req: &self.ordering_req,
269
132
            is_distinct: self.is_distinct,
270
132
            name: &self.name,
271
132
            is_reversed: self.is_reversed,
272
132
            exprs: &self.args,
273
132
        };
274
132
275
132
        self.fun.accumulator(acc_args)
276
132
    }
277
278
    /// the field of the final result of this aggregation.
279
112
    pub fn state_fields(&self) -> Result<Vec<Field>> {
280
112
        let args = StateFieldsArgs {
281
112
            name: &self.name,
282
112
            input_types: &self.input_types,
283
112
            return_type: &self.data_type,
284
112
            ordering_fields: &self.ordering_fields,
285
112
            is_distinct: self.is_distinct,
286
112
        };
287
112
288
112
        self.fun.state_fields(args)
289
112
    }
290
291
    /// Order by requirements for the aggregate function
292
    /// By default it is `None` (there is no requirement)
293
    /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this
294
64
    pub fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
295
64
        if self.ordering_req.is_empty() {
296
23
            return None;
297
41
        }
298
41
299
41
        if !self.order_sensitivity().is_insensitive() {
300
41
            return Some(&self.ordering_req);
301
0
        }
302
0
303
0
        None
304
64
    }
305
306
    /// Indicates whether aggregator can produce the correct result with any
307
    /// arbitrary input ordering. By default, we assume that aggregate expressions
308
    /// are order insensitive.
309
95
    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
310
95
        if !self.ordering_req.is_empty() {
311
            // If there is requirement, use the sensitivity of the implementation
312
70
            self.fun.order_sensitivity()
313
        } else {
314
            // If no requirement, aggregator is order insensitive
315
25
            AggregateOrderSensitivity::Insensitive
316
        }
317
95
    }
318
319
    /// Sets the indicator whether ordering requirements of the aggregator is
320
    /// satisfied by its input. If this is not the case, aggregators with order
321
    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
322
    /// the correct result with possibly more work internally.
323
    ///
324
    /// # Returns
325
    ///
326
    /// Returns `Ok(Some(updated_expr))` if the process completes successfully.
327
    /// If the expression can benefit from existing input ordering, but does
328
    /// not implement the method, returns an error. Order insensitive and hard
329
    /// requirement aggregators return `Ok(None)`.
330
0
    pub fn with_beneficial_ordering(
331
0
        self,
332
0
        beneficial_ordering: bool,
333
0
    ) -> Result<Option<AggregateFunctionExpr>> {
334
0
        let Some(updated_fn) = self
335
0
            .fun
336
0
            .clone()
337
0
            .with_beneficial_ordering(beneficial_ordering)?
338
        else {
339
0
            return Ok(None);
340
        };
341
342
0
        AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec())
343
0
            .order_by(self.ordering_req.to_vec())
344
0
            .schema(Arc::new(self.schema.clone()))
345
0
            .alias(self.name().to_string())
346
0
            .with_ignore_nulls(self.ignore_nulls)
347
0
            .with_distinct(self.is_distinct)
348
0
            .with_reversed(self.is_reversed)
349
0
            .build()
350
0
            .map(Some)
351
0
    }
352
353
    /// Creates accumulator implementation that supports retract
354
3
    pub fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
355
3
        let args = AccumulatorArgs {
356
3
            return_type: &self.data_type,
357
3
            schema: &self.schema,
358
3
            ignore_nulls: self.ignore_nulls,
359
3
            ordering_req: &self.ordering_req,
360
3
            is_distinct: self.is_distinct,
361
3
            name: &self.name,
362
3
            is_reversed: self.is_reversed,
363
3
            exprs: &self.args,
364
3
        };
365
366
3
        let accumulator = self.fun.create_sliding_accumulator(args)
?0
;
367
368
        // Accumulators that have window frame startings different
369
        // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to
370
        // implement retract_batch method in order to run correctly
371
        // currently in DataFusion.
372
        //
373
        // If this `retract_batches` is not present, there is no way
374
        // to calculate result correctly. For example, the query
375
        //
376
        // ```sql
377
        // SELECT
378
        //  SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a
379
        // FROM
380
        //  t
381
        // ```
382
        //
383
        // 1. First sum value will be the sum of rows between `[0, 1)`,
384
        //
385
        // 2. Second sum value will be the sum of rows between `[0, 2)`
386
        //
387
        // 3. Third sum value will be the sum of rows between `[1, 3)`, etc.
388
        //
389
        // Since the accumulator keeps the running sum:
390
        //
391
        // 1. First sum we add to the state sum value between `[0, 1)`
392
        //
393
        // 2. Second sum we add to the state sum value between `[1, 2)`
394
        // (`[0, 1)` is already in the state sum, hence running sum will
395
        // cover `[0, 2)` range)
396
        //
397
        // 3. Third sum we add to the state sum value between `[2, 3)`
398
        // (`[0, 2)` is already in the state sum).  Also we need to
399
        // retract values between `[0, 1)` by this way we can obtain sum
400
        // between [1, 3) which is indeed the appropriate range.
401
        //
402
        // When we use `UNBOUNDED PRECEDING` in the query starting
403
        // index will always be 0 for the desired range, and hence the
404
        // `retract_batch` method will not be called. In this case
405
        // having retract_batch is not a requirement.
406
        //
407
        // This approach is a a bit different than window function
408
        // approach. In window function (when they use a window frame)
409
        // they get all the desired range during evaluation.
410
3
        if !accumulator.supports_retract_batch() {
411
0
            return not_impl_err!(
412
0
                "Aggregate can not be used as a sliding accumulator because \
413
0
                     `retract_batch` is not implemented: {}",
414
0
                self.name
415
0
            );
416
3
        }
417
3
        Ok(accumulator)
418
3
    }
419
420
    /// If the aggregate expression has a specialized
421
    /// [`GroupsAccumulator`] implementation. If this returns true,
422
    /// `[Self::create_groups_accumulator`] will be called.
423
70
    pub fn groups_accumulator_supported(&self) -> bool {
424
70
        let args = AccumulatorArgs {
425
70
            return_type: &self.data_type,
426
70
            schema: &self.schema,
427
70
            ignore_nulls: self.ignore_nulls,
428
70
            ordering_req: &self.ordering_req,
429
70
            is_distinct: self.is_distinct,
430
70
            name: &self.name,
431
70
            is_reversed: self.is_reversed,
432
70
            exprs: &self.args,
433
70
        };
434
70
        self.fun.groups_accumulator_supported(args)
435
70
    }
436
437
    /// Return a specialized [`GroupsAccumulator`] that manages state
438
    /// for all groups.
439
    ///
440
    /// For maximum performance, a [`GroupsAccumulator`] should be
441
    /// implemented in addition to [`Accumulator`].
442
30
    pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
443
30
        let args = AccumulatorArgs {
444
30
            return_type: &self.data_type,
445
30
            schema: &self.schema,
446
30
            ignore_nulls: self.ignore_nulls,
447
30
            ordering_req: &self.ordering_req,
448
30
            is_distinct: self.is_distinct,
449
30
            name: &self.name,
450
30
            is_reversed: self.is_reversed,
451
30
            exprs: &self.args,
452
30
        };
453
30
        self.fun.create_groups_accumulator(args)
454
30
    }
455
456
    /// Construct an expression that calculates the aggregate in reverse.
457
    /// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
458
    /// For aggregates that do not support calculation in reverse,
459
    /// returns None (which is the default value).
460
3
    pub fn reverse_expr(&self) -> Option<AggregateFunctionExpr> {
461
3
        match self.fun.reverse_udf() {
462
0
            ReversedUDAF::NotSupported => None,
463
0
            ReversedUDAF::Identical => Some(self.clone()),
464
3
            ReversedUDAF::Reversed(reverse_udf) => {
465
3
                let reverse_ordering_req = reverse_order_bys(&self.ordering_req);
466
3
                let mut name = self.name().to_string();
467
3
                // If the function is changed, we need to reverse order_by clause as well
468
3
                // i.e. First(a order by b asc null first) -> Last(a order by b desc null last)
469
3
                if self.fun().name() == reverse_udf.name() {
470
3
                } else {
471
0
                    replace_order_by_clause(&mut name);
472
0
                }
473
3
                replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name());
474
3
475
3
                AggregateExprBuilder::new(reverse_udf, self.args.to_vec())
476
3
                    .order_by(reverse_ordering_req.to_vec())
477
3
                    .schema(Arc::new(self.schema.clone()))
478
3
                    .alias(name)
479
3
                    .with_ignore_nulls(self.ignore_nulls)
480
3
                    .with_distinct(self.is_distinct)
481
3
                    .with_reversed(!self.is_reversed)
482
3
                    .build()
483
3
                    .ok()
484
            }
485
        }
486
3
    }
487
488
    /// Returns all expressions used in the [`AggregateFunctionExpr`].
489
    /// These expressions are  (1)function arguments, (2) order by expressions.
490
0
    pub fn all_expressions(&self) -> AggregatePhysicalExpressions {
491
0
        let args = self.expressions();
492
0
        let order_bys = self.order_bys().unwrap_or(&[]);
493
0
        let order_by_exprs = order_bys
494
0
            .iter()
495
0
            .map(|sort_expr| Arc::clone(&sort_expr.expr))
496
0
            .collect::<Vec<_>>();
497
0
        AggregatePhysicalExpressions {
498
0
            args,
499
0
            order_by_exprs,
500
0
        }
501
0
    }
502
503
    /// Rewrites [`AggregateFunctionExpr`], with new expressions given. The argument should be consistent
504
    /// with the return value of the [`AggregateFunctionExpr::all_expressions`] method.
505
    /// Returns `Some(Arc<dyn AggregateExpr>)` if re-write is supported, otherwise returns `None`.
506
0
    pub fn with_new_expressions(
507
0
        &self,
508
0
        _args: Vec<Arc<dyn PhysicalExpr>>,
509
0
        _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
510
0
    ) -> Option<AggregateFunctionExpr> {
511
0
        None
512
0
    }
513
514
    /// If this function is max, return (output_field, true)
515
    /// if the function is min, return (output_field, false)
516
    /// otherwise return None (the default)
517
    ///
518
    /// output_field is the name of the column produced by this aggregate
519
    ///
520
    /// Note: this is used to use special aggregate implementations in certain conditions
521
0
    pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
522
0
        self.fun.is_descending().map(|flag| (self.field(), flag))
523
0
    }
524
525
    /// Returns default value of the function given the input is Null
526
    /// Most of the aggregate function return Null if input is Null,
527
    /// while `count` returns 0 if input is Null
528
0
    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
529
0
        self.fun.default_value(data_type)
530
0
    }
531
}
532
533
/// Stores the physical expressions used inside the `AggregateExpr`.
534
pub struct AggregatePhysicalExpressions {
535
    /// Aggregate function arguments
536
    pub args: Vec<Arc<dyn PhysicalExpr>>,
537
    /// Order by expressions
538
    pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
539
}
540
541
impl PartialEq for AggregateFunctionExpr {
542
0
    fn eq(&self, other: &Self) -> bool {
543
0
        self.name == other.name
544
0
            && self.data_type == other.data_type
545
0
            && self.fun == other.fun
546
0
            && self.args.len() == other.args.len()
547
0
            && self
548
0
                .args
549
0
                .iter()
550
0
                .zip(other.args.iter())
551
0
                .all(|(this_arg, other_arg)| this_arg.eq(other_arg))
552
0
    }
553
}
554
555
0
fn replace_order_by_clause(order_by: &mut String) {
556
0
    let suffixes = [
557
0
        (" DESC NULLS FIRST]", " ASC NULLS LAST]"),
558
0
        (" ASC NULLS FIRST]", " DESC NULLS LAST]"),
559
0
        (" DESC NULLS LAST]", " ASC NULLS FIRST]"),
560
0
        (" ASC NULLS LAST]", " DESC NULLS FIRST]"),
561
0
    ];
562
563
0
    if let Some(start) = order_by.find("ORDER BY [") {
564
0
        if let Some(end) = order_by[start..].find(']') {
565
0
            let order_by_start = start + 9;
566
0
            let order_by_end = start + end;
567
0
568
0
            let column_order = &order_by[order_by_start..=order_by_end];
569
0
            for (suffix, replacement) in suffixes {
570
0
                if column_order.ends_with(suffix) {
571
0
                    let new_order = column_order.replace(suffix, replacement);
572
0
                    order_by.replace_range(order_by_start..=order_by_end, &new_order);
573
0
                    break;
574
0
                }
575
            }
576
0
        }
577
0
    }
578
0
}
579
580
3
fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) {
581
3
    *aggr_name = aggr_name.replace(fn_name_old, fn_name_new);
582
3
}