Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/regr.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines physical expressions that can evaluated at runtime during query execution
19
20
use std::any::Any;
21
use std::fmt::Debug;
22
23
use arrow::array::Float64Array;
24
use arrow::{
25
    array::{ArrayRef, UInt64Array},
26
    compute::cast,
27
    datatypes::DataType,
28
    datatypes::Field,
29
};
30
use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue};
31
use datafusion_common::{DataFusionError, Result};
32
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33
use datafusion_expr::type_coercion::aggregates::NUMERICS;
34
use datafusion_expr::utils::format_state_name;
35
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
36
37
macro_rules! make_regr_udaf_expr_and_func {
38
    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
39
        make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
40
        create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
41
    }
42
}
43
44
make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
45
make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
46
make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
47
make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
48
make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
49
make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
50
make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
51
make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
52
make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
53
54
pub struct Regr {
55
    signature: Signature,
56
    regr_type: RegrType,
57
    func_name: &'static str,
58
}
59
60
impl Debug for Regr {
61
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
62
0
        f.debug_struct("regr")
63
0
            .field("name", &self.name())
64
0
            .field("signature", &self.signature)
65
0
            .finish()
66
0
    }
67
}
68
69
impl Regr {
70
0
    pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
71
0
        Self {
72
0
            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
73
0
            regr_type,
74
0
            func_name,
75
0
        }
76
0
    }
77
}
78
79
/*
80
#[derive(Debug)]
81
pub struct Regr {
82
    name: String,
83
    regr_type: RegrType,
84
    expr_y: Arc<dyn PhysicalExpr>,
85
    expr_x: Arc<dyn PhysicalExpr>,
86
}
87
88
impl Regr {
89
    pub fn get_regr_type(&self) -> RegrType {
90
        self.regr_type.clone()
91
    }
92
}
93
*/
94
95
#[derive(Debug, Clone)]
96
#[allow(clippy::upper_case_acronyms)]
97
pub enum RegrType {
98
    /// Variant for `regr_slope` aggregate expression
99
    /// Returns the slope of the linear regression line for non-null pairs in aggregate columns.
100
    /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal
101
    /// RSS (Residual Sum of Squares) fitting.
102
    Slope,
103
    /// Variant for `regr_intercept` aggregate expression
104
    /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns.
105
    /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal
106
    /// RSS fitting.
107
    Intercept,
108
    /// Variant for `regr_count` aggregate expression
109
    /// Returns the number of input rows for which both expressions are not null.
110
    /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs.
111
    Count,
112
    /// Variant for `regr_r2` aggregate expression
113
    /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns.
114
    /// The R-squared value represents the proportion of variance in Y that is predictable from X.
115
    R2,
116
    /// Variant for `regr_avgx` aggregate expression
117
    /// Returns the average of the independent variable for non-null pairs in aggregate columns.
118
    /// Given input column X: `regr_avgx(Y, X)` returns the average of X values.
119
    AvgX,
120
    /// Variant for `regr_avgy` aggregate expression
121
    /// Returns the average of the dependent variable for non-null pairs in aggregate columns.
122
    /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values.
123
    AvgY,
124
    /// Variant for `regr_sxx` aggregate expression
125
    /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns.
126
    /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean.
127
    SXX,
128
    /// Variant for `regr_syy` aggregate expression
129
    /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns.
130
    /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean.
131
    SYY,
132
    /// Variant for `regr_sxy` aggregate expression
133
    /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns.
134
    /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means.
135
    SXY,
136
}
137
138
impl AggregateUDFImpl for Regr {
139
0
    fn as_any(&self) -> &dyn Any {
140
0
        self
141
0
    }
142
143
0
    fn name(&self) -> &str {
144
0
        self.func_name
145
0
    }
146
147
0
    fn signature(&self) -> &Signature {
148
0
        &self.signature
149
0
    }
150
151
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152
0
        if !arg_types[0].is_numeric() {
153
0
            return plan_err!("Covariance requires numeric input types");
154
0
        }
155
156
0
        if matches!(self.regr_type, RegrType::Count) {
157
0
            Ok(DataType::UInt64)
158
        } else {
159
0
            Ok(DataType::Float64)
160
        }
161
0
    }
