Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/variance.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`VarianceSample`]: variance sample aggregations.
19
//! [`VariancePopulation`]: variance population aggregations.
20
21
use arrow::{
22
    array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
23
    buffer::NullBuffer,
24
    compute::kernels::cast,
25
    datatypes::{DataType, Field},
26
};
27
use std::sync::OnceLock;
28
use std::{fmt::Debug, sync::Arc};
29
30
use datafusion_common::{
31
    downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue,
32
};
33
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
34
use datafusion_expr::{
35
    function::{AccumulatorArgs, StateFieldsArgs},
36
    utils::format_state_name,
37
    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
38
    Volatility,
39
};
40
use datafusion_functions_aggregate_common::{
41
    aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
42
};
43
44
make_udaf_expr_and_func!(
45
    VarianceSample,
46
    var_sample,
47
    expression,
48
    "Computes the sample variance.",
49
    var_samp_udaf
50
);
51
52
make_udaf_expr_and_func!(
53
    VariancePopulation,
54
    var_pop,
55
    expression,
56
    "Computes the population variance.",
57
    var_pop_udaf
58
);
59
60
pub struct VarianceSample {
61
    signature: Signature,
62
    aliases: Vec<String>,
63
}
64
65
impl Debug for VarianceSample {
66
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
67
0
        f.debug_struct("VarianceSample")
68
0
            .field("name", &self.name())
69
0
            .field("signature", &self.signature)
70
0
            .finish()
71
0
    }
72
}
73
74
impl Default for VarianceSample {
75
0
    fn default() -> Self {
76
0
        Self::new()
77
0
    }
78
}
79
80
impl VarianceSample {
81
0
    pub fn new() -> Self {
82
0
        Self {
83
0
            aliases: vec![String::from("var_sample"), String::from("var_samp")],
84
0
            signature: Signature::coercible(
85
0
                vec![DataType::Float64],
86
0
                Volatility::Immutable,
87
0
            ),
88
0
        }
89
0
    }
90
}
91
92
impl AggregateUDFImpl for VarianceSample {
93
0
    fn as_any(&self) -> &dyn std::any::Any {
94
0
        self
95
0
    }
96
97
0
    fn name(&self) -> &str {
98
0
        "var"
99
0
    }
100
101
0
    fn signature(&self) -> &Signature {
102
0
        &self.signature
103
0
    }
104
105
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
106
0
        Ok(DataType::Float64)
107
0
    }
108
109
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
110
0
        let name = args.name;
111
0
        Ok(vec![
112
0
            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
113
0
            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
114
0
            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
115
0
        ])
116
0
    }
117
118
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
119
0
        if acc_args.is_distinct {
120
0
            return not_impl_err!("VAR(DISTINCT) aggregations are not available");
121
0
        }
122
0
123
0
        Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
124
0
    }
125
126
0
    fn aliases(&self) -> &[String] {
127
0
        &self.aliases
128
0
    }
129
130
0
    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
131
0
        !acc_args.is_distinct
132
0
    }
133
134
0
    fn create_groups_accumulator(
135
0
        &self,
136
0
        _args: AccumulatorArgs,
137
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
138
0
        Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
139
0
    }
140
141
0
    fn documentation(&self) -> Option<&Documentation> {
142
0
        Some(get_variance_sample_doc())
143
0
    }
144
}
145
146
static VARIANCE_SAMPLE_DOC: OnceLock<Documentation> = OnceLock::new();
147
148
0
fn get_variance_sample_doc() -> &'static Documentation {
149
0
    VARIANCE_SAMPLE_DOC.get_or_init(|| {
150
0
        Documentation::builder()
151
0
            .with_doc_section(DOC_SECTION_GENERAL)
152
0
            .with_description(
153
0
                "Returns the statistical sample variance of a set of numbers.",
154
0
            )
155
0
            .with_syntax_example("var(expression)")
156
0
            .with_standard_argument("expression", "Numeric")
157
0
            .build()
158
0
            .unwrap()
159
0
    })
160
0
}
161
162
pub struct VariancePopulation {
163
    signature: Signature,
164
    aliases: Vec<String>,
165
}
166
167
impl Debug for VariancePopulation {
168
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
169
0
        f.debug_struct("VariancePopulation")
170
0
            .field("name", &self.name())
171
0
            .field("signature", &self.signature)
172
0
            .finish()
173
0
    }
