Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/average.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines `Avg` & `Mean` aggregate & accumulators
19
20
use arrow::array::{
21
    self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
22
    AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23
};
24
25
use arrow::compute::sum;
26
use arrow::datatypes::{
27
    i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
28
    Float64Type, UInt64Type,
29
};
30
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
31
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32
use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
33
use datafusion_expr::utils::format_state_name;
34
use datafusion_expr::Volatility::Immutable;
35
use datafusion_expr::{
36
    Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
37
};
38
39
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
40
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
41
    filtered_null_mask, set_nulls,
42
};
43
44
use datafusion_functions_aggregate_common::utils::DecimalAverager;
45
use log::debug;
46
use std::any::Any;
47
use std::fmt::Debug;
48
use std::sync::Arc;
49
50
make_udaf_expr_and_func!(
51
    Avg,
52
    avg,
53
    expression,
54
    "Returns the avg of a group of values.",
55
    avg_udaf
56
);
57
58
#[derive(Debug)]
59
pub struct Avg {
60
    signature: Signature,
61
    aliases: Vec<String>,
62
}
63
64
impl Avg {
65
1
    pub fn new() -> Self {
66
1
        Self {
67
1
            signature: Signature::user_defined(Immutable),
68
1
            aliases: vec![String::from("mean")],
69
1
        }
70
1
    }
71
}
72
73
impl Default for Avg {
74
1
    fn default() -> Self {
75
1
        Self::new()
76
1
    }
77
}
78
79
impl AggregateUDFImpl for Avg {
80
0
    fn as_any(&self) -> &dyn Any {
81
0
        self
82
0
    }
83
84
14
    fn name(&self) -> &str {
85
14
        "avg"
86
14
    }
87
88
7
    fn signature(&self) -> &Signature {
89
7
        &self.signature
90
7
    }
91
92
7
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93
7
        avg_return_type(self.name(), &arg_types[0])
94
7
    }
95
96
1
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
97
1
        if acc_args.is_distinct {
98
0
            return exec_err!("avg(DISTINCT) aggregations are not available");
99
1
        }
100
        use DataType::*;
101
102
1
        let data_type = acc_args.exprs[0].data_type(acc_args.schema)
?0
;
103
        // instantiate specialized accumulator based for the type
104
1
        match (&data_type, acc_args.return_type) {
105
1
            (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
106
            (
107
0
                Decimal128(sum_precision, sum_scale),
108
0
                Decimal128(target_precision, target_scale),
109
0
            ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
110
0
                sum: None,
111
0
                count: 0,
112
0
                sum_scale: *sum_scale,
113
0
                sum_precision: *sum_precision,
114
0
                target_precision: *target_precision,
115
0
                target_scale: *target_scale,
116
0
            })),
117
118
            (
119
0
                Decimal256(sum_precision, sum_scale),
120
0
                Decimal256(target_precision, target_scale),
121
0
            ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
122
0
                sum: None,
123
0
                count: 0,
124
0
                sum_scale: *sum_scale,
125
0
                sum_precision: *sum_precision,
126
0
                target_precision: *target_precision,
127
0
                target_scale: *target_scale,
128
0
            })),
129
0
            _ => exec_err!(
130
0
                "AvgAccumulator for ({} --> {})",
131
0
                &data_type,
132
0
                acc_args.return_type
133
0
            ),
134
        }
135
1
    }
136
137
25
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
138
25
        Ok(vec![
139
25
            Field::new(
140
25
                format_state_name(args.name, "count"),
141
25
                DataType::UInt64,
142
25
                true,
143
25
            ),
144
25
            Field::new(
145
25
                format_state_name(args.name, "sum"),
146
25
                args.input_types[0].clone(),
147
25
                true,
148
25
            ),
149
25
        ])
150
25
    }
151
152
14
    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
153
0
        matches!(
154
14
            args.return_type,
155
            DataType::Float64 | DataType::Decimal128(_, _)
156
        )
157
14
    }
158
159
14
    fn create_groups_accumulator(
160
14
        &self,
161
14
        args: AccumulatorArgs,
162
14
    ) -> Result<Box<dyn GroupsAccumulator>> {
163
        use DataType::*;
164
165
14
        let data_type = args.exprs[0].data_type(args.schema)
?0
;
166
        // instantiate specialized accumulator based for the type
167
14
        match (&data_type, args.return_type) {
168
            (Float64, Float64) => {
169
14
                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
170
14
                    &data_type,
171
14
                    args.return_type,
172
14
                    |sum: f64, count: u64| 
Ok(sum / count as f64)12
,
173
14
                )))