162
163
0
    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
164
0
        Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
165
0
    }
166
167
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
168
0
        Ok(vec![
169
0
            Field::new(
170
0
                format_state_name(args.name, "count"),
171
0
                DataType::UInt64,
172
0
                true,
173
0
            ),
174
0
            Field::new(
175
0
                format_state_name(args.name, "mean_x"),
176
0
                DataType::Float64,
177
0
                true,
178
0
            ),
179
0
            Field::new(
180
0
                format_state_name(args.name, "mean_y"),
181
0
                DataType::Float64,
182
0
                true,
183
0
            ),
184
0
            Field::new(
185
0
                format_state_name(args.name, "m2_x"),
186
0
                DataType::Float64,
187
0
                true,
188
0
            ),
189
0
            Field::new(
190
0
                format_state_name(args.name, "m2_y"),
191
0
                DataType::Float64,
192
0
                true,
193
0
            ),
194
0
            Field::new(
195
0
                format_state_name(args.name, "algo_const"),
196
0
                DataType::Float64,
197
0
                true,
198
0
            ),
199
0
        ])
200
0
    }
201
}
202
203
/*
204
impl PartialEq<dyn Any> for Regr {
205
    fn eq(&self, other: &dyn Any) -> bool {
206
        down_cast_any_ref(other)
207
            .downcast_ref::<Self>()
208
            .map(|x| {
209
                self.name == x.name
210
                    && self.expr_y.eq(&x.expr_y)
211
                    && self.expr_x.eq(&x.expr_x)
212
            })
213
            .unwrap_or(false)
214
    }
215
}
216
*/
217
218
/// `RegrAccumulator` is used to compute linear regression aggregate functions
219
/// by maintaining statistics needed to compute them in an online fashion.
220
///
221
/// This struct uses Welford's online algorithm for calculating variance and covariance:
222
/// <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm>
223
///
224
/// Given the statistics, the following aggregate functions can be calculated:
225
///
226
/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as:
227
///   cov_pop(x, y) / var_pop(x).
228
///   It represents the expected change in Y for a one-unit change in X.
229
///
230
/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as:
231
///   mean_y - (regr_slope(y, x) * mean_x).
232
///   It represents the expected value of Y when X is 0.
233
///
234
/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows.
235
///
236
/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as:
237
///   (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)).
238
///   It provides a measure of how well the model's predictions match the observed data.
239
///
240
/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x.
241
///
242
/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y.
243
///
244
/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as:
245
///   m2_x.
246
///
247
/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as:
248
///   m2_y.
249
///
250
/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as:
251
///   algo_const.
252
///
253
/// Here's how the statistics maintained in this struct are calculated:
254
/// - `cov_pop(x, y)`: algo_const / count.
255
/// - `var_pop(x)`: m2_x / count.
256
/// - `var_pop(y)`: m2_y / count.
257
#[derive(Debug)]
258
pub struct RegrAccumulator {
259
    count: u64,
260
    mean_x: f64,
261
    mean_y: f64,
262
    m2_x: f64,
263
    m2_y: f64,
264
    algo_const: f64,
265
    regr_type: RegrType,
266
}
267
268
impl RegrAccumulator {
269
    /// Creates a new `RegrAccumulator`
270
0
    pub fn try_new(regr_type: &RegrType) -> Result<Self> {
271
0
        Ok(Self {
272
0
            count: 0_u64,
273
0
            mean_x: 0_f64,
274
0
            mean_y: 0_f64,
275
0
            m2_x: 0_f64,
276
0
            m2_y: 0_f64,
277
0
            algo_const: 0_f64,
278
0
            regr_type: regr_type.clone(),
279
0
        })
280
0
    }
281
}
282
283
impl Accumulator for RegrAccumulator {
284
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
285
0
        Ok(vec![
286
0
            ScalarValue::from(self.count),
287
0
            ScalarValue::from(self.mean_x),
288
0
            ScalarValue::from(self.mean_y),
289
0
            ScalarValue::from(self.m2_x),
290
0
            ScalarValue::from(self.m2_y),
291
0
            ScalarValue::from(self.algo_const),
292
0
        ])
293
0
    }