174
}
175
176
impl Default for VariancePopulation {
177
0
    fn default() -> Self {
178
0
        Self::new()
179
0
    }
180
}
181
182
impl VariancePopulation {
183
0
    pub fn new() -> Self {
184
0
        Self {
185
0
            aliases: vec![String::from("var_population")],
186
0
            signature: Signature::numeric(1, Volatility::Immutable),
187
0
        }
188
0
    }
189
}
190
191
impl AggregateUDFImpl for VariancePopulation {
192
0
    fn as_any(&self) -> &dyn std::any::Any {
193
0
        self
194
0
    }
195
196
0
    fn name(&self) -> &str {
197
0
        "var_pop"
198
0
    }
199
200
0
    fn signature(&self) -> &Signature {
201
0
        &self.signature
202
0
    }
203
204
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
205
0
        if !arg_types[0].is_numeric() {
206
0
            return plan_err!("Variance requires numeric input types");
207
0
        }
208
0
209
0
        Ok(DataType::Float64)
210
0
    }
211
212
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
213
0
        let name = args.name;
214
0
        Ok(vec![
215
0
            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
216
0
            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
217
0
            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
218
0
        ])
219
0
    }
220
221
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
222
0
        if acc_args.is_distinct {
223
0
            return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
224
0
        }
225
0
226
0
        Ok(Box::new(VarianceAccumulator::try_new(
227
0
            StatsType::Population,
228
0
        )?))
229
0
    }
230
231
0
    fn aliases(&self) -> &[String] {
232
0
        &self.aliases
233
0
    }
234
235
0
    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
236
0
        !acc_args.is_distinct
237
0
    }
238
239
0
    fn create_groups_accumulator(
240
0
        &self,
241
0
        _args: AccumulatorArgs,
242
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
243
0
        Ok(Box::new(VarianceGroupsAccumulator::new(
244
0
            StatsType::Population,
245
0
        )))
246
0
    }
247
0
    fn documentation(&self) -> Option<&Documentation> {
248
0
        Some(get_variance_population_doc())
249
0
    }
250
}
251
252
static VARIANCE_POPULATION_DOC: OnceLock<Documentation> = OnceLock::new();
253
254
0
fn get_variance_population_doc() -> &'static Documentation {
255
0
    VARIANCE_POPULATION_DOC.get_or_init(|| {
256
0
        Documentation::builder()
257
0
            .with_doc_section(DOC_SECTION_GENERAL)
258
0
            .with_description(
259
0
                "Returns the statistical population variance of a set of numbers.",
260
0
            )
261
0
            .with_syntax_example("var_pop(expression)")
262
0
            .with_standard_argument("expression", "Numeric")
263
0
            .build()
264
0
            .unwrap()
265
0
    })
266
0
}
267
268
/// An accumulator to compute variance
269
/// The algorithm used is an online implementation and numerically stable. It is based on this paper:
270
/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
271
/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
272
///
273
/// The algorithm has been analyzed here:
274
/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
275
/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
276
277
#[derive(Debug)]
278
pub struct VarianceAccumulator {
279
    m2: f64,
280
    mean: f64,
281
    count: u64,
282
    stats_type: StatsType,
283
}
284
285
impl VarianceAccumulator {
286
    /// Creates a new `VarianceAccumulator`
287
0
    pub fn try_new(s_type: StatsType) -> Result<Self> {
288
0
        Ok(Self {
289
0
            m2: 0_f64,
290
0
            mean: 0_f64,
291
0
            count: 0_u64,
292
0
            stats_type: s_type,
293
0
        })
294
0
    }
295
296
0
    pub fn get_count(&self) -> u64 {
297
0
        self.count
298
0
    }
299
300
0
    pub fn get_mean(&self) -> f64 {
301
0
        self.mean
302
0
    }
303
304
0
    pub fn get_m2(&self) -> f64 {
305
0
        self.m2
306
0
    }
307
}
308
309
#[inline]
310
0
fn merge(
311
0
    count: u64,
312
0
    mean: f64,
313
0
    m2: f64,
314
0
    count2: u64,
315
0
    mean2: f64,
316
0
    m22: f64,
317
0
) -> (u64, f64, f64) {
318
0
    let new_count = count + count2;
319
0
    let new_mean =
320
0
        mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
321
0
    let delta = mean - mean2;
322
0
    let new_m2 =
323
0
        m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
324
0
325
0
    (new_count, new_mean, new_m2)
326
0
}
327
328
#[inline]
329
0
fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
330
0
    let new_count = count + 1;
331
0
    let delta1 = value - mean;
332
0
    let new_mean = delta1 / new_count as f64 + mean;
333
0
    let delta2 = value - new_mean;
334
0
    let new_m2 = m2 + delta1 * delta2;
335
0
336
0
    (new_count, new_mean, new_m2)
337
0
}
338
339
impl Accumulator for VarianceAccumulator {
340
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
341
0
        Ok(vec![
342
0
            ScalarValue::from(self.count),
343
0
            ScalarValue::from(self.mean),
344
0
            ScalarValue::from(self.m2),
345
0
        ])
346
0
    }
347
348
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
349
0
        let values = &cast(&values[0], &DataType::Float64)?;
350
0
        let arr = downcast_value!(values, Float64Array).iter().flatten();
