Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/stddev.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines physical expressions that can evaluated at runtime during query execution
19
20
use std::any::Any;
21
use std::fmt::{Debug, Formatter};
22
use std::sync::Arc;
23
24
use arrow::array::Float64Array;
25
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
26
27
use datafusion_common::{internal_err, not_impl_err, Result};
28
use datafusion_common::{plan_err, ScalarValue};
29
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30
use datafusion_expr::utils::format_state_name;
31
use datafusion_expr::{
32
    Accumulator, AggregateUDFImpl, GroupsAccumulator, Signature, Volatility,
33
};
34
use datafusion_functions_aggregate_common::stats::StatsType;
35
36
use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};
37
38
make_udaf_expr_and_func!(
39
    Stddev,
40
    stddev,
41
    expression,
42
    "Compute the standard deviation of a set of numbers",
43
    stddev_udaf
44
);
45
46
/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
47
pub struct Stddev {
48
    signature: Signature,
49
    alias: Vec<String>,
50
}
51
52
impl Debug for Stddev {
53
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54
0
        f.debug_struct("Stddev")
55
0
            .field("name", &self.name())
56
0
            .field("signature", &self.signature)
57
0
            .finish()
58
0
    }
59
}
60
61
impl Default for Stddev {
62
0
    fn default() -> Self {
63
0
        Self::new()
64
0
    }
65
}
66
67
impl Stddev {
68
    /// Create a new STDDEV aggregate function
69
0
    pub fn new() -> Self {
70
0
        Self {
71
0
            signature: Signature::coercible(
72
0
                vec![DataType::Float64],
73
0
                Volatility::Immutable,
74
0
            ),
75
0
            alias: vec!["stddev_samp".to_string()],
76
0
        }
77
0
    }
78
}
79
80
impl AggregateUDFImpl for Stddev {
81
    /// Return a reference to Any that can be used for downcasting
82
0
    fn as_any(&self) -> &dyn Any {
83
0
        self
84
0
    }
85
86
0
    fn name(&self) -> &str {
87
0
        "stddev"
88
0
    }
89
90
0
    fn signature(&self) -> &Signature {
91
0
        &self.signature
92
0
    }
93
94
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
95
0
        Ok(DataType::Float64)
96
0
    }
97
98
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
99
0
        Ok(vec![
100
0
            Field::new(
101
0
                format_state_name(args.name, "count"),
102
0
                DataType::UInt64,
103
0
                true,
104
0
            ),
105
0
            Field::new(
106
0
                format_state_name(args.name, "mean"),
107
0
                DataType::Float64,
108
0
                true,
109
0
            ),
110
0
            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
111
0
        ])
112
0
    }
113
114
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
115
0
        if acc_args.is_distinct {
116
0
            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
117
0
        }
118
0
        Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
119
0
    }
120
121
0
    fn aliases(&self) -> &[String] {
122
0
        &self.alias
123
0
    }
124
125
0
    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
126
0
        !acc_args.is_distinct
127
0
    }
128
129
0
    fn create_groups_accumulator(
130
0
        &self,
131
0
        _args: AccumulatorArgs,
132
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
133
0
        Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
134
0
    }
135
}
136
137
make_udaf_expr_and_func!(
138
    StddevPop,
139
    stddev_pop,
140
    expression,
141
    "Compute the population standard deviation of a set of numbers",
142
    stddev_pop_udaf
143
);
144
145
/// STDDEV_POP population aggregate expression
146
pub struct StddevPop {
147
    signature: Signature,
148
}
149
150
impl Debug for StddevPop {
151
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152
0
        f.debug_struct("StddevPop")
153
0
            .field("name", &self.name())
154
0
            .field("signature", &self.signature)
155
0
            .finish()
156
0
    }
