/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/variance.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`VarianceSample`]: variance sample aggregations. |
19 | | //! [`VariancePopulation`]: variance population aggregations. |
20 | | |
21 | | use arrow::{ |
22 | | array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, |
23 | | buffer::NullBuffer, |
24 | | compute::kernels::cast, |
25 | | datatypes::{DataType, Field}, |
26 | | }; |
27 | | use std::sync::OnceLock; |
28 | | use std::{fmt::Debug, sync::Arc}; |
29 | | |
30 | | use datafusion_common::{ |
31 | | downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, |
32 | | }; |
33 | | use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; |
34 | | use datafusion_expr::{ |
35 | | function::{AccumulatorArgs, StateFieldsArgs}, |
36 | | utils::format_state_name, |
37 | | Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, |
38 | | Volatility, |
39 | | }; |
40 | | use datafusion_functions_aggregate_common::{ |
41 | | aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, |
42 | | }; |
43 | | |
44 | | make_udaf_expr_and_func!( |
45 | | VarianceSample, |
46 | | var_sample, |
47 | | expression, |
48 | | "Computes the sample variance.", |
49 | | var_samp_udaf |
50 | | ); |
51 | | |
52 | | make_udaf_expr_and_func!( |
53 | | VariancePopulation, |
54 | | var_pop, |
55 | | expression, |
56 | | "Computes the population variance.", |
57 | | var_pop_udaf |
58 | | ); |
59 | | |
60 | | pub struct VarianceSample { |
61 | | signature: Signature, |
62 | | aliases: Vec<String>, |
63 | | } |
64 | | |
65 | | impl Debug for VarianceSample { |
66 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
67 | 0 | f.debug_struct("VarianceSample") |
68 | 0 | .field("name", &self.name()) |
69 | 0 | .field("signature", &self.signature) |
70 | 0 | .finish() |
71 | 0 | } |
72 | | } |
73 | | |
74 | | impl Default for VarianceSample { |
75 | 0 | fn default() -> Self { |
76 | 0 | Self::new() |
77 | 0 | } |
78 | | } |
79 | | |
80 | | impl VarianceSample { |
81 | 0 | pub fn new() -> Self { |
82 | 0 | Self { |
83 | 0 | aliases: vec![String::from("var_sample"), String::from("var_samp")], |
84 | 0 | signature: Signature::coercible( |
85 | 0 | vec![DataType::Float64], |
86 | 0 | Volatility::Immutable, |
87 | 0 | ), |
88 | 0 | } |
89 | 0 | } |
90 | | } |
91 | | |
92 | | impl AggregateUDFImpl for VarianceSample { |
93 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
94 | 0 | self |
95 | 0 | } |
96 | | |
97 | 0 | fn name(&self) -> &str { |
98 | 0 | "var" |
99 | 0 | } |
100 | | |
101 | 0 | fn signature(&self) -> &Signature { |
102 | 0 | &self.signature |
103 | 0 | } |
104 | | |
105 | 0 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
106 | 0 | Ok(DataType::Float64) |
107 | 0 | } |
108 | | |
109 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
110 | 0 | let name = args.name; |
111 | 0 | Ok(vec![ |
112 | 0 | Field::new(format_state_name(name, "count"), DataType::UInt64, true), |
113 | 0 | Field::new(format_state_name(name, "mean"), DataType::Float64, true), |
114 | 0 | Field::new(format_state_name(name, "m2"), DataType::Float64, true), |
115 | 0 | ]) |
116 | 0 | } |
117 | | |
118 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
119 | 0 | if acc_args.is_distinct { |
120 | 0 | return not_impl_err!("VAR(DISTINCT) aggregations are not available"); |
121 | 0 | } |
122 | 0 |
|
123 | 0 | Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) |
124 | 0 | } |
125 | | |
126 | 0 | fn aliases(&self) -> &[String] { |
127 | 0 | &self.aliases |
128 | 0 | } |
129 | | |
130 | 0 | fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { |
131 | 0 | !acc_args.is_distinct |
132 | 0 | } |
133 | | |
134 | 0 | fn create_groups_accumulator( |
135 | 0 | &self, |
136 | 0 | _args: AccumulatorArgs, |
137 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
138 | 0 | Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample))) |
139 | 0 | } |
140 | | |
141 | 0 | fn documentation(&self) -> Option<&Documentation> { |
142 | 0 | Some(get_variance_sample_doc()) |
143 | 0 | } |
144 | | } |
145 | | |
146 | | static VARIANCE_SAMPLE_DOC: OnceLock<Documentation> = OnceLock::new(); |
147 | | |
148 | 0 | fn get_variance_sample_doc() -> &'static Documentation { |
149 | 0 | VARIANCE_SAMPLE_DOC.get_or_init(|| { |
150 | 0 | Documentation::builder() |
151 | 0 | .with_doc_section(DOC_SECTION_GENERAL) |
152 | 0 | .with_description( |
153 | 0 | "Returns the statistical sample variance of a set of numbers.", |
154 | 0 | ) |
155 | 0 | .with_syntax_example("var(expression)") |
156 | 0 | .with_standard_argument("expression", "Numeric") |
157 | 0 | .build() |
158 | 0 | .unwrap() |
159 | 0 | }) |
160 | 0 | } |
161 | | |
162 | | pub struct VariancePopulation { |
163 | | signature: Signature, |
164 | | aliases: Vec<String>, |
165 | | } |
166 | | |
167 | | impl Debug for VariancePopulation { |
168 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
169 | 0 | f.debug_struct("VariancePopulation") |
170 | 0 | .field("name", &self.name()) |
171 | 0 | .field("signature", &self.signature) |
172 | 0 | .finish() |
173 | 0 | } |
174 | | } |
175 | | |
176 | | impl Default for VariancePopulation { |
177 | 0 | fn default() -> Self { |
178 | 0 | Self::new() |
179 | 0 | } |
180 | | } |
181 | | |
182 | | impl VariancePopulation { |
183 | 0 | pub fn new() -> Self { |
184 | 0 | Self { |
185 | 0 | aliases: vec![String::from("var_population")], |
186 | 0 | signature: Signature::numeric(1, Volatility::Immutable), |
187 | 0 | } |
188 | 0 | } |
189 | | } |
190 | | |
191 | | impl AggregateUDFImpl for VariancePopulation { |
192 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
193 | 0 | self |
194 | 0 | } |
195 | | |
196 | 0 | fn name(&self) -> &str { |
197 | 0 | "var_pop" |
198 | 0 | } |
199 | | |
200 | 0 | fn signature(&self) -> &Signature { |
201 | 0 | &self.signature |
202 | 0 | } |
203 | | |
204 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
205 | 0 | if !arg_types[0].is_numeric() { |
206 | 0 | return plan_err!("Variance requires numeric input types"); |
207 | 0 | } |
208 | 0 |
|
209 | 0 | Ok(DataType::Float64) |
210 | 0 | } |
211 | | |
212 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
213 | 0 | let name = args.name; |
214 | 0 | Ok(vec![ |
215 | 0 | Field::new(format_state_name(name, "count"), DataType::UInt64, true), |
216 | 0 | Field::new(format_state_name(name, "mean"), DataType::Float64, true), |
217 | 0 | Field::new(format_state_name(name, "m2"), DataType::Float64, true), |
218 | 0 | ]) |
219 | 0 | } |
220 | | |
221 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
222 | 0 | if acc_args.is_distinct { |
223 | 0 | return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); |
224 | 0 | } |
225 | 0 |
|
226 | 0 | Ok(Box::new(VarianceAccumulator::try_new( |
227 | 0 | StatsType::Population, |
228 | 0 | )?)) |
229 | 0 | } |
230 | | |
231 | 0 | fn aliases(&self) -> &[String] { |
232 | 0 | &self.aliases |
233 | 0 | } |
234 | | |
235 | 0 | fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { |
236 | 0 | !acc_args.is_distinct |
237 | 0 | } |
238 | | |
239 | 0 | fn create_groups_accumulator( |
240 | 0 | &self, |
241 | 0 | _args: AccumulatorArgs, |
242 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
243 | 0 | Ok(Box::new(VarianceGroupsAccumulator::new( |
244 | 0 | StatsType::Population, |
245 | 0 | ))) |
246 | 0 | } |
247 | 0 | fn documentation(&self) -> Option<&Documentation> { |
248 | 0 | Some(get_variance_population_doc()) |
249 | 0 | } |
250 | | } |
251 | | |
252 | | static VARIANCE_POPULATION_DOC: OnceLock<Documentation> = OnceLock::new(); |
253 | | |
254 | 0 | fn get_variance_population_doc() -> &'static Documentation { |
255 | 0 | VARIANCE_POPULATION_DOC.get_or_init(|| { |
256 | 0 | Documentation::builder() |
257 | 0 | .with_doc_section(DOC_SECTION_GENERAL) |
258 | 0 | .with_description( |
259 | 0 | "Returns the statistical population variance of a set of numbers.", |
260 | 0 | ) |
261 | 0 | .with_syntax_example("var_pop(expression)") |
262 | 0 | .with_standard_argument("expression", "Numeric") |
263 | 0 | .build() |
264 | 0 | .unwrap() |
265 | 0 | }) |
266 | 0 | } |
267 | | |
268 | | /// An accumulator to compute variance |
269 | | /// The algorithm used is an online implementation and numerically stable. It is based on this paper: |
270 | | /// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". |
271 | | /// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. |
272 | | /// |
273 | | /// The algorithm has been analyzed here: |
274 | | /// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". |
275 | | /// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. |
276 | | |
277 | | #[derive(Debug)] |
278 | | pub struct VarianceAccumulator { |
279 | | m2: f64, |
280 | | mean: f64, |
281 | | count: u64, |
282 | | stats_type: StatsType, |
283 | | } |
284 | | |
285 | | impl VarianceAccumulator { |
286 | | /// Creates a new `VarianceAccumulator` |
287 | 0 | pub fn try_new(s_type: StatsType) -> Result<Self> { |
288 | 0 | Ok(Self { |
289 | 0 | m2: 0_f64, |
290 | 0 | mean: 0_f64, |
291 | 0 | count: 0_u64, |
292 | 0 | stats_type: s_type, |
293 | 0 | }) |
294 | 0 | } |
295 | | |
296 | 0 | pub fn get_count(&self) -> u64 { |
297 | 0 | self.count |
298 | 0 | } |
299 | | |
300 | 0 | pub fn get_mean(&self) -> f64 { |
301 | 0 | self.mean |
302 | 0 | } |
303 | | |
304 | 0 | pub fn get_m2(&self) -> f64 { |
305 | 0 | self.m2 |
306 | 0 | } |
307 | | } |
308 | | |
309 | | #[inline] |
310 | 0 | fn merge( |
311 | 0 | count: u64, |
312 | 0 | mean: f64, |
313 | 0 | m2: f64, |
314 | 0 | count2: u64, |
315 | 0 | mean2: f64, |
316 | 0 | m22: f64, |
317 | 0 | ) -> (u64, f64, f64) { |
318 | 0 | let new_count = count + count2; |
319 | 0 | let new_mean = |
320 | 0 | mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64; |
321 | 0 | let delta = mean - mean2; |
322 | 0 | let new_m2 = |
323 | 0 | m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64; |
324 | 0 |
|
325 | 0 | (new_count, new_mean, new_m2) |
326 | 0 | } |
327 | | |
328 | | #[inline] |
329 | 0 | fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) { |
330 | 0 | let new_count = count + 1; |
331 | 0 | let delta1 = value - mean; |
332 | 0 | let new_mean = delta1 / new_count as f64 + mean; |
333 | 0 | let delta2 = value - new_mean; |
334 | 0 | let new_m2 = m2 + delta1 * delta2; |
335 | 0 |
|
336 | 0 | (new_count, new_mean, new_m2) |
337 | 0 | } |
338 | | |
339 | | impl Accumulator for VarianceAccumulator { |
340 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
341 | 0 | Ok(vec![ |
342 | 0 | ScalarValue::from(self.count), |
343 | 0 | ScalarValue::from(self.mean), |
344 | 0 | ScalarValue::from(self.m2), |
345 | 0 | ]) |
346 | 0 | } |
347 | | |
348 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
349 | 0 | let values = &cast(&values[0], &DataType::Float64)?; |
350 | 0 | let arr = downcast_value!(values, Float64Array).iter().flatten(); |
351 | | |
352 | 0 | for value in arr { |
353 | 0 | (self.count, self.mean, self.m2) = |
354 | 0 | update(self.count, self.mean, self.m2, value) |
355 | | } |
356 | | |
357 | 0 | Ok(()) |
358 | 0 | } |
359 | | |
360 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
361 | 0 | let values = &cast(&values[0], &DataType::Float64)?; |
362 | 0 | let arr = downcast_value!(values, Float64Array).iter().flatten(); |
363 | | |
364 | 0 | for value in arr { |
365 | 0 | let new_count = self.count - 1; |
366 | 0 | let delta1 = self.mean - value; |
367 | 0 | let new_mean = delta1 / new_count as f64 + self.mean; |
368 | 0 | let delta2 = new_mean - value; |
369 | 0 | let new_m2 = self.m2 - delta1 * delta2; |
370 | 0 |
|
371 | 0 | self.count -= 1; |
372 | 0 | self.mean = new_mean; |
373 | 0 | self.m2 = new_m2; |
374 | 0 | } |
375 | | |
376 | 0 | Ok(()) |
377 | 0 | } |
378 | | |
379 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
380 | 0 | let counts = downcast_value!(states[0], UInt64Array); |
381 | 0 | let means = downcast_value!(states[1], Float64Array); |
382 | 0 | let m2s = downcast_value!(states[2], Float64Array); |
383 | | |
384 | 0 | for i in 0..counts.len() { |
385 | 0 | let c = counts.value(i); |
386 | 0 | if c == 0_u64 { |
387 | 0 | continue; |
388 | 0 | } |
389 | 0 | (self.count, self.mean, self.m2) = merge( |
390 | 0 | self.count, |
391 | 0 | self.mean, |
392 | 0 | self.m2, |
393 | 0 | c, |
394 | 0 | means.value(i), |
395 | 0 | m2s.value(i), |
396 | 0 | ) |
397 | | } |
398 | 0 | Ok(()) |
399 | 0 | } |
400 | | |
401 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
402 | 0 | let count = match self.stats_type { |
403 | 0 | StatsType::Population => self.count, |
404 | | StatsType::Sample => { |
405 | 0 | if self.count > 0 { |
406 | 0 | self.count - 1 |
407 | | } else { |
408 | 0 | self.count |
409 | | } |
410 | | } |
411 | | }; |
412 | | |
413 | 0 | Ok(ScalarValue::Float64(match self.count { |
414 | 0 | 0 => None, |
415 | | 1 => { |
416 | 0 | if let StatsType::Population = self.stats_type { |
417 | 0 | Some(0.0) |
418 | | } else { |
419 | 0 | None |
420 | | } |
421 | | } |
422 | 0 | _ => Some(self.m2 / count as f64), |
423 | | })) |
424 | 0 | } |
425 | | |
426 | 0 | fn size(&self) -> usize { |
427 | 0 | std::mem::size_of_val(self) |
428 | 0 | } |
429 | | |
430 | 0 | fn supports_retract_batch(&self) -> bool { |
431 | 0 | true |
432 | 0 | } |
433 | | } |
434 | | |
435 | | #[derive(Debug)] |
436 | | pub struct VarianceGroupsAccumulator { |
437 | | m2s: Vec<f64>, |
438 | | means: Vec<f64>, |
439 | | counts: Vec<u64>, |
440 | | stats_type: StatsType, |
441 | | } |
442 | | |
443 | | impl VarianceGroupsAccumulator { |
444 | 0 | pub fn new(s_type: StatsType) -> Self { |
445 | 0 | Self { |
446 | 0 | m2s: Vec::new(), |
447 | 0 | means: Vec::new(), |
448 | 0 | counts: Vec::new(), |
449 | 0 | stats_type: s_type, |
450 | 0 | } |
451 | 0 | } |
452 | | |
453 | 0 | fn resize(&mut self, total_num_groups: usize) { |
454 | 0 | self.m2s.resize(total_num_groups, 0.0); |
455 | 0 | self.means.resize(total_num_groups, 0.0); |
456 | 0 | self.counts.resize(total_num_groups, 0); |
457 | 0 | } |
458 | | |
459 | 0 | fn merge<F>( |
460 | 0 | group_indices: &[usize], |
461 | 0 | counts: &UInt64Array, |
462 | 0 | means: &Float64Array, |
463 | 0 | m2s: &Float64Array, |
464 | 0 | opt_filter: Option<&BooleanArray>, |
465 | 0 | mut value_fn: F, |
466 | 0 | ) where |
467 | 0 | F: FnMut(usize, u64, f64, f64) + Send, |
468 | 0 | { |
469 | 0 | assert_eq!(counts.null_count(), 0); |
470 | 0 | assert_eq!(means.null_count(), 0); |
471 | 0 | assert_eq!(m2s.null_count(), 0); |
472 | | |
473 | 0 | match opt_filter { |
474 | 0 | None => { |
475 | 0 | group_indices |
476 | 0 | .iter() |
477 | 0 | .zip(counts.values().iter()) |
478 | 0 | .zip(means.values().iter()) |
479 | 0 | .zip(m2s.values().iter()) |
480 | 0 | .for_each(|(((&group_index, &count), &mean), &m2)| { |
481 | 0 | value_fn(group_index, count, mean, m2); |
482 | 0 | }); |
483 | 0 | } |
484 | 0 | Some(filter) => { |
485 | 0 | group_indices |
486 | 0 | .iter() |
487 | 0 | .zip(counts.values().iter()) |
488 | 0 | .zip(means.values().iter()) |
489 | 0 | .zip(m2s.values().iter()) |
490 | 0 | .zip(filter.iter()) |
491 | 0 | .for_each( |
492 | 0 | |((((&group_index, &count), &mean), &m2), filter_value)| { |
493 | 0 | if let Some(true) = filter_value { |
494 | 0 | value_fn(group_index, count, mean, m2); |
495 | 0 | } |
496 | 0 | }, |
497 | 0 | ); |
498 | 0 | } |
499 | | } |
500 | 0 | } |
501 | | |
502 | 0 | pub fn variance( |
503 | 0 | &mut self, |
504 | 0 | emit_to: datafusion_expr::EmitTo, |
505 | 0 | ) -> (Vec<f64>, NullBuffer) { |
506 | 0 | let mut counts = emit_to.take_needed(&mut self.counts); |
507 | 0 | // means are only needed for updating m2s and are not needed for the final result. |
508 | 0 | // But we still need to take them to ensure the internal state is consistent. |
509 | 0 | let _ = emit_to.take_needed(&mut self.means); |
510 | 0 | let m2s = emit_to.take_needed(&mut self.m2s); |
511 | 0 |
|
512 | 0 | if let StatsType::Sample = self.stats_type { |
513 | 0 | counts.iter_mut().for_each(|count| { |
514 | 0 | *count = count.saturating_sub(1); |
515 | 0 | }); |
516 | 0 | } |
517 | 0 | let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0)); |
518 | 0 | let variance = m2s |
519 | 0 | .iter() |
520 | 0 | .zip(counts) |
521 | 0 | .map(|(m2, count)| m2 / count as f64) |
522 | 0 | .collect(); |
523 | 0 | (variance, nulls) |
524 | 0 | } |
525 | | } |
526 | | |
527 | | impl GroupsAccumulator for VarianceGroupsAccumulator { |
528 | 0 | fn update_batch( |
529 | 0 | &mut self, |
530 | 0 | values: &[ArrayRef], |
531 | 0 | group_indices: &[usize], |
532 | 0 | opt_filter: Option<&arrow::array::BooleanArray>, |
533 | 0 | total_num_groups: usize, |
534 | 0 | ) -> Result<()> { |
535 | 0 | assert_eq!(values.len(), 1, "single argument to update_batch"); |
536 | 0 | let values = &cast(&values[0], &DataType::Float64)?; |
537 | 0 | let values = downcast_value!(values, Float64Array); |
538 | | |
539 | 0 | self.resize(total_num_groups); |
540 | 0 | accumulate(group_indices, values, opt_filter, |group_index, value| { |
541 | 0 | let (new_count, new_mean, new_m2) = update( |
542 | 0 | self.counts[group_index], |
543 | 0 | self.means[group_index], |
544 | 0 | self.m2s[group_index], |
545 | 0 | value, |
546 | 0 | ); |
547 | 0 | self.counts[group_index] = new_count; |
548 | 0 | self.means[group_index] = new_mean; |
549 | 0 | self.m2s[group_index] = new_m2; |
550 | 0 | }); |
551 | 0 | Ok(()) |
552 | 0 | } |
553 | | |
554 | 0 | fn merge_batch( |
555 | 0 | &mut self, |
556 | 0 | values: &[ArrayRef], |
557 | 0 | group_indices: &[usize], |
558 | 0 | opt_filter: Option<&arrow::array::BooleanArray>, |
559 | 0 | total_num_groups: usize, |
560 | 0 | ) -> Result<()> { |
561 | 0 | assert_eq!(values.len(), 3, "two arguments to merge_batch"); |
562 | | // first batch is counts, second is partial means, third is partial m2s |
563 | 0 | let partial_counts = downcast_value!(values[0], UInt64Array); |
564 | 0 | let partial_means = downcast_value!(values[1], Float64Array); |
565 | 0 | let partial_m2s = downcast_value!(values[2], Float64Array); |
566 | | |
567 | 0 | self.resize(total_num_groups); |
568 | 0 | Self::merge( |
569 | 0 | group_indices, |
570 | 0 | partial_counts, |
571 | 0 | partial_means, |
572 | 0 | partial_m2s, |
573 | 0 | opt_filter, |
574 | 0 | |group_index, partial_count, partial_mean, partial_m2| { |
575 | 0 | let (new_count, new_mean, new_m2) = merge( |
576 | 0 | self.counts[group_index], |
577 | 0 | self.means[group_index], |
578 | 0 | self.m2s[group_index], |
579 | 0 | partial_count, |
580 | 0 | partial_mean, |
581 | 0 | partial_m2, |
582 | 0 | ); |
583 | 0 | self.counts[group_index] = new_count; |
584 | 0 | self.means[group_index] = new_mean; |
585 | 0 | self.m2s[group_index] = new_m2; |
586 | 0 | }, |
587 | 0 | ); |
588 | 0 | Ok(()) |
589 | 0 | } |
590 | | |
591 | 0 | fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> { |
592 | 0 | let (variances, nulls) = self.variance(emit_to); |
593 | 0 | Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls)))) |
594 | 0 | } |
595 | | |
596 | 0 | fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> { |
597 | 0 | let counts = emit_to.take_needed(&mut self.counts); |
598 | 0 | let means = emit_to.take_needed(&mut self.means); |
599 | 0 | let m2s = emit_to.take_needed(&mut self.m2s); |
600 | 0 |
|
601 | 0 | Ok(vec![ |
602 | 0 | Arc::new(UInt64Array::new(counts.into(), None)), |
603 | 0 | Arc::new(Float64Array::new(means.into(), None)), |
604 | 0 | Arc::new(Float64Array::new(m2s.into(), None)), |
605 | 0 | ]) |
606 | 0 | } |
607 | | |
608 | 0 | fn size(&self) -> usize { |
609 | 0 | self.m2s.capacity() * std::mem::size_of::<f64>() |
610 | 0 | + self.means.capacity() * std::mem::size_of::<f64>() |
611 | 0 | + self.counts.capacity() * std::mem::size_of::<u64>() |
612 | 0 | } |
613 | | } |