294
295
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
296
        // regr_slope(Y, X) calculates k in y = k*x + b
297
0
        let values_y = &cast(&values[0], &DataType::Float64)?;
298
0
        let values_x = &cast(&values[1], &DataType::Float64)?;
299
300
0
        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
301
0
        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
302
303
0
        for i in 0..values_y.len() {
304
            // skip either x or y is NULL
305
0
            let value_y = if values_y.is_valid(i) {
306
0
                arr_y.next()
307
            } else {
308
0
                None
309
            };
310
0
            let value_x = if values_x.is_valid(i) {
311
0
                arr_x.next()
312
            } else {
313
0
                None
314
            };
315
0
            if value_y.is_none() || value_x.is_none() {
316
0
                continue;
317
0
            }
318
319
            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
320
0
            let value_y = unwrap_or_internal_err!(value_y);
321
0
            let value_x = unwrap_or_internal_err!(value_x);
322
323
0
            self.count += 1;
324
0
            let delta_x = value_x - self.mean_x;
325
0
            let delta_y = value_y - self.mean_y;
326
0
            self.mean_x += delta_x / self.count as f64;
327
0
            self.mean_y += delta_y / self.count as f64;
328
0
            let delta_x_2 = value_x - self.mean_x;
329
0
            let delta_y_2 = value_y - self.mean_y;
330
0
            self.m2_x += delta_x * delta_x_2;
331
0
            self.m2_y += delta_y * delta_y_2;
332
0
            self.algo_const += delta_x * (value_y - self.mean_y);
333
        }
334
335
0
        Ok(())
336
0
    }
337
338
0
    fn supports_retract_batch(&self) -> bool {
339
0
        true
340
0
    }
341
342
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
343
0
        let values_y = &cast(&values[0], &DataType::Float64)?;
344
0
        let values_x = &cast(&values[1], &DataType::Float64)?;
345
346
0
        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
347
0
        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
348
349
0
        for i in 0..values_y.len() {
350
            // skip either x or y is NULL
351
0
            let value_y = if values_y.is_valid(i) {
352
0
                arr_y.next()
353
            } else {
354
0
                None
355
            };
356
0
            let value_x = if values_x.is_valid(i) {
357
0
                arr_x.next()
358
            } else {
359
0
                None
360
            };
361
0
            if value_y.is_none() || value_x.is_none() {
362
0
                continue;
363
0
            }
364
365
            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
366
0
            let value_y = unwrap_or_internal_err!(value_y);
367
0
            let value_x = unwrap_or_internal_err!(value_x);
368
369
0
            if self.count > 1 {
370
0
                self.count -= 1;
371
0
                let delta_x = value_x - self.mean_x;
372
0
                let delta_y = value_y - self.mean_y;
373
0
                self.mean_x -= delta_x / self.count as f64;
374
0
                self.mean_y -= delta_y / self.count as f64;
375
0
                let delta_x_2 = value_x - self.mean_x;
376
0
                let delta_y_2 = value_y - self.mean_y;
377
0
                self.m2_x -= delta_x * delta_x_2;
378
0
                self.m2_y -= delta_y * delta_y_2;
379
0
                self.algo_const -= delta_x * (value_y - self.mean_y);
380
0
            } else {
381
0
                self.count = 0;
382
0
                self.mean_x = 0.0;
383
0
                self.m2_x = 0.0;
384
0
                self.m2_y = 0.0;
385
0
                self.mean_y = 0.0;
386
0
                self.algo_const = 0.0;
387
0
            }
388
        }