351
352
0
        for value in arr {
353
0
            (self.count, self.mean, self.m2) =
354
0
                update(self.count, self.mean, self.m2, value)
355
        }
356
357
0
        Ok(())
358
0
    }
359
360
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
361
0
        let values = &cast(&values[0], &DataType::Float64)?;
362
0
        let arr = downcast_value!(values, Float64Array).iter().flatten();
363
364
0
        for value in arr {
365
0
            let new_count = self.count - 1;
366
0
            let delta1 = self.mean - value;
367
0
            let new_mean = delta1 / new_count as f64 + self.mean;
368
0
            let delta2 = new_mean - value;
369
0
            let new_m2 = self.m2 - delta1 * delta2;
370
0
371
0
            self.count -= 1;
372
0
            self.mean = new_mean;
373
0
            self.m2 = new_m2;
374
0
        }
375
376
0
        Ok(())
377
0
    }
378
379
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
380
0
        let counts = downcast_value!(states[0], UInt64Array);
381
0
        let means = downcast_value!(states[1], Float64Array);
382
0
        let m2s = downcast_value!(states[2], Float64Array);
383
384
0
        for i in 0..counts.len() {
385
0
            let c = counts.value(i);
386
0
            if c == 0_u64 {
387
0
                continue;
388
0
            }
389
0
            (self.count, self.mean, self.m2) = merge(
390
0
                self.count,
391
0
                self.mean,
392
0
                self.m2,
393
0
                c,
394
0
                means.value(i),
395
0
                m2s.value(i),
396
0
            )
397
        }
398
0
        Ok(())
399
0
    }
400
401
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
402
0
        let count = match self.stats_type {
403
0
            StatsType::Population => self.count,
404
            StatsType::Sample => {
405
0
                if self.count > 0 {
406
0
                    self.count - 1
407
                } else {
408
0
                    self.count
409
                }
410
            }
411
        };
412
413
0
        Ok(ScalarValue::Float64(match self.count {
414
0
            0 => None,
415
            1 => {
416
0
                if let StatsType::Population = self.stats_type {
417
0
                    Some(0.0)
418
                } else {
419
0
                    None
420
                }
421
            }
422
0
            _ => Some(self.m2 / count as f64),
423
        }))
424
0
    }
425
426
0
    fn size(&self) -> usize {
427
0
        std::mem::size_of_val(self)
428
0
    }
429
430
0
    fn supports_retract_batch(&self) -> bool {
431
0
        true
432
0
    }
