/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/average.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines `Avg` & `Mean` aggregate & accumulators |
19 | | |
20 | | use arrow::array::{ |
21 | | self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, |
22 | | AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, |
23 | | }; |
24 | | |
25 | | use arrow::compute::sum; |
26 | | use arrow::datatypes::{ |
27 | | i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, |
28 | | Float64Type, UInt64Type, |
29 | | }; |
30 | | use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; |
31 | | use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; |
32 | | use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; |
33 | | use datafusion_expr::utils::format_state_name; |
34 | | use datafusion_expr::Volatility::Immutable; |
35 | | use datafusion_expr::{ |
36 | | Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, |
37 | | }; |
38 | | |
39 | | use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; |
40 | | use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ |
41 | | filtered_null_mask, set_nulls, |
42 | | }; |
43 | | |
44 | | use datafusion_functions_aggregate_common::utils::DecimalAverager; |
45 | | use log::debug; |
46 | | use std::any::Any; |
47 | | use std::fmt::Debug; |
48 | | use std::sync::Arc; |
49 | | |
50 | | make_udaf_expr_and_func!( |
51 | | Avg, |
52 | | avg, |
53 | | expression, |
54 | | "Returns the avg of a group of values.", |
55 | | avg_udaf |
56 | | ); |
57 | | |
58 | | #[derive(Debug)] |
59 | | pub struct Avg { |
60 | | signature: Signature, |
61 | | aliases: Vec<String>, |
62 | | } |
63 | | |
64 | | impl Avg { |
65 | 1 | pub fn new() -> Self { |
66 | 1 | Self { |
67 | 1 | signature: Signature::user_defined(Immutable), |
68 | 1 | aliases: vec![String::from("mean")], |
69 | 1 | } |
70 | 1 | } |
71 | | } |
72 | | |
73 | | impl Default for Avg { |
74 | 1 | fn default() -> Self { |
75 | 1 | Self::new() |
76 | 1 | } |
77 | | } |
78 | | |
79 | | impl AggregateUDFImpl for Avg { |
80 | 0 | fn as_any(&self) -> &dyn Any { |
81 | 0 | self |
82 | 0 | } |
83 | | |
84 | 14 | fn name(&self) -> &str { |
85 | 14 | "avg" |
86 | 14 | } |
87 | | |
88 | 7 | fn signature(&self) -> &Signature { |
89 | 7 | &self.signature |
90 | 7 | } |
91 | | |
92 | 7 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
93 | 7 | avg_return_type(self.name(), &arg_types[0]) |
94 | 7 | } |
95 | | |
96 | 1 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
97 | 1 | if acc_args.is_distinct { |
98 | 0 | return exec_err!("avg(DISTINCT) aggregations are not available"); |
99 | 1 | } |
100 | | use DataType::*; |
101 | | |
102 | 1 | let data_type = acc_args.exprs[0].data_type(acc_args.schema)?0 ; |
103 | | // instantiate specialized accumulator based for the type |
104 | 1 | match (&data_type, acc_args.return_type) { |
105 | 1 | (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()), |
106 | | ( |
107 | 0 | Decimal128(sum_precision, sum_scale), |
108 | 0 | Decimal128(target_precision, target_scale), |
109 | 0 | ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> { |
110 | 0 | sum: None, |
111 | 0 | count: 0, |
112 | 0 | sum_scale: *sum_scale, |
113 | 0 | sum_precision: *sum_precision, |
114 | 0 | target_precision: *target_precision, |
115 | 0 | target_scale: *target_scale, |
116 | 0 | })), |
117 | | |
118 | | ( |
119 | 0 | Decimal256(sum_precision, sum_scale), |
120 | 0 | Decimal256(target_precision, target_scale), |
121 | 0 | ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> { |
122 | 0 | sum: None, |
123 | 0 | count: 0, |
124 | 0 | sum_scale: *sum_scale, |
125 | 0 | sum_precision: *sum_precision, |
126 | 0 | target_precision: *target_precision, |
127 | 0 | target_scale: *target_scale, |
128 | 0 | })), |
129 | 0 | _ => exec_err!( |
130 | 0 | "AvgAccumulator for ({} --> {})", |
131 | 0 | &data_type, |
132 | 0 | acc_args.return_type |
133 | 0 | ), |
134 | | } |
135 | 1 | } |
136 | | |
137 | 25 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
138 | 25 | Ok(vec![ |
139 | 25 | Field::new( |
140 | 25 | format_state_name(args.name, "count"), |
141 | 25 | DataType::UInt64, |
142 | 25 | true, |
143 | 25 | ), |
144 | 25 | Field::new( |
145 | 25 | format_state_name(args.name, "sum"), |
146 | 25 | args.input_types[0].clone(), |
147 | 25 | true, |
148 | 25 | ), |
149 | 25 | ]) |
150 | 25 | } |
151 | | |
152 | 14 | fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { |
153 | 0 | matches!( |
154 | 14 | args.return_type, |
155 | | DataType::Float64 | DataType::Decimal128(_, _) |
156 | | ) |
157 | 14 | } |
158 | | |
159 | 14 | fn create_groups_accumulator( |
160 | 14 | &self, |
161 | 14 | args: AccumulatorArgs, |
162 | 14 | ) -> Result<Box<dyn GroupsAccumulator>> { |
163 | | use DataType::*; |
164 | | |
165 | 14 | let data_type = args.exprs[0].data_type(args.schema)?0 ; |
166 | | // instantiate specialized accumulator based for the type |
167 | 14 | match (&data_type, args.return_type) { |
168 | | (Float64, Float64) => { |
169 | 14 | Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new( |
170 | 14 | &data_type, |
171 | 14 | args.return_type, |
172 | 14 | |sum: f64, count: u64| Ok(sum / count as f64)12 , |
173 | 14 | ))) |
174 | | } |
175 | | ( |
176 | 0 | Decimal128(_sum_precision, sum_scale), |
177 | 0 | Decimal128(target_precision, target_scale), |
178 | | ) => { |
179 | 0 | let decimal_averager = DecimalAverager::<Decimal128Type>::try_new( |
180 | 0 | *sum_scale, |
181 | 0 | *target_precision, |
182 | 0 | *target_scale, |
183 | 0 | )?; |
184 | | |
185 | 0 | let avg_fn = |
186 | 0 | move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); |
187 | | |
188 | 0 | Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new( |
189 | 0 | &data_type, |
190 | 0 | args.return_type, |
191 | 0 | avg_fn, |
192 | 0 | ))) |
193 | | } |
194 | | |
195 | | ( |
196 | 0 | Decimal256(_sum_precision, sum_scale), |
197 | 0 | Decimal256(target_precision, target_scale), |
198 | | ) => { |
199 | 0 | let decimal_averager = DecimalAverager::<Decimal256Type>::try_new( |
200 | 0 | *sum_scale, |
201 | 0 | *target_precision, |
202 | 0 | *target_scale, |
203 | 0 | )?; |
204 | | |
205 | 0 | let avg_fn = move |sum: i256, count: u64| { |
206 | 0 | decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap()) |
207 | 0 | }; |
208 | | |
209 | 0 | Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new( |
210 | 0 | &data_type, |
211 | 0 | args.return_type, |
212 | 0 | avg_fn, |
213 | 0 | ))) |
214 | | } |
215 | | |
216 | 0 | _ => not_impl_err!( |
217 | 0 | "AvgGroupsAccumulator for ({} --> {})", |
218 | 0 | &data_type, |
219 | 0 | args.return_type |
220 | 0 | ), |
221 | | } |
222 | 14 | } |
223 | | |
224 | 0 | fn aliases(&self) -> &[String] { |
225 | 0 | &self.aliases |
226 | 0 | } |
227 | | |
228 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
229 | 0 | ReversedUDAF::Identical |
230 | 0 | } |
231 | | |
232 | 0 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
233 | 0 | if arg_types.len() != 1 { |
234 | 0 | return exec_err!("{} expects exactly one argument.", self.name()); |
235 | 0 | } |
236 | 0 | coerce_avg_type(self.name(), arg_types) |
237 | 0 | } |
238 | | } |
239 | | |
240 | | /// An accumulator to compute the average |
241 | | #[derive(Debug, Default)] |
242 | | pub struct AvgAccumulator { |
243 | | sum: Option<f64>, |
244 | | count: u64, |
245 | | } |
246 | | |
247 | | impl Accumulator for AvgAccumulator { |
248 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
249 | 0 | let values = values[0].as_primitive::<Float64Type>(); |
250 | 0 | self.count += (values.len() - values.null_count()) as u64; |
251 | 0 | if let Some(x) = sum(values) { |
252 | 0 | let v = self.sum.get_or_insert(0.); |
253 | 0 | *v += x; |
254 | 0 | } |
255 | 0 | Ok(()) |
256 | 0 | } |
257 | | |
258 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
259 | 0 | Ok(ScalarValue::Float64( |
260 | 0 | self.sum.map(|f| f / self.count as f64), |
261 | 0 | )) |
262 | 0 | } |
263 | | |
264 | 0 | fn size(&self) -> usize { |
265 | 0 | std::mem::size_of_val(self) |
266 | 0 | } |
267 | | |
268 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
269 | 0 | Ok(vec![ |
270 | 0 | ScalarValue::from(self.count), |
271 | 0 | ScalarValue::Float64(self.sum), |
272 | 0 | ]) |
273 | 0 | } |
274 | | |
275 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
276 | 0 | // counts are summed |
277 | 0 | self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default(); |
278 | | |
279 | | // sums are summed |
280 | 0 | if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) { |
281 | 0 | let v = self.sum.get_or_insert(0.); |
282 | 0 | *v += x; |
283 | 0 | } |
284 | 0 | Ok(()) |
285 | 0 | } |
286 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
287 | 0 | let values = values[0].as_primitive::<Float64Type>(); |
288 | 0 | self.count -= (values.len() - values.null_count()) as u64; |
289 | 0 | if let Some(x) = sum(values) { |
290 | 0 | self.sum = Some(self.sum.unwrap() - x); |
291 | 0 | } |
292 | 0 | Ok(()) |
293 | 0 | } |
294 | | |
295 | 0 | fn supports_retract_batch(&self) -> bool { |
296 | 0 | true |
297 | 0 | } |
298 | | } |
299 | | |
300 | | /// An accumulator to compute the average for decimals |
301 | | #[derive(Debug)] |
302 | | struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> { |
303 | | sum: Option<T::Native>, |
304 | | count: u64, |
305 | | sum_scale: i8, |
306 | | sum_precision: u8, |
307 | | target_precision: u8, |
308 | | target_scale: i8, |
309 | | } |
310 | | |
311 | | impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> { |
312 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
313 | 0 | let values = values[0].as_primitive::<T>(); |
314 | 0 | self.count += (values.len() - values.null_count()) as u64; |
315 | | |
316 | 0 | if let Some(x) = sum(values) { |
317 | 0 | let v = self.sum.get_or_insert(T::Native::default()); |
318 | 0 | self.sum = Some(v.add_wrapping(x)); |
319 | 0 | } |
320 | 0 | Ok(()) |
321 | 0 | } |
322 | | |
323 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
324 | 0 | let v = self |
325 | 0 | .sum |
326 | 0 | .map(|v| { |
327 | 0 | DecimalAverager::<T>::try_new( |
328 | 0 | self.sum_scale, |
329 | 0 | self.target_precision, |
330 | 0 | self.target_scale, |
331 | 0 | )? |
332 | 0 | .avg(v, T::Native::from_usize(self.count as usize).unwrap()) |
333 | 0 | }) |
334 | 0 | .transpose()?; |
335 | | |
336 | 0 | ScalarValue::new_primitive::<T>( |
337 | 0 | v, |
338 | 0 | &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), |
339 | 0 | ) |
340 | 0 | } |
341 | | |
342 | 0 | fn size(&self) -> usize { |
343 | 0 | std::mem::size_of_val(self) |
344 | 0 | } |
345 | | |
346 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
347 | 0 | Ok(vec![ |
348 | 0 | ScalarValue::from(self.count), |
349 | 0 | ScalarValue::new_primitive::<T>( |
350 | 0 | self.sum, |
351 | 0 | &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), |
352 | 0 | )?, |
353 | | ]) |
354 | 0 | } |
355 | | |
356 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
357 | 0 | // counts are summed |
358 | 0 | self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default(); |
359 | | |
360 | | // sums are summed |
361 | 0 | if let Some(x) = sum(states[1].as_primitive::<T>()) { |
362 | 0 | let v = self.sum.get_or_insert(T::Native::default()); |
363 | 0 | self.sum = Some(v.add_wrapping(x)); |
364 | 0 | } |
365 | 0 | Ok(()) |
366 | 0 | } |
367 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
368 | 0 | let values = values[0].as_primitive::<T>(); |
369 | 0 | self.count -= (values.len() - values.null_count()) as u64; |
370 | 0 | if let Some(x) = sum(values) { |
371 | 0 | self.sum = Some(self.sum.unwrap().sub_wrapping(x)); |
372 | 0 | } |
373 | 0 | Ok(()) |
374 | 0 | } |
375 | | |
376 | 0 | fn supports_retract_batch(&self) -> bool { |
377 | 0 | true |
378 | 0 | } |
379 | | } |
380 | | |
381 | | /// An accumulator to compute the average of `[PrimitiveArray<T>]`. |
382 | | /// Stores values as native types, and does overflow checking |
383 | | /// |
384 | | /// F: Function that calculates the average value from a sum of |
385 | | /// T::Native and a total count |
386 | | #[derive(Debug)] |
387 | | struct AvgGroupsAccumulator<T, F> |
388 | | where |
389 | | T: ArrowNumericType + Send, |
390 | | F: Fn(T::Native, u64) -> Result<T::Native> + Send, |
391 | | { |
392 | | /// The type of the internal sum |
393 | | sum_data_type: DataType, |
394 | | |
395 | | /// The type of the returned sum |
396 | | return_data_type: DataType, |
397 | | |
398 | | /// Count per group (use u64 to make UInt64Array) |
399 | | counts: Vec<u64>, |
400 | | |
401 | | /// Sums per group, stored as the native type |
402 | | sums: Vec<T::Native>, |
403 | | |
404 | | /// Track nulls in the input / filters |
405 | | null_state: NullState, |
406 | | |
407 | | /// Function that computes the final average (value / count) |
408 | | avg_fn: F, |
409 | | } |
410 | | |
411 | | impl<T, F> AvgGroupsAccumulator<T, F> |
412 | | where |
413 | | T: ArrowNumericType + Send, |
414 | | F: Fn(T::Native, u64) -> Result<T::Native> + Send, |
415 | | { |
416 | 14 | pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { |
417 | 14 | debug!( |
418 | 0 | "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}", |
419 | 0 | std::any::type_name::<T>() |
420 | | ); |
421 | | |
422 | 14 | Self { |
423 | 14 | return_data_type: return_data_type.clone(), |
424 | 14 | sum_data_type: sum_data_type.clone(), |
425 | 14 | counts: vec![], |
426 | 14 | sums: vec![], |
427 | 14 | null_state: NullState::new(), |
428 | 14 | avg_fn, |
429 | 14 | } |
430 | 14 | } |
431 | | } |
432 | | |
433 | | impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F> |
434 | | where |
435 | | T: ArrowNumericType + Send, |
436 | | F: Fn(T::Native, u64) -> Result<T::Native> + Send, |
437 | | { |
438 | 17 | fn update_batch( |
439 | 17 | &mut self, |
440 | 17 | values: &[ArrayRef], |
441 | 17 | group_indices: &[usize], |
442 | 17 | opt_filter: Option<&array::BooleanArray>, |
443 | 17 | total_num_groups: usize, |
444 | 17 | ) -> Result<()> { |
445 | 17 | assert_eq!(values.len(), 1, "single argument to update_batch"0 ); |
446 | 17 | let values = values[0].as_primitive::<T>(); |
447 | 17 | |
448 | 17 | // increment counts, update sums |
449 | 17 | self.counts.resize(total_num_groups, 0); |
450 | 17 | self.sums.resize(total_num_groups, T::default_value()); |
451 | 17 | self.null_state.accumulate( |
452 | 17 | group_indices, |
453 | 17 | values, |
454 | 17 | opt_filter, |
455 | 17 | total_num_groups, |
456 | 68 | |group_index, new_value| { |
457 | 68 | let sum = &mut self.sums[group_index]; |
458 | 68 | *sum = sum.add_wrapping(new_value); |
459 | 68 | |
460 | 68 | self.counts[group_index] += 1; |
461 | 68 | }, |
462 | 17 | ); |
463 | 17 | |
464 | 17 | Ok(()) |
465 | 17 | } |
466 | | |
467 | 8 | fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> { |
468 | 8 | let counts = emit_to.take_needed(&mut self.counts); |
469 | 8 | let sums = emit_to.take_needed(&mut self.sums); |
470 | 8 | let nulls = self.null_state.build(emit_to); |
471 | 8 | |
472 | 8 | assert_eq!(nulls.len(), sums.len()); |
473 | 8 | assert_eq!(counts.len(), sums.len()); |
474 | | |
475 | | // don't evaluate averages with null inputs to avoid errors on null values |
476 | | |
477 | 8 | let array: PrimitiveArray<T> = if nulls.null_count() > 0 { |
478 | 0 | let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len()) |
479 | 0 | .with_data_type(self.return_data_type.clone()); |
480 | 0 | let iter = sums.into_iter().zip(counts).zip(nulls.iter()); |
481 | | |
482 | 0 | for ((sum, count), is_valid) in iter { |
483 | 0 | if is_valid { |
484 | 0 | builder.append_value((self.avg_fn)(sum, count)?) |
485 | 0 | } else { |
486 | 0 | builder.append_null(); |
487 | 0 | } |
488 | | } |
489 | 0 | builder.finish() |
490 | | } else { |
491 | 8 | let averages: Vec<T::Native> = sums |
492 | 8 | .into_iter() |
493 | 8 | .zip(counts.into_iter()) |
494 | 12 | .map(|(sum, count)| (self.avg_fn)(sum, count)) |
495 | 8 | .collect::<Result<Vec<_>>>()?0 ; |
496 | 8 | PrimitiveArray::new(averages.into(), Some(nulls)) // no copy |
497 | 8 | .with_data_type(self.return_data_type.clone()) |
498 | | }; |
499 | | |
500 | 8 | Ok(Arc::new(array)) |
501 | 8 | } |
502 | | |
503 | | // return arrays for sums and counts |
504 | 20 | fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { |
505 | 20 | let nulls = self.null_state.build(emit_to); |
506 | 20 | let nulls = Some(nulls); |
507 | 20 | |
508 | 20 | let counts = emit_to.take_needed(&mut self.counts); |
509 | 20 | let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy |
510 | 20 | |
511 | 20 | let sums = emit_to.take_needed(&mut self.sums); |
512 | 20 | let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy |
513 | 20 | .with_data_type(self.sum_data_type.clone()); |
514 | 20 | |
515 | 20 | Ok(vec![ |
516 | 20 | Arc::new(counts) as ArrayRef, |
517 | 20 | Arc::new(sums) as ArrayRef, |
518 | 20 | ]) |
519 | 20 | } |
520 | | |
521 | 14 | fn merge_batch( |
522 | 14 | &mut self, |
523 | 14 | values: &[ArrayRef], |
524 | 14 | group_indices: &[usize], |
525 | 14 | opt_filter: Option<&array::BooleanArray>, |
526 | 14 | total_num_groups: usize, |
527 | 14 | ) -> Result<()> { |
528 | 14 | assert_eq!(values.len(), 2, "two arguments to merge_batch"0 ); |
529 | | // first batch is counts, second is partial sums |
530 | 14 | let partial_counts = values[0].as_primitive::<UInt64Type>(); |
531 | 14 | let partial_sums = values[1].as_primitive::<T>(); |
532 | 14 | // update counts with partial counts |
533 | 14 | self.counts.resize(total_num_groups, 0); |
534 | 14 | self.null_state.accumulate( |
535 | 14 | group_indices, |
536 | 14 | partial_counts, |
537 | 14 | opt_filter, |
538 | 14 | total_num_groups, |
539 | 26 | |group_index, partial_count| { |
540 | 26 | self.counts[group_index] += partial_count; |
541 | 26 | }, |
542 | 14 | ); |
543 | 14 | |
544 | 14 | // update sums |
545 | 14 | self.sums.resize(total_num_groups, T::default_value()); |
546 | 14 | self.null_state.accumulate( |
547 | 14 | group_indices, |
548 | 14 | partial_sums, |
549 | 14 | opt_filter, |
550 | 14 | total_num_groups, |
551 | 26 | |group_index, new_value: <T as ArrowPrimitiveType>::Native| { |
552 | 26 | let sum = &mut self.sums[group_index]; |
553 | 26 | *sum = sum.add_wrapping(new_value); |
554 | 26 | }, |
555 | 14 | ); |
556 | 14 | |
557 | 14 | Ok(()) |
558 | 14 | } |
559 | | |
560 | 0 | fn convert_to_state( |
561 | 0 | &self, |
562 | 0 | values: &[ArrayRef], |
563 | 0 | opt_filter: Option<&BooleanArray>, |
564 | 0 | ) -> Result<Vec<ArrayRef>> { |
565 | 0 | let sums = values[0] |
566 | 0 | .as_primitive::<T>() |
567 | 0 | .clone() |
568 | 0 | .with_data_type(self.sum_data_type.clone()); |
569 | 0 | let counts = UInt64Array::from_value(1, sums.len()); |
570 | 0 |
|
571 | 0 | let nulls = filtered_null_mask(opt_filter, &sums); |
572 | 0 |
|
573 | 0 | // set nulls on the arrays |
574 | 0 | let counts = set_nulls(counts, nulls.clone()); |
575 | 0 | let sums = set_nulls(sums, nulls); |
576 | 0 |
|
577 | 0 | Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) |
578 | 0 | } |
579 | | |
580 | 10 | fn supports_convert_to_state(&self) -> bool { |
581 | 10 | true |
582 | 10 | } |
583 | | |
584 | 85 | fn size(&self) -> usize { |
585 | 85 | self.counts.capacity() * std::mem::size_of::<u64>() |
586 | 85 | + self.sums.capacity() * std::mem::size_of::<T>() |
587 | 85 | } |
588 | | } |