389
390
0
        Ok(())
391
0
    }
392
393
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
394
0
        let count_arr = downcast_value!(states[0], UInt64Array);
395
0
        let mean_x_arr = downcast_value!(states[1], Float64Array);
396
0
        let mean_y_arr = downcast_value!(states[2], Float64Array);
397
0
        let m2_x_arr = downcast_value!(states[3], Float64Array);
398
0
        let m2_y_arr = downcast_value!(states[4], Float64Array);
399
0
        let algo_const_arr = downcast_value!(states[5], Float64Array);
400
401
0
        for i in 0..count_arr.len() {
402
0
            let count_b = count_arr.value(i);
403
0
            if count_b == 0_u64 {
404
0
                continue;
405
0
            }
406
0
            let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
407
0
                self.count,
408
0
                self.mean_x,
409
0
                self.mean_y,
410
0
                self.m2_x,
411
0
                self.m2_y,
412
0
                self.algo_const,
413
0
            );
414
0
            let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
415
0
                count_b,
416
0
                mean_x_arr.value(i),
417
0
                mean_y_arr.value(i),
418
0
                m2_x_arr.value(i),
419
0
                m2_y_arr.value(i),
420
0
                algo_const_arr.value(i),
421
0
            );
422
0
423
0
            // Assuming two different batches of input have calculated the states:
424
0
            // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a}
425
0
            // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b}
426
0
            // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab,
427
0
            // algo_const_ab}
428
0
            //
429
0
            // Reference for the algorithm to merge states:
430
0
            // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
431
0
            let count_ab = count_a + count_b;
432
0
            let (count_a, count_b) = (count_a as f64, count_b as f64);
433
0
            let d_x = mean_x_b - mean_x_a;
434
0
            let d_y = mean_y_b - mean_y_a;
435
0
            let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
436
0
            let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
437
0
            let m2_x_ab =
438
0
                m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
439
0
            let m2_y_ab =
440
0
                m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
441
0
            let algo_const_ab = algo_const_a
442
0
                + algo_const_b
443
0
                + d_x * d_y * count_a * count_b / count_ab as f64;
444
0
445
0
            self.count = count_ab;
446
0
            self.mean_x = mean_x_ab;
447
0
            self.mean_y = mean_y_ab;
448
0
            self.m2_x = m2_x_ab;
449
0
            self.m2_y = m2_y_ab;
450
0
            self.algo_const = algo_const_ab;
451
        }
452
0
        Ok(())
453
0
    }
454
455
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
456
0
        let cov_pop_x_y = self.algo_const / self.count as f64;
457
0
        let var_pop_x = self.m2_x / self.count as f64;
458
0
        let var_pop_y = self.m2_y / self.count as f64;
459
0
460
0
        let nullif_or_stat = |cond: bool, stat: f64| {
461
0
            if cond {
462
0
                Ok(ScalarValue::Float64(None))
463
            } else {
464
0
                Ok(ScalarValue::Float64(Some(stat)))
465
            }
466
0
        };
467
468
0
        match self.regr_type {
469
            RegrType::Slope => {
470
                // Only 0/1 point or slope is infinite
471
0
                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
472
0
                nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
473
            }
474
            RegrType::Intercept => {
475
0
                let slope = cov_pop_x_y / var_pop_x;
476
                // Only 0/1 point or slope is infinite
477
0
                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
478
0
                nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
479
            }
480
0
            RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
481
            RegrType::R2 => {
482
                // Only 0/1 point or all x(or y) is the same
483
0
                let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
484
0
                nullif_or_stat(
485
0
                    nullif_cond,
486
0
                    (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
487
0
                )
488
            }
489
0
            RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
490
0
            RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
491
0
            RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
492
0
            RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
493
0
            RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
494
        }
495
0
    }
496
497
0
    fn size(&self) -> usize {
498
0
        std::mem::size_of_val(self)
499
0
    }
500
}