157
}
158
159
impl Default for StddevPop {
160
0
    fn default() -> Self {
161
0
        Self::new()
162
0
    }
163
}
164
165
impl StddevPop {
166
    /// Create a new STDDEV_POP aggregate function
167
0
    pub fn new() -> Self {
168
0
        Self {
169
0
            signature: Signature::numeric(1, Volatility::Immutable),
170
0
        }
171
0
    }
172
}
173
174
impl AggregateUDFImpl for StddevPop {
175
    /// Return a reference to Any that can be used for downcasting
176
0
    fn as_any(&self) -> &dyn Any {
177
0
        self
178
0
    }
179
180
0
    fn name(&self) -> &str {
181
0
        "stddev_pop"
182
0
    }
183
184
0
    fn signature(&self) -> &Signature {
185
0
        &self.signature
186
0
    }
187
188
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
189
0
        Ok(vec![
190
0
            Field::new(
191
0
                format_state_name(args.name, "count"),
192
0
                DataType::UInt64,
193
0
                true,
194
0
            ),
195
0
            Field::new(
196
0
                format_state_name(args.name, "mean"),
197
0
                DataType::Float64,
198
0
                true,
199
0
            ),
200
0
            Field::new(format_state_name(args.name, "m2"), DataType::Float64, true),
201
0
        ])
202
0
    }
203
204
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
205
0
        if acc_args.is_distinct {
206
0
            return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available");
207
0
        }
208
0
        Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
209
0
    }
210
211
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
212
0
        if !arg_types[0].is_numeric() {
213
0
            return plan_err!("StddevPop requires numeric input types");
214
0
        }
215
0
216
0
        Ok(DataType::Float64)
217
0
    }
218
219
0
    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
220
0
        !acc_args.is_distinct
221
0
    }
222
223
0
    fn create_groups_accumulator(
224
0
        &self,
225
0
        _args: AccumulatorArgs,
226
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
227
0
        Ok(Box::new(StddevGroupsAccumulator::new(
228
0
            StatsType::Population,
229
0
        )))
230
0
    }
231
}
232
233
/// An accumulator to compute the average
234
#[derive(Debug)]
235
pub struct StddevAccumulator {
236
    variance: VarianceAccumulator,
237
}
238
239
impl StddevAccumulator {
240
    /// Creates a new `StddevAccumulator`
241
0
    pub fn try_new(s_type: StatsType) -> Result<Self> {
242
0
        Ok(Self {
243
0
            variance: VarianceAccumulator::try_new(s_type)?,
244
        })
245
0
    }
246
247
0
    pub fn get_m2(&self) -> f64 {
248
0
        self.variance.get_m2()
249
0
    }
250
}
251
252
impl Accumulator for StddevAccumulator {
253
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
254
0
        Ok(vec![
255
0
            ScalarValue::from(self.variance.get_count()),
256
0
            ScalarValue::from(self.variance.get_mean()),
257
0
            ScalarValue::from(self.variance.get_m2()),
258
0
        ])
259
0
    }
260
261
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
262
0
        self.variance.update_batch(values)
263
0
    }
264
265
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
266
0
        self.variance.retract_batch(values)
267
0
    }
268
269
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
270
0
        self.variance.merge_batch(states)
271
0
    }
272
273
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
274
0
        let variance = self.variance.evaluate()?;
275
0
        match variance {
276
0
            ScalarValue::Float64(e) => {
277
0
                if e.is_none() {
278
0
                    Ok(ScalarValue::Float64(None))
279
                } else {
280
0
                    Ok(ScalarValue::Float64(e.map(|f| f.sqrt())))
281
                }
282
            }
283
0
            _ => internal_err!("Variance should be f64"),
284
        }
285
0
    }
286
287
0
    fn size(&self) -> usize {
288
0
        std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance)
289
0
            + self.variance.size()
290
0
    }
291
292
0
    fn supports_retract_batch(&self) -> bool {
293
0
        self.variance.supports_retract_batch()
294
0
    }