174
            }
175
            (
176
0
                Decimal128(_sum_precision, sum_scale),
177
0
                Decimal128(target_precision, target_scale),
178
            ) => {
179
0
                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
180
0
                    *sum_scale,
181
0
                    *target_precision,
182
0
                    *target_scale,
183
0
                )?;
184
185
0
                let avg_fn =
186
0
                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
187
188
0
                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
189
0
                    &data_type,
190
0
                    args.return_type,
191
0
                    avg_fn,
192
0
                )))
193
            }
194
195
            (
196
0
                Decimal256(_sum_precision, sum_scale),
197
0
                Decimal256(target_precision, target_scale),
198
            ) => {
199
0
                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
200
0
                    *sum_scale,
201
0
                    *target_precision,
202
0
                    *target_scale,
203
0
                )?;
204
205
0
                let avg_fn = move |sum: i256, count: u64| {
206
0
                    decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
207
0
                };
208
209
0
                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
210
0
                    &data_type,
211
0
                    args.return_type,
212
0
                    avg_fn,
213
0
                )))
214
            }
215
216
0
            _ => not_impl_err!(
217
0
                "AvgGroupsAccumulator for ({} --> {})",
218
0
                &data_type,
219
0
                args.return_type
220
0
            ),
221
        }
222
14
    }
223
224
0
    fn aliases(&self) -> &[String] {
225
0
        &self.aliases
226
0
    }
227
228
0
    fn reverse_expr(&self) -> ReversedUDAF {
229
0
        ReversedUDAF::Identical
230
0
    }
231
232
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
233
0
        if arg_types.len() != 1 {
234
0
            return exec_err!("{} expects exactly one argument.", self.name());
235
0
        }
236
0
        coerce_avg_type(self.name(), arg_types)
237
0
    }
238
}
239
240
/// An accumulator to compute the average
241
#[derive(Debug, Default)]
242
pub struct AvgAccumulator {
243
    sum: Option<f64>,
244
    count: u64,
245
}
246
247
impl Accumulator for AvgAccumulator {
248
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
249
0
        let values = values[0].as_primitive::<Float64Type>();
250
0
        self.count += (values.len() - values.null_count()) as u64;
251
0
        if let Some(x) = sum(values) {
252
0
            let v = self.sum.get_or_insert(0.);
253
0
            *v += x;
254
0
        }
255
0
        Ok(())
256
0
    }
257
258
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
259
0
        Ok(ScalarValue::Float64(
260
0
            self.sum.map(|f| f / self.count as f64),
261
0
        ))
262
0
    }
263
264
0
    fn size(&self) -> usize {
265
0
        std::mem::size_of_val(self)
266
0
    }
267
268
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
269
0
        Ok(vec![
270
0
            ScalarValue::from(self.count),
271
0
            ScalarValue::Float64(self.sum),
272
0
        ])
273
0
    }
274
275
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
276
0
        // counts are summed
277
0
        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
278
279
        // sums are summed
280
0
        if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
281
0
            let v = self.sum.get_or_insert(0.);
282
0
            *v += x;
283
0
        }
284
0
        Ok(())
285
0
    }
286
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
287
0
        let values = values[0].as_primitive::<Float64Type>();
288
0
        self.count -= (values.len() - values.null_count()) as u64;
289
0
        if let Some(x) = sum(values) {
290
0
            self.sum = Some(self.sum.unwrap() - x);
291
0
        }
292
0
        Ok(())
293
0
    }
294
295
0
    fn supports_retract_batch(&self) -> bool {
296
0
        true
297
0
    }
298
}
299
300
/// An accumulator to compute the average for decimals
301
#[derive(Debug)]
302
struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
303
    sum: Option<T::Native>,
304
    count: u64,
305
    sum_scale: i8,
306
    sum_precision: u8,
307
    target_precision: u8,
308
    target_scale: i8,
309
}
310
311
impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
312
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
313
0
        let values = values[0].as_primitive::<T>();
314
0
        self.count += (values.len() - values.null_count()) as u64;
315
316
0
        if let Some(x) = sum(values) {
317
0
            let v = self.sum.get_or_insert(T::Native::default());
318
0
            self.sum = Some(v.add_wrapping(x));
319
0
        }
320
0
        Ok(())
321
0
    }
322
323
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
324
0
        let v = self
