/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/covariance.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`CovarianceSample`]: covariance sample aggregations. |
19 | | |
20 | | use std::fmt::Debug; |
21 | | |
22 | | use arrow::{ |
23 | | array::{ArrayRef, Float64Array, UInt64Array}, |
24 | | compute::kernels::cast, |
25 | | datatypes::{DataType, Field}, |
26 | | }; |
27 | | |
28 | | use datafusion_common::{ |
29 | | downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, |
30 | | ScalarValue, |
31 | | }; |
32 | | use datafusion_expr::{ |
33 | | function::{AccumulatorArgs, StateFieldsArgs}, |
34 | | type_coercion::aggregates::NUMERICS, |
35 | | utils::format_state_name, |
36 | | Accumulator, AggregateUDFImpl, Signature, Volatility, |
37 | | }; |
38 | | use datafusion_functions_aggregate_common::stats::StatsType; |
39 | | |
40 | | make_udaf_expr_and_func!( |
41 | | CovarianceSample, |
42 | | covar_samp, |
43 | | y x, |
44 | | "Computes the sample covariance.", |
45 | | covar_samp_udaf |
46 | | ); |
47 | | |
48 | | make_udaf_expr_and_func!( |
49 | | CovariancePopulation, |
50 | | covar_pop, |
51 | | y x, |
52 | | "Computes the population covariance.", |
53 | | covar_pop_udaf |
54 | | ); |
55 | | |
56 | | pub struct CovarianceSample { |
57 | | signature: Signature, |
58 | | aliases: Vec<String>, |
59 | | } |
60 | | |
61 | | impl Debug for CovarianceSample { |
62 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
63 | 0 | f.debug_struct("CovarianceSample") |
64 | 0 | .field("name", &self.name()) |
65 | 0 | .field("signature", &self.signature) |
66 | 0 | .finish() |
67 | 0 | } |
68 | | } |
69 | | |
70 | | impl Default for CovarianceSample { |
71 | 0 | fn default() -> Self { |
72 | 0 | Self::new() |
73 | 0 | } |
74 | | } |
75 | | |
76 | | impl CovarianceSample { |
77 | 0 | pub fn new() -> Self { |
78 | 0 | Self { |
79 | 0 | aliases: vec![String::from("covar")], |
80 | 0 | signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), |
81 | 0 | } |
82 | 0 | } |
83 | | } |
84 | | |
85 | | impl AggregateUDFImpl for CovarianceSample { |
86 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
87 | 0 | self |
88 | 0 | } |
89 | | |
90 | 0 | fn name(&self) -> &str { |
91 | 0 | "covar_samp" |
92 | 0 | } |
93 | | |
94 | 0 | fn signature(&self) -> &Signature { |
95 | 0 | &self.signature |
96 | 0 | } |
97 | | |
98 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
99 | 0 | if !arg_types[0].is_numeric() { |
100 | 0 | return plan_err!("Covariance requires numeric input types"); |
101 | 0 | } |
102 | 0 |
|
103 | 0 | Ok(DataType::Float64) |
104 | 0 | } |
105 | | |
106 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
107 | 0 | let name = args.name; |
108 | 0 | Ok(vec![ |
109 | 0 | Field::new(format_state_name(name, "count"), DataType::UInt64, true), |
110 | 0 | Field::new(format_state_name(name, "mean1"), DataType::Float64, true), |
111 | 0 | Field::new(format_state_name(name, "mean2"), DataType::Float64, true), |
112 | 0 | Field::new( |
113 | 0 | format_state_name(name, "algo_const"), |
114 | 0 | DataType::Float64, |
115 | 0 | true, |
116 | 0 | ), |
117 | 0 | ]) |
118 | 0 | } |
119 | | |
120 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
121 | 0 | Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?)) |
122 | 0 | } |
123 | | |
124 | 0 | fn aliases(&self) -> &[String] { |
125 | 0 | &self.aliases |
126 | 0 | } |
127 | | } |
128 | | |
129 | | pub struct CovariancePopulation { |
130 | | signature: Signature, |
131 | | } |
132 | | |
133 | | impl Debug for CovariancePopulation { |
134 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
135 | 0 | f.debug_struct("CovariancePopulation") |
136 | 0 | .field("name", &self.name()) |
137 | 0 | .field("signature", &self.signature) |
138 | 0 | .finish() |
139 | 0 | } |
140 | | } |
141 | | |
142 | | impl Default for CovariancePopulation { |
143 | 0 | fn default() -> Self { |
144 | 0 | Self::new() |
145 | 0 | } |
146 | | } |
147 | | |
148 | | impl CovariancePopulation { |
149 | 0 | pub fn new() -> Self { |
150 | 0 | Self { |
151 | 0 | signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), |
152 | 0 | } |
153 | 0 | } |
154 | | } |
155 | | |
156 | | impl AggregateUDFImpl for CovariancePopulation { |
157 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
158 | 0 | self |
159 | 0 | } |
160 | | |
161 | 0 | fn name(&self) -> &str { |
162 | 0 | "covar_pop" |
163 | 0 | } |
164 | | |
165 | 0 | fn signature(&self) -> &Signature { |
166 | 0 | &self.signature |
167 | 0 | } |
168 | | |
169 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
170 | 0 | if !arg_types[0].is_numeric() { |
171 | 0 | return plan_err!("Covariance requires numeric input types"); |
172 | 0 | } |
173 | 0 |
|
174 | 0 | Ok(DataType::Float64) |
175 | 0 | } |
176 | | |
177 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
178 | 0 | let name = args.name; |
179 | 0 | Ok(vec![ |
180 | 0 | Field::new(format_state_name(name, "count"), DataType::UInt64, true), |
181 | 0 | Field::new(format_state_name(name, "mean1"), DataType::Float64, true), |
182 | 0 | Field::new(format_state_name(name, "mean2"), DataType::Float64, true), |
183 | 0 | Field::new( |
184 | 0 | format_state_name(name, "algo_const"), |
185 | 0 | DataType::Float64, |
186 | 0 | true, |
187 | 0 | ), |
188 | 0 | ]) |
189 | 0 | } |
190 | | |
191 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
192 | 0 | Ok(Box::new(CovarianceAccumulator::try_new( |
193 | 0 | StatsType::Population, |
194 | 0 | )?)) |
195 | 0 | } |
196 | | } |
197 | | |
198 | | /// An accumulator to compute covariance |
199 | | /// The algorithm used is an online implementation and numerically stable. It is derived from the following paper |
200 | | /// for calculating variance: |
201 | | /// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". |
202 | | /// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. |
203 | | /// |
204 | | /// The algorithm has been analyzed here: |
205 | | /// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". |
206 | | /// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. |
207 | | /// |
208 | | /// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online, |
209 | | /// parallelizable and numerically stable. |
210 | | |
211 | | #[derive(Debug)] |
212 | | pub struct CovarianceAccumulator { |
213 | | algo_const: f64, |
214 | | mean1: f64, |
215 | | mean2: f64, |
216 | | count: u64, |
217 | | stats_type: StatsType, |
218 | | } |
219 | | |
220 | | impl CovarianceAccumulator { |
221 | | /// Creates a new `CovarianceAccumulator` |
222 | 0 | pub fn try_new(s_type: StatsType) -> Result<Self> { |
223 | 0 | Ok(Self { |
224 | 0 | algo_const: 0_f64, |
225 | 0 | mean1: 0_f64, |
226 | 0 | mean2: 0_f64, |
227 | 0 | count: 0_u64, |
228 | 0 | stats_type: s_type, |
229 | 0 | }) |
230 | 0 | } |
231 | | |
232 | 0 | pub fn get_count(&self) -> u64 { |
233 | 0 | self.count |
234 | 0 | } |
235 | | |
236 | 0 | pub fn get_mean1(&self) -> f64 { |
237 | 0 | self.mean1 |
238 | 0 | } |
239 | | |
240 | 0 | pub fn get_mean2(&self) -> f64 { |
241 | 0 | self.mean2 |
242 | 0 | } |
243 | | |
244 | 0 | pub fn get_algo_const(&self) -> f64 { |
245 | 0 | self.algo_const |
246 | 0 | } |
247 | | } |
248 | | |
249 | | impl Accumulator for CovarianceAccumulator { |
250 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
251 | 0 | Ok(vec![ |
252 | 0 | ScalarValue::from(self.count), |
253 | 0 | ScalarValue::from(self.mean1), |
254 | 0 | ScalarValue::from(self.mean2), |
255 | 0 | ScalarValue::from(self.algo_const), |
256 | 0 | ]) |
257 | 0 | } |
258 | | |
259 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
260 | 0 | let values1 = &cast(&values[0], &DataType::Float64)?; |
261 | 0 | let values2 = &cast(&values[1], &DataType::Float64)?; |
262 | | |
263 | 0 | let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); |
264 | 0 | let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); |
265 | | |
266 | 0 | for i in 0..values1.len() { |
267 | 0 | let value1 = if values1.is_valid(i) { |
268 | 0 | arr1.next() |
269 | | } else { |
270 | 0 | None |
271 | | }; |
272 | 0 | let value2 = if values2.is_valid(i) { |
273 | 0 | arr2.next() |
274 | | } else { |
275 | 0 | None |
276 | | }; |
277 | | |
278 | 0 | if value1.is_none() || value2.is_none() { |
279 | 0 | continue; |
280 | 0 | } |
281 | | |
282 | 0 | let value1 = unwrap_or_internal_err!(value1); |
283 | 0 | let value2 = unwrap_or_internal_err!(value2); |
284 | 0 | let new_count = self.count + 1; |
285 | 0 | let delta1 = value1 - self.mean1; |
286 | 0 | let new_mean1 = delta1 / new_count as f64 + self.mean1; |
287 | 0 | let delta2 = value2 - self.mean2; |
288 | 0 | let new_mean2 = delta2 / new_count as f64 + self.mean2; |
289 | 0 | let new_c = delta1 * (value2 - new_mean2) + self.algo_const; |
290 | 0 |
|
291 | 0 | self.count += 1; |
292 | 0 | self.mean1 = new_mean1; |
293 | 0 | self.mean2 = new_mean2; |
294 | 0 | self.algo_const = new_c; |
295 | | } |
296 | | |
297 | 0 | Ok(()) |
298 | 0 | } |
299 | | |
300 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
301 | 0 | let values1 = &cast(&values[0], &DataType::Float64)?; |
302 | 0 | let values2 = &cast(&values[1], &DataType::Float64)?; |
303 | 0 | let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); |
304 | 0 | let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); |
305 | | |
306 | 0 | for i in 0..values1.len() { |
307 | 0 | let value1 = if values1.is_valid(i) { |
308 | 0 | arr1.next() |
309 | | } else { |
310 | 0 | None |
311 | | }; |
312 | 0 | let value2 = if values2.is_valid(i) { |
313 | 0 | arr2.next() |
314 | | } else { |
315 | 0 | None |
316 | | }; |
317 | | |
318 | 0 | if value1.is_none() || value2.is_none() { |
319 | 0 | continue; |
320 | 0 | } |
321 | | |
322 | 0 | let value1 = unwrap_or_internal_err!(value1); |
323 | 0 | let value2 = unwrap_or_internal_err!(value2); |
324 | | |
325 | 0 | let new_count = self.count - 1; |
326 | 0 | let delta1 = self.mean1 - value1; |
327 | 0 | let new_mean1 = delta1 / new_count as f64 + self.mean1; |
328 | 0 | let delta2 = self.mean2 - value2; |
329 | 0 | let new_mean2 = delta2 / new_count as f64 + self.mean2; |
330 | 0 | let new_c = self.algo_const - delta1 * (new_mean2 - value2); |
331 | 0 |
|
332 | 0 | self.count -= 1; |
333 | 0 | self.mean1 = new_mean1; |
334 | 0 | self.mean2 = new_mean2; |
335 | 0 | self.algo_const = new_c; |
336 | | } |
337 | | |
338 | 0 | Ok(()) |
339 | 0 | } |
340 | | |
341 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
342 | 0 | let counts = downcast_value!(states[0], UInt64Array); |
343 | 0 | let means1 = downcast_value!(states[1], Float64Array); |
344 | 0 | let means2 = downcast_value!(states[2], Float64Array); |
345 | 0 | let cs = downcast_value!(states[3], Float64Array); |
346 | | |
347 | 0 | for i in 0..counts.len() { |
348 | 0 | let c = counts.value(i); |
349 | 0 | if c == 0_u64 { |
350 | 0 | continue; |
351 | 0 | } |
352 | 0 | let new_count = self.count + c; |
353 | 0 | let new_mean1 = self.mean1 * self.count as f64 / new_count as f64 |
354 | 0 | + means1.value(i) * c as f64 / new_count as f64; |
355 | 0 | let new_mean2 = self.mean2 * self.count as f64 / new_count as f64 |
356 | 0 | + means2.value(i) * c as f64 / new_count as f64; |
357 | 0 | let delta1 = self.mean1 - means1.value(i); |
358 | 0 | let delta2 = self.mean2 - means2.value(i); |
359 | 0 | let new_c = self.algo_const |
360 | 0 | + cs.value(i) |
361 | 0 | + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64; |
362 | 0 |
|
363 | 0 | self.count = new_count; |
364 | 0 | self.mean1 = new_mean1; |
365 | 0 | self.mean2 = new_mean2; |
366 | 0 | self.algo_const = new_c; |
367 | | } |
368 | 0 | Ok(()) |
369 | 0 | } |
370 | | |
371 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
372 | 0 | let count = match self.stats_type { |
373 | 0 | StatsType::Population => self.count, |
374 | | StatsType::Sample => { |
375 | 0 | if self.count > 0 { |
376 | 0 | self.count - 1 |
377 | | } else { |
378 | 0 | self.count |
379 | | } |
380 | | } |
381 | | }; |
382 | | |
383 | 0 | if count == 0 { |
384 | 0 | Ok(ScalarValue::Float64(None)) |
385 | | } else { |
386 | 0 | Ok(ScalarValue::Float64(Some(self.algo_const / count as f64))) |
387 | | } |
388 | 0 | } |
389 | | |
390 | 0 | fn size(&self) -> usize { |
391 | 0 | std::mem::size_of_val(self) |
392 | 0 | } |
393 | | } |