433
}
434
435
#[derive(Debug)]
436
pub struct VarianceGroupsAccumulator {
437
    m2s: Vec<f64>,
438
    means: Vec<f64>,
439
    counts: Vec<u64>,
440
    stats_type: StatsType,
441
}
442
443
impl VarianceGroupsAccumulator {
444
0
    pub fn new(s_type: StatsType) -> Self {
445
0
        Self {
446
0
            m2s: Vec::new(),
447
0
            means: Vec::new(),
448
0
            counts: Vec::new(),
449
0
            stats_type: s_type,
450
0
        }
451
0
    }
452
453
0
    fn resize(&mut self, total_num_groups: usize) {
454
0
        self.m2s.resize(total_num_groups, 0.0);
455
0
        self.means.resize(total_num_groups, 0.0);
456
0
        self.counts.resize(total_num_groups, 0);
457
0
    }
458
459
0
    fn merge<F>(
460
0
        group_indices: &[usize],
461
0
        counts: &UInt64Array,
462
0
        means: &Float64Array,
463
0
        m2s: &Float64Array,
464
0
        opt_filter: Option<&BooleanArray>,
465
0
        mut value_fn: F,
466
0
    ) where
467
0
        F: FnMut(usize, u64, f64, f64) + Send,
468
0
    {
469
0
        assert_eq!(counts.null_count(), 0);
470
0
        assert_eq!(means.null_count(), 0);
471
0
        assert_eq!(m2s.null_count(), 0);
472
473
0
        match opt_filter {
474
0
            None => {
475
0
                group_indices
476
0
                    .iter()
477
0
                    .zip(counts.values().iter())
478
0
                    .zip(means.values().iter())
479
0
                    .zip(m2s.values().iter())
480
0
                    .for_each(|(((&group_index, &count), &mean), &m2)| {
481
0
                        value_fn(group_index, count, mean, m2);
482
0
                    });
483
0
            }
484
0
            Some(filter) => {
485
0
                group_indices
486
0
                    .iter()
487
0
                    .zip(counts.values().iter())
488
0
                    .zip(means.values().iter())
489
0
                    .zip(m2s.values().iter())
490
0
                    .zip(filter.iter())
491
0
                    .for_each(
492
0
                        |((((&group_index, &count), &mean), &m2), filter_value)| {
493
0
                            if let Some(true) = filter_value {
494
0
                                value_fn(group_index, count, mean, m2);
495
0
                            }
496
0
                        },
497
0
                    );
498
0
            }
499
        }
500
0
    }
501
502
0
    pub fn variance(
503
0
        &mut self,
504
0
        emit_to: datafusion_expr::EmitTo,
505
0
    ) -> (Vec<f64>, NullBuffer) {
506
0
        let mut counts = emit_to.take_needed(&mut self.counts);
507
0
        // means are only needed for updating m2s and are not needed for the final result.
508
0
        // But we still need to take them to ensure the internal state is consistent.
509
0
        let _ = emit_to.take_needed(&mut self.means);
510
0
        let m2s = emit_to.take_needed(&mut self.m2s);
511
0
512
0
        if let StatsType::Sample = self.stats_type {
513
0
            counts.iter_mut().for_each(|count| {
514
0
                *count = count.saturating_sub(1);
515
0
            });
516
0
        }
517
0
        let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
518
0
        let variance = m2s
519
0
            .iter()
520
0
            .zip(counts)
521
0
            .map(|(m2, count)| m2 / count as f64)
522
0
            .collect();
523
0
        (variance, nulls)
524
0
    }
525
}
526
527
impl GroupsAccumulator for VarianceGroupsAccumulator {
528
0
    fn update_batch(
529
0
        &mut self,
530
0
        values: &[ArrayRef],
531
0
        group_indices: &[usize],
532
0
        opt_filter: Option<&arrow::array::BooleanArray>,
533
0
        total_num_groups: usize,
534
0
    ) -> Result<()> {
535
0
        assert_eq!(values.len(), 1, "single argument to update_batch");
536
0
        let values = &cast(&values[0], &DataType::Float64)?;
537
0
        let values = downcast_value!(values, Float64Array);
538
539
0
        self.resize(total_num_groups);
540
0
        accumulate(group_indices, values, opt_filter, |group_index, value| {
541
0
            let (new_count, new_mean, new_m2) = update(
542
0
                self.counts[group_index],
543
0
                self.means[group_index],
544
0
                self.m2s[group_index],
545
0
                value,
546
0
            );
547
0
            self.counts[group_index] = new_count;
548
0
            self.means[group_index] = new_mean;
549
0
            self.m2s[group_index] = new_m2;
550
0
        });
551
0
        Ok(())
552
0
    }
553
554
0
    fn merge_batch(
555
0
        &mut self,
556
0
        values: &[ArrayRef],
557
0
        group_indices: &[usize],
558
0
        opt_filter: Option<&arrow::array::BooleanArray>,
559
0
        total_num_groups: usize,
560
0
    ) -> Result<()> {
561
0
        assert_eq!(values.len(), 3, "two arguments to merge_batch");
562
        // first batch is counts, second is partial means, third is partial m2s
563
0
        let partial_counts = downcast_value!(values[0], UInt64Array);
564
0
        let partial_means = downcast_value!(values[1], Float64Array);
565
0
        let partial_m2s = downcast_value!(values[2], Float64Array);
566
567
0
        self.resize(total_num_groups);
568
0
        Self::merge(
569
0
            group_indices,
570
0
            partial_counts,
571
0
            partial_means,
572
0
            partial_m2s,
573
0
            opt_filter,
574
0
            |group_index, partial_count, partial_mean, partial_m2| {
575
0
                let (new_count, new_mean, new_m2) = merge(
576
0
                    self.counts[group_index],
577
0
                    self.means[group_index],
578
0
                    self.m2s[group_index],
579
0
                    partial_count,
580
0
                    partial_mean,
581
0
                    partial_m2,
582
0
                );
583
0
                self.counts[group_index] = new_count;
584
0
                self.means[group_index] = new_mean;
585
0
                self.m2s[group_index] = new_m2;
586
0
            },
587
0
        );
588
0
        Ok(())
589
0
    }
590
591
0
    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
592
0
        let (variances, nulls) = self.variance(emit_to);
593
0
        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
594
0
    }
595
596
0
    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
597
0
        let counts = emit_to.take_needed(&mut self.counts);
598
0
        let means = emit_to.take_needed(&mut self.means);
599
0
        let m2s = emit_to.take_needed(&mut self.m2s);
600
0
601
0
        Ok(vec![
602
0
            Arc::new(UInt64Array::new(counts.into(), None)),
603
0
            Arc::new(Float64Array::new(means.into(), None)),
604
0
            Arc::new(Float64Array::new(m2s.into(), None)),
605
0
        ])
606
0
    }
607
608
0
    fn size(&self) -> usize {
609
0
        self.m2s.capacity() * std::mem::size_of::<f64>()
610
0
            + self.means.capacity() * std::mem::size_of::<f64>()
611
0
            + self.counts.capacity() * std::mem::size_of::<u64>()
612
0
    }
613
}