295
}
296
297
#[derive(Debug)]
298
pub struct StddevGroupsAccumulator {
299
    variance: VarianceGroupsAccumulator,
300
}
301
302
impl StddevGroupsAccumulator {
303
0
    pub fn new(s_type: StatsType) -> Self {
304
0
        Self {
305
0
            variance: VarianceGroupsAccumulator::new(s_type),
306
0
        }
307
0
    }
308
}
309
310
impl GroupsAccumulator for StddevGroupsAccumulator {
311
0
    fn update_batch(
312
0
        &mut self,
313
0
        values: &[ArrayRef],
314
0
        group_indices: &[usize],
315
0
        opt_filter: Option<&arrow::array::BooleanArray>,
316
0
        total_num_groups: usize,
317
0
    ) -> Result<()> {
318
0
        self.variance
319
0
            .update_batch(values, group_indices, opt_filter, total_num_groups)
320
0
    }
321
322
0
    fn merge_batch(
323
0
        &mut self,
324
0
        values: &[ArrayRef],
325
0
        group_indices: &[usize],
326
0
        opt_filter: Option<&arrow::array::BooleanArray>,
327
0
        total_num_groups: usize,
328
0
    ) -> Result<()> {
329
0
        self.variance
330
0
            .merge_batch(values, group_indices, opt_filter, total_num_groups)
331
0
    }
332
333
0
    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
334
0
        let (mut variances, nulls) = self.variance.variance(emit_to);
335
0
        variances.iter_mut().for_each(|v| *v = v.sqrt());
336
0
        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
337
0
    }
338
339
0
    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
340
0
        self.variance.state(emit_to)
341
0
    }
342
343
0
    fn size(&self) -> usize {
344
0
        self.variance.size()
345
0
    }
346
}
347
348
#[cfg(test)]
349
mod tests {
350
    use super::*;
351
    use arrow::{array::*, datatypes::*};
352
    use datafusion_expr::AggregateUDF;
353
    use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays;
354
    use datafusion_physical_expr::expressions::col;
355
    use std::sync::Arc;
356
357
    #[test]
358
    fn stddev_f64_merge_1() -> Result<()> {
359
        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
360
        let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
361
362
        let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
363
364
        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
365
        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
366
367
        let agg1 = stddev_pop_udaf();
368
        let agg2 = stddev_pop_udaf();
369
370
        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
371
        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
372
373
        Ok(())
374
    }
375
376
    #[test]
377
    fn stddev_f64_merge_2() -> Result<()> {
378
        let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
379
        let b = Arc::new(Float64Array::from(vec![None]));
380
381
        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
382
383
        let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
384
        let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?;
385
386
        let agg1 = stddev_pop_udaf();
387
        let agg2 = stddev_pop_udaf();
388
389
        let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?;
390
        assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2));
391
392
        Ok(())
393
    }
394
395
    fn merge(
396
        batch1: &RecordBatch,
397
        batch2: &RecordBatch,
398
        agg1: Arc<AggregateUDF>,
399
        agg2: Arc<AggregateUDF>,
400
        schema: &Schema,
401
    ) -> Result<ScalarValue> {
402
        let args1 = AccumulatorArgs {
403
            return_type: &DataType::Float64,
404
            schema,
405
            ignore_nulls: false,
406
            ordering_req: &[],
407
            name: "a",
408
            is_distinct: false,
409
            is_reversed: false,
410
            exprs: &[col("a", schema)?],
411
        };
412
413
        let args2 = AccumulatorArgs {
414
            return_type: &DataType::Float64,
415
            schema,
416
            ignore_nulls: false,
417
            ordering_req: &[],
418
            name: "a",
419
            is_distinct: false,
420
            is_reversed: false,
421
            exprs: &[col("a", schema)?],
422
        };
423
424
        let mut accum1 = agg1.accumulator(args1)?;
425
        let mut accum2 = agg2.accumulator(args2)?;
426
427
        let value1 = vec![col("a", schema)?
428
            .evaluate(batch1)
429
            .and_then(|v| v.into_array(batch1.num_rows()))?];
430
        let value2 = vec![col("a", schema)?
431
            .evaluate(batch2)
432
            .and_then(|v| v.into_array(batch2.num_rows()))?];
433
434
        accum1.update_batch(&value1)?;
435
        accum2.update_batch(&value2)?;
436
        let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
437
        accum1.merge_batch(&state2)?;
438
        let result = accum1.evaluate()?;
439
        Ok(result)
440
    }
441
}