/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/regr.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines physical expressions that can evaluated at runtime during query execution |
19 | | |
20 | | use std::any::Any; |
21 | | use std::fmt::Debug; |
22 | | |
23 | | use arrow::array::Float64Array; |
24 | | use arrow::{ |
25 | | array::{ArrayRef, UInt64Array}, |
26 | | compute::cast, |
27 | | datatypes::DataType, |
28 | | datatypes::Field, |
29 | | }; |
30 | | use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; |
31 | | use datafusion_common::{DataFusionError, Result}; |
32 | | use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; |
33 | | use datafusion_expr::type_coercion::aggregates::NUMERICS; |
34 | | use datafusion_expr::utils::format_state_name; |
35 | | use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; |
36 | | |
37 | | macro_rules! make_regr_udaf_expr_and_func { |
38 | | ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { |
39 | | make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN); |
40 | | create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN))); |
41 | | } |
42 | | } |
43 | | |
44 | | make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope); |
45 | | make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept); |
46 | | make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count); |
47 | | make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2); |
48 | | make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX); |
49 | | make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY); |
50 | | make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); |
51 | | make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); |
52 | | make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); |
53 | | |
54 | | pub struct Regr { |
55 | | signature: Signature, |
56 | | regr_type: RegrType, |
57 | | func_name: &'static str, |
58 | | } |
59 | | |
60 | | impl Debug for Regr { |
61 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
62 | 0 | f.debug_struct("regr") |
63 | 0 | .field("name", &self.name()) |
64 | 0 | .field("signature", &self.signature) |
65 | 0 | .finish() |
66 | 0 | } |
67 | | } |
68 | | |
69 | | impl Regr { |
70 | 0 | pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { |
71 | 0 | Self { |
72 | 0 | signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), |
73 | 0 | regr_type, |
74 | 0 | func_name, |
75 | 0 | } |
76 | 0 | } |
77 | | } |
78 | | |
79 | | /* |
80 | | #[derive(Debug)] |
81 | | pub struct Regr { |
82 | | name: String, |
83 | | regr_type: RegrType, |
84 | | expr_y: Arc<dyn PhysicalExpr>, |
85 | | expr_x: Arc<dyn PhysicalExpr>, |
86 | | } |
87 | | |
88 | | impl Regr { |
89 | | pub fn get_regr_type(&self) -> RegrType { |
90 | | self.regr_type.clone() |
91 | | } |
92 | | } |
93 | | */ |
94 | | |
95 | | #[derive(Debug, Clone)] |
96 | | #[allow(clippy::upper_case_acronyms)] |
97 | | pub enum RegrType { |
98 | | /// Variant for `regr_slope` aggregate expression |
99 | | /// Returns the slope of the linear regression line for non-null pairs in aggregate columns. |
100 | | /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal |
101 | | /// RSS (Residual Sum of Squares) fitting. |
102 | | Slope, |
103 | | /// Variant for `regr_intercept` aggregate expression |
104 | | /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns. |
105 | | /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal |
106 | | /// RSS fitting. |
107 | | Intercept, |
108 | | /// Variant for `regr_count` aggregate expression |
109 | | /// Returns the number of input rows for which both expressions are not null. |
110 | | /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs. |
111 | | Count, |
112 | | /// Variant for `regr_r2` aggregate expression |
113 | | /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns. |
114 | | /// The R-squared value represents the proportion of variance in Y that is predictable from X. |
115 | | R2, |
116 | | /// Variant for `regr_avgx` aggregate expression |
117 | | /// Returns the average of the independent variable for non-null pairs in aggregate columns. |
118 | | /// Given input column X: `regr_avgx(Y, X)` returns the average of X values. |
119 | | AvgX, |
120 | | /// Variant for `regr_avgy` aggregate expression |
121 | | /// Returns the average of the dependent variable for non-null pairs in aggregate columns. |
122 | | /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values. |
123 | | AvgY, |
124 | | /// Variant for `regr_sxx` aggregate expression |
125 | | /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns. |
126 | | /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean. |
127 | | SXX, |
128 | | /// Variant for `regr_syy` aggregate expression |
129 | | /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns. |
130 | | /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean. |
131 | | SYY, |
132 | | /// Variant for `regr_sxy` aggregate expression |
133 | | /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns. |
134 | | /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means. |
135 | | SXY, |
136 | | } |
137 | | |
138 | | impl AggregateUDFImpl for Regr { |
139 | 0 | fn as_any(&self) -> &dyn Any { |
140 | 0 | self |
141 | 0 | } |
142 | | |
143 | 0 | fn name(&self) -> &str { |
144 | 0 | self.func_name |
145 | 0 | } |
146 | | |
147 | 0 | fn signature(&self) -> &Signature { |
148 | 0 | &self.signature |
149 | 0 | } |
150 | | |
151 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
152 | 0 | if !arg_types[0].is_numeric() { |
153 | 0 | return plan_err!("Covariance requires numeric input types"); |
154 | 0 | } |
155 | | |
156 | 0 | if matches!(self.regr_type, RegrType::Count) { |
157 | 0 | Ok(DataType::UInt64) |
158 | | } else { |
159 | 0 | Ok(DataType::Float64) |
160 | | } |
161 | 0 | } |
162 | | |
163 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
164 | 0 | Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) |
165 | 0 | } |
166 | | |
167 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
168 | 0 | Ok(vec![ |
169 | 0 | Field::new( |
170 | 0 | format_state_name(args.name, "count"), |
171 | 0 | DataType::UInt64, |
172 | 0 | true, |
173 | 0 | ), |
174 | 0 | Field::new( |
175 | 0 | format_state_name(args.name, "mean_x"), |
176 | 0 | DataType::Float64, |
177 | 0 | true, |
178 | 0 | ), |
179 | 0 | Field::new( |
180 | 0 | format_state_name(args.name, "mean_y"), |
181 | 0 | DataType::Float64, |
182 | 0 | true, |
183 | 0 | ), |
184 | 0 | Field::new( |
185 | 0 | format_state_name(args.name, "m2_x"), |
186 | 0 | DataType::Float64, |
187 | 0 | true, |
188 | 0 | ), |
189 | 0 | Field::new( |
190 | 0 | format_state_name(args.name, "m2_y"), |
191 | 0 | DataType::Float64, |
192 | 0 | true, |
193 | 0 | ), |
194 | 0 | Field::new( |
195 | 0 | format_state_name(args.name, "algo_const"), |
196 | 0 | DataType::Float64, |
197 | 0 | true, |
198 | 0 | ), |
199 | 0 | ]) |
200 | 0 | } |
201 | | } |
202 | | |
203 | | /* |
204 | | impl PartialEq<dyn Any> for Regr { |
205 | | fn eq(&self, other: &dyn Any) -> bool { |
206 | | down_cast_any_ref(other) |
207 | | .downcast_ref::<Self>() |
208 | | .map(|x| { |
209 | | self.name == x.name |
210 | | && self.expr_y.eq(&x.expr_y) |
211 | | && self.expr_x.eq(&x.expr_x) |
212 | | }) |
213 | | .unwrap_or(false) |
214 | | } |
215 | | } |
216 | | */ |
217 | | |
218 | | /// `RegrAccumulator` is used to compute linear regression aggregate functions |
219 | | /// by maintaining statistics needed to compute them in an online fashion. |
220 | | /// |
221 | | /// This struct uses Welford's online algorithm for calculating variance and covariance: |
222 | | /// <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm> |
223 | | /// |
224 | | /// Given the statistics, the following aggregate functions can be calculated: |
225 | | /// |
226 | | /// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as: |
227 | | /// cov_pop(x, y) / var_pop(x). |
228 | | /// It represents the expected change in Y for a one-unit change in X. |
229 | | /// |
230 | | /// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as: |
231 | | /// mean_y - (regr_slope(y, x) * mean_x). |
232 | | /// It represents the expected value of Y when X is 0. |
233 | | /// |
234 | | /// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows. |
235 | | /// |
236 | | /// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as: |
237 | | /// (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)). |
238 | | /// It provides a measure of how well the model's predictions match the observed data. |
239 | | /// |
240 | | /// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x. |
241 | | /// |
242 | | /// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y. |
243 | | /// |
244 | | /// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as: |
245 | | /// m2_x. |
246 | | /// |
247 | | /// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as: |
248 | | /// m2_y. |
249 | | /// |
250 | | /// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as: |
251 | | /// algo_const. |
252 | | /// |
253 | | /// Here's how the statistics maintained in this struct are calculated: |
254 | | /// - `cov_pop(x, y)`: algo_const / count. |
255 | | /// - `var_pop(x)`: m2_x / count. |
256 | | /// - `var_pop(y)`: m2_y / count. |
257 | | #[derive(Debug)] |
258 | | pub struct RegrAccumulator { |
259 | | count: u64, |
260 | | mean_x: f64, |
261 | | mean_y: f64, |
262 | | m2_x: f64, |
263 | | m2_y: f64, |
264 | | algo_const: f64, |
265 | | regr_type: RegrType, |
266 | | } |
267 | | |
268 | | impl RegrAccumulator { |
269 | | /// Creates a new `RegrAccumulator` |
270 | 0 | pub fn try_new(regr_type: &RegrType) -> Result<Self> { |
271 | 0 | Ok(Self { |
272 | 0 | count: 0_u64, |
273 | 0 | mean_x: 0_f64, |
274 | 0 | mean_y: 0_f64, |
275 | 0 | m2_x: 0_f64, |
276 | 0 | m2_y: 0_f64, |
277 | 0 | algo_const: 0_f64, |
278 | 0 | regr_type: regr_type.clone(), |
279 | 0 | }) |
280 | 0 | } |
281 | | } |
282 | | |
283 | | impl Accumulator for RegrAccumulator { |
284 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
285 | 0 | Ok(vec![ |
286 | 0 | ScalarValue::from(self.count), |
287 | 0 | ScalarValue::from(self.mean_x), |
288 | 0 | ScalarValue::from(self.mean_y), |
289 | 0 | ScalarValue::from(self.m2_x), |
290 | 0 | ScalarValue::from(self.m2_y), |
291 | 0 | ScalarValue::from(self.algo_const), |
292 | 0 | ]) |
293 | 0 | } |
294 | | |
295 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
296 | | // regr_slope(Y, X) calculates k in y = k*x + b |
297 | 0 | let values_y = &cast(&values[0], &DataType::Float64)?; |
298 | 0 | let values_x = &cast(&values[1], &DataType::Float64)?; |
299 | | |
300 | 0 | let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); |
301 | 0 | let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); |
302 | | |
303 | 0 | for i in 0..values_y.len() { |
304 | | // skip either x or y is NULL |
305 | 0 | let value_y = if values_y.is_valid(i) { |
306 | 0 | arr_y.next() |
307 | | } else { |
308 | 0 | None |
309 | | }; |
310 | 0 | let value_x = if values_x.is_valid(i) { |
311 | 0 | arr_x.next() |
312 | | } else { |
313 | 0 | None |
314 | | }; |
315 | 0 | if value_y.is_none() || value_x.is_none() { |
316 | 0 | continue; |
317 | 0 | } |
318 | | |
319 | | // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] |
320 | 0 | let value_y = unwrap_or_internal_err!(value_y); |
321 | 0 | let value_x = unwrap_or_internal_err!(value_x); |
322 | | |
323 | 0 | self.count += 1; |
324 | 0 | let delta_x = value_x - self.mean_x; |
325 | 0 | let delta_y = value_y - self.mean_y; |
326 | 0 | self.mean_x += delta_x / self.count as f64; |
327 | 0 | self.mean_y += delta_y / self.count as f64; |
328 | 0 | let delta_x_2 = value_x - self.mean_x; |
329 | 0 | let delta_y_2 = value_y - self.mean_y; |
330 | 0 | self.m2_x += delta_x * delta_x_2; |
331 | 0 | self.m2_y += delta_y * delta_y_2; |
332 | 0 | self.algo_const += delta_x * (value_y - self.mean_y); |
333 | | } |
334 | | |
335 | 0 | Ok(()) |
336 | 0 | } |
337 | | |
338 | 0 | fn supports_retract_batch(&self) -> bool { |
339 | 0 | true |
340 | 0 | } |
341 | | |
342 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
343 | 0 | let values_y = &cast(&values[0], &DataType::Float64)?; |
344 | 0 | let values_x = &cast(&values[1], &DataType::Float64)?; |
345 | | |
346 | 0 | let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); |
347 | 0 | let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); |
348 | | |
349 | 0 | for i in 0..values_y.len() { |
350 | | // skip either x or y is NULL |
351 | 0 | let value_y = if values_y.is_valid(i) { |
352 | 0 | arr_y.next() |
353 | | } else { |
354 | 0 | None |
355 | | }; |
356 | 0 | let value_x = if values_x.is_valid(i) { |
357 | 0 | arr_x.next() |
358 | | } else { |
359 | 0 | None |
360 | | }; |
361 | 0 | if value_y.is_none() || value_x.is_none() { |
362 | 0 | continue; |
363 | 0 | } |
364 | | |
365 | | // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] |
366 | 0 | let value_y = unwrap_or_internal_err!(value_y); |
367 | 0 | let value_x = unwrap_or_internal_err!(value_x); |
368 | | |
369 | 0 | if self.count > 1 { |
370 | 0 | self.count -= 1; |
371 | 0 | let delta_x = value_x - self.mean_x; |
372 | 0 | let delta_y = value_y - self.mean_y; |
373 | 0 | self.mean_x -= delta_x / self.count as f64; |
374 | 0 | self.mean_y -= delta_y / self.count as f64; |
375 | 0 | let delta_x_2 = value_x - self.mean_x; |
376 | 0 | let delta_y_2 = value_y - self.mean_y; |
377 | 0 | self.m2_x -= delta_x * delta_x_2; |
378 | 0 | self.m2_y -= delta_y * delta_y_2; |
379 | 0 | self.algo_const -= delta_x * (value_y - self.mean_y); |
380 | 0 | } else { |
381 | 0 | self.count = 0; |
382 | 0 | self.mean_x = 0.0; |
383 | 0 | self.m2_x = 0.0; |
384 | 0 | self.m2_y = 0.0; |
385 | 0 | self.mean_y = 0.0; |
386 | 0 | self.algo_const = 0.0; |
387 | 0 | } |
388 | | } |
389 | | |
390 | 0 | Ok(()) |
391 | 0 | } |
392 | | |
393 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
394 | 0 | let count_arr = downcast_value!(states[0], UInt64Array); |
395 | 0 | let mean_x_arr = downcast_value!(states[1], Float64Array); |
396 | 0 | let mean_y_arr = downcast_value!(states[2], Float64Array); |
397 | 0 | let m2_x_arr = downcast_value!(states[3], Float64Array); |
398 | 0 | let m2_y_arr = downcast_value!(states[4], Float64Array); |
399 | 0 | let algo_const_arr = downcast_value!(states[5], Float64Array); |
400 | | |
401 | 0 | for i in 0..count_arr.len() { |
402 | 0 | let count_b = count_arr.value(i); |
403 | 0 | if count_b == 0_u64 { |
404 | 0 | continue; |
405 | 0 | } |
406 | 0 | let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = ( |
407 | 0 | self.count, |
408 | 0 | self.mean_x, |
409 | 0 | self.mean_y, |
410 | 0 | self.m2_x, |
411 | 0 | self.m2_y, |
412 | 0 | self.algo_const, |
413 | 0 | ); |
414 | 0 | let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = ( |
415 | 0 | count_b, |
416 | 0 | mean_x_arr.value(i), |
417 | 0 | mean_y_arr.value(i), |
418 | 0 | m2_x_arr.value(i), |
419 | 0 | m2_y_arr.value(i), |
420 | 0 | algo_const_arr.value(i), |
421 | 0 | ); |
422 | 0 |
|
423 | 0 | // Assuming two different batches of input have calculated the states: |
424 | 0 | // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a} |
425 | 0 | // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b} |
426 | 0 | // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab, |
427 | 0 | // algo_const_ab} |
428 | 0 | // |
429 | 0 | // Reference for the algorithm to merge states: |
430 | 0 | // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm |
431 | 0 | let count_ab = count_a + count_b; |
432 | 0 | let (count_a, count_b) = (count_a as f64, count_b as f64); |
433 | 0 | let d_x = mean_x_b - mean_x_a; |
434 | 0 | let d_y = mean_y_b - mean_y_a; |
435 | 0 | let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64; |
436 | 0 | let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64; |
437 | 0 | let m2_x_ab = |
438 | 0 | m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64; |
439 | 0 | let m2_y_ab = |
440 | 0 | m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64; |
441 | 0 | let algo_const_ab = algo_const_a |
442 | 0 | + algo_const_b |
443 | 0 | + d_x * d_y * count_a * count_b / count_ab as f64; |
444 | 0 |
|
445 | 0 | self.count = count_ab; |
446 | 0 | self.mean_x = mean_x_ab; |
447 | 0 | self.mean_y = mean_y_ab; |
448 | 0 | self.m2_x = m2_x_ab; |
449 | 0 | self.m2_y = m2_y_ab; |
450 | 0 | self.algo_const = algo_const_ab; |
451 | | } |
452 | 0 | Ok(()) |
453 | 0 | } |
454 | | |
455 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
456 | 0 | let cov_pop_x_y = self.algo_const / self.count as f64; |
457 | 0 | let var_pop_x = self.m2_x / self.count as f64; |
458 | 0 | let var_pop_y = self.m2_y / self.count as f64; |
459 | 0 |
|
460 | 0 | let nullif_or_stat = |cond: bool, stat: f64| { |
461 | 0 | if cond { |
462 | 0 | Ok(ScalarValue::Float64(None)) |
463 | | } else { |
464 | 0 | Ok(ScalarValue::Float64(Some(stat))) |
465 | | } |
466 | 0 | }; |
467 | | |
468 | 0 | match self.regr_type { |
469 | | RegrType::Slope => { |
470 | | // Only 0/1 point or slope is infinite |
471 | 0 | let nullif_cond = self.count <= 1 || var_pop_x == 0.0; |
472 | 0 | nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x) |
473 | | } |
474 | | RegrType::Intercept => { |
475 | 0 | let slope = cov_pop_x_y / var_pop_x; |
476 | | // Only 0/1 point or slope is infinite |
477 | 0 | let nullif_cond = self.count <= 1 || var_pop_x == 0.0; |
478 | 0 | nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x) |
479 | | } |
480 | 0 | RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))), |
481 | | RegrType::R2 => { |
482 | | // Only 0/1 point or all x(or y) is the same |
483 | 0 | let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0; |
484 | 0 | nullif_or_stat( |
485 | 0 | nullif_cond, |
486 | 0 | (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y), |
487 | 0 | ) |
488 | | } |
489 | 0 | RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x), |
490 | 0 | RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y), |
491 | 0 | RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x), |
492 | 0 | RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y), |
493 | 0 | RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const), |
494 | | } |
495 | 0 | } |
496 | | |
497 | 0 | fn size(&self) -> usize { |
498 | 0 | std::mem::size_of_val(self) |
499 | 0 | } |
500 | | } |