325
0
            .sum
326
0
            .map(|v| {
327
0
                DecimalAverager::<T>::try_new(
328
0
                    self.sum_scale,
329
0
                    self.target_precision,
330
0
                    self.target_scale,
331
0
                )?
332
0
                .avg(v, T::Native::from_usize(self.count as usize).unwrap())
333
0
            })
334
0
            .transpose()?;
335
336
0
        ScalarValue::new_primitive::<T>(
337
0
            v,
338
0
            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
339
0
        )
340
0
    }
341
342
0
    fn size(&self) -> usize {
343
0
        std::mem::size_of_val(self)
344
0
    }
345
346
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
347
0
        Ok(vec![
348
0
            ScalarValue::from(self.count),
349
0
            ScalarValue::new_primitive::<T>(
350
0
                self.sum,
351
0
                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
352
0
            )?,
353
        ])
354
0
    }
355
356
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
357
0
        // counts are summed
358
0
        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
359
360
        // sums are summed
361
0
        if let Some(x) = sum(states[1].as_primitive::<T>()) {
362
0
            let v = self.sum.get_or_insert(T::Native::default());
363
0
            self.sum = Some(v.add_wrapping(x));
364
0
        }
365
0
        Ok(())
366
0
    }
367
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
368
0
        let values = values[0].as_primitive::<T>();
369
0
        self.count -= (values.len() - values.null_count()) as u64;
370
0
        if let Some(x) = sum(values) {
371
0
            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
372
0
        }
373
0
        Ok(())
374
0
    }
375
376
0
    fn supports_retract_batch(&self) -> bool {
377
0
        true
378
0
    }
379
}
380
381
/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
382
/// Stores values as native types, and does overflow checking
383
///
384
/// F: Function that calculates the average value from a sum of
385
/// T::Native and a total count
386
#[derive(Debug)]
387
struct AvgGroupsAccumulator<T, F>
388
where
389
    T: ArrowNumericType + Send,
390
    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
391
{
392
    /// The type of the internal sum
393
    sum_data_type: DataType,
394
395
    /// The type of the returned sum
396
    return_data_type: DataType,
397
398
    /// Count per group (use u64 to make UInt64Array)
399
    counts: Vec<u64>,
400
401
    /// Sums per group, stored as the native type
402
    sums: Vec<T::Native>,
403
404
    /// Track nulls in the input / filters
405
    null_state: NullState,
406
407
    /// Function that computes the final average (value / count)
408
    avg_fn: F,
409
}
410
411
impl<T, F> AvgGroupsAccumulator<T, F>
412
where
413
    T: ArrowNumericType + Send,
414
    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
415
{
416
14
    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
417
14
        debug!(
418
0
            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
419
0
            std::any::type_name::<T>()
420
        );
421
422
14
        Self {
423
14
            return_data_type: return_data_type.clone(),
424
14
            sum_data_type: sum_data_type.clone(),
425
14
            counts: vec![],
426
14
            sums: vec![],
427
14
            null_state: NullState::new(),
428
14
            avg_fn,
429
14
        }
430
14
    }
431
}
432
433
impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
434
where
435
    T: ArrowNumericType + Send,
436
    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
437
{
438
17
    fn update_batch(
439
17
        &mut self,
440
17
        values: &[ArrayRef],
441
17
        group_indices: &[usize],
442
17
        opt_filter: Option<&array::BooleanArray>,
443
17
        total_num_groups: usize,
444
17
    ) -> Result<()> {
445
17
        assert_eq!(values.len(), 1, 
"single argument to update_batch"0
);
446
17
        let values = values[0].as_primitive::<T>();
447
17
448
17
        // increment counts, update sums
449
17
        self.counts.resize(total_num_groups, 0);
450
17
        self.sums.resize(total_num_groups, T::default_value());
451
17
        self.null_state.accumulate(
452
17
            group_indices,
453
17
            values,
454
17
            opt_filter,
455
17
            total_num_groups,
456
68
            |group_index, new_value| {
457
68
                let sum = &mut self.sums[group_index];
458
68
                *sum = sum.add_wrapping(new_value);
459
68
460
68
                self.counts[group_index] += 1;
461
68
            },
462
17
        );
463
17
464
17
        Ok(())
465
17
    }
466
467
8
    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
468
8
        let counts = emit_to.take_needed(&mut self.counts);
469
8
        let sums = emit_to.take_needed(&mut self.sums);
470
8
        let nulls = self.null_state.build(emit_to);
471
8
472
8
        assert_eq!(nulls.len(), sums.len());
473
8
        assert_eq!(counts.len(), sums.len());
474
475
        // don't evaluate averages with null inputs to avoid errors on null values
476
477
8
        let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
478
0
            let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
479
0
                .with_data_type(self.return_data_type.clone());
480
0
            let iter = sums.into_iter().zip(counts).zip(nulls.iter());
481
482
0
            for ((sum, count), is_valid) in iter {
483
0
                if is_valid {
484
0
                    builder.append_value((self.avg_fn)(sum, count)?)
485
0
                } else {
486
0
                    builder.append_null();
487
0
                }
488
            }
489
0
            builder.finish()
490
        } else {
491
8
            let averages: Vec<T::Native> = sums
492
8
                .into_iter()
493
8
                .zip(counts.into_iter())
494
12
                .map(|(sum, count)| (self.avg_fn)(sum, count))
495
8
                .collect::<Result<Vec<_>>>()
?0
;
496
8
            PrimitiveArray::new(averages.into(), Some(nulls)) // no copy
497
8
                .with_data_type(self.return_data_type.clone())
498
        };
499
500
8
        Ok(Arc::new(array))
501
8
    }
502
503
    // return arrays for sums and counts
504
20
    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
505
20
        let nulls = self.null_state.build(emit_to);
506
20
        let nulls = Some(nulls);
507
20
508
20
        let counts = emit_to.take_needed(&mut self.counts);
509
20
        let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
510
20
511
20
        let sums = emit_to.take_needed(&mut self.sums);
512
20
        let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
513
20
            .with_data_type(self.sum_data_type.clone());
514
20
515
20
        Ok(vec![
516
20
            Arc::new(counts) as ArrayRef,
517
20
            Arc::new(sums) as ArrayRef,
518
20
        ])
519
20
    }
520
521
14
    fn merge_batch(
522
14
        &mut self,
523
14
        values: &[ArrayRef],
524
14
        group_indices: &[usize],
525
14
        opt_filter: Option<&array::BooleanArray>,
526
14
        total_num_groups: usize,
527
14
    ) -> Result<()> {
528
14
        assert_eq!(values.len(), 2, 
"two arguments to merge_batch"0
);
529
        // first batch is counts, second is partial sums
530
14
        let partial_counts = values[0].as_primitive::<UInt64Type>();
531
14
        let partial_sums = values[1].as_primitive::<T>();
532
14
        // update counts with partial counts
533
14
        self.counts.resize(total_num_groups, 0);
534
14
        self.null_state.accumulate(
535
14
            group_indices,
536
14
            partial_counts,
537
14
            opt_filter,
538
14
            total_num_groups,
539
26
            |group_index, partial_count| {
540
26
                self.counts[group_index] += partial_count;
541
26
            },
542
14
        );
543
14
544
14
        // update sums
545
14
        self.sums.resize(total_num_groups, T::default_value());
546
14
        self.null_state.accumulate(
547
14
            group_indices,
548
14
            partial_sums,
549
14
            opt_filter,
550
14
            total_num_groups,
551
26
            |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
552
26
                let sum = &mut self.sums[group_index];
553
26
                *sum = sum.add_wrapping(new_value);
554
26
            },
555
14
        );
556
14
557
14
        Ok(())
558
14
    }
559
560
0
    fn convert_to_state(
561
0
        &self,
562
0
        values: &[ArrayRef],
563
0
        opt_filter: Option<&BooleanArray>,
564
0
    ) -> Result<Vec<ArrayRef>> {
565
0
        let sums = values[0]
566
0
            .as_primitive::<T>()
567
0
            .clone()
568
0
            .with_data_type(self.sum_data_type.clone());
569
0
        let counts = UInt64Array::from_value(1, sums.len());
570
0
571
0
        let nulls = filtered_null_mask(opt_filter, &sums);
572
0
573
0
        // set nulls on the arrays
574
0
        let counts = set_nulls(counts, nulls.clone());
575
0
        let sums = set_nulls(sums, nulls);
576
0
577
0
        Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
578
0
    }
579
580
10
    fn supports_convert_to_state(&self) -> bool {
581
10
        true
582
10
    }
583
584
85
    fn size(&self) -> usize {
585
85
        self.counts.capacity() * std::mem::size_of::<u64>()
586
85
            + self.sums.capacity() * std::mem::size_of::<T>()
587
85
    }
588
}