/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/sum.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines `SUM` and `SUM DISTINCT` aggregate accumulators |
19 | | |
20 | | use ahash::RandomState; |
21 | | use datafusion_expr::utils::AggregateOrderSensitivity; |
22 | | use std::any::Any; |
23 | | use std::collections::HashSet; |
24 | | |
25 | | use arrow::array::Array; |
26 | | use arrow::array::ArrowNativeTypeOp; |
27 | | use arrow::array::{ArrowNumericType, AsArray}; |
28 | | use arrow::datatypes::ArrowNativeType; |
29 | | use arrow::datatypes::ArrowPrimitiveType; |
30 | | use arrow::datatypes::{ |
31 | | DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, |
32 | | DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, |
33 | | }; |
34 | | use arrow::{array::ArrayRef, datatypes::Field}; |
35 | | use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; |
36 | | use datafusion_expr::function::AccumulatorArgs; |
37 | | use datafusion_expr::function::StateFieldsArgs; |
38 | | use datafusion_expr::utils::format_state_name; |
39 | | use datafusion_expr::{ |
40 | | Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, |
41 | | }; |
42 | | use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; |
43 | | use datafusion_functions_aggregate_common::utils::Hashable; |
44 | | |
45 | | make_udaf_expr_and_func!( |
46 | | Sum, |
47 | | sum, |
48 | | expression, |
49 | | "Returns the sum of a group of values.", |
50 | | sum_udaf |
51 | | ); |
52 | | |
53 | | /// Sum only supports a subset of numeric types, instead relying on type coercion |
54 | | /// |
55 | | /// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive) |
56 | | /// |
57 | | /// `args` is [AccumulatorArgs] |
58 | | /// `helper` is a macro accepting (ArrowPrimitiveType, DataType) |
59 | | macro_rules! downcast_sum { |
60 | | ($args:ident, $helper:ident) => { |
61 | | match $args.return_type { |
62 | | DataType::UInt64 => $helper!(UInt64Type, $args.return_type), |
63 | | DataType::Int64 => $helper!(Int64Type, $args.return_type), |
64 | | DataType::Float64 => $helper!(Float64Type, $args.return_type), |
65 | | DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type), |
66 | | DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type), |
67 | | _ => { |
68 | | not_impl_err!( |
69 | | "Sum not supported for {}: {}", |
70 | | $args.name, |
71 | | $args.return_type |
72 | | ) |
73 | | } |
74 | | } |
75 | | }; |
76 | | } |
77 | | |
78 | | #[derive(Debug)] |
79 | | pub struct Sum { |
80 | | signature: Signature, |
81 | | } |
82 | | |
83 | | impl Sum { |
84 | 1 | pub fn new() -> Self { |
85 | 1 | Self { |
86 | 1 | signature: Signature::user_defined(Volatility::Immutable), |
87 | 1 | } |
88 | 1 | } |
89 | | } |
90 | | |
91 | | impl Default for Sum { |
92 | 1 | fn default() -> Self { |
93 | 1 | Self::new() |
94 | 1 | } |
95 | | } |
96 | | |
97 | | impl AggregateUDFImpl for Sum { |
98 | 0 | fn as_any(&self) -> &dyn Any { |
99 | 0 | self |
100 | 0 | } |
101 | | |
102 | 1 | fn name(&self) -> &str { |
103 | 1 | "sum" |
104 | 1 | } |
105 | | |
106 | 1 | fn signature(&self) -> &Signature { |
107 | 1 | &self.signature |
108 | 1 | } |
109 | | |
110 | 0 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
111 | 0 | if arg_types.len() != 1 { |
112 | 0 | return exec_err!("SUM expects exactly one argument"); |
113 | 0 | } |
114 | | |
115 | | // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc |
116 | | // smallint, int, bigint, real, double precision, decimal, or interval. |
117 | | |
118 | 0 | fn coerced_type(data_type: &DataType) -> Result<DataType> { |
119 | 0 | match data_type { |
120 | 0 | DataType::Dictionary(_, v) => coerced_type(v), |
121 | | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
122 | | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
123 | | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { |
124 | 0 | Ok(data_type.clone()) |
125 | | } |
126 | 0 | dt if dt.is_signed_integer() => Ok(DataType::Int64), |
127 | 0 | dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), |
128 | 0 | dt if dt.is_floating() => Ok(DataType::Float64), |
129 | 0 | _ => exec_err!("Sum not supported for {}", data_type), |
130 | | } |
131 | 0 | } |
132 | | |
133 | 0 | Ok(vec![coerced_type(&arg_types[0])?]) |
134 | 0 | } |
135 | | |
136 | 1 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
137 | 1 | match &arg_types[0] { |
138 | 0 | DataType::Int64 => Ok(DataType::Int64), |
139 | 1 | DataType::UInt64 => Ok(DataType::UInt64), |
140 | 0 | DataType::Float64 => Ok(DataType::Float64), |
141 | 0 | DataType::Decimal128(precision, scale) => { |
142 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
143 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
144 | 0 | let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); |
145 | 0 | Ok(DataType::Decimal128(new_precision, *scale)) |
146 | | } |
147 | 0 | DataType::Decimal256(precision, scale) => { |
148 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
149 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
150 | 0 | let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); |
151 | 0 | Ok(DataType::Decimal256(new_precision, *scale)) |
152 | | } |
153 | 0 | other => { |
154 | 0 | exec_err!("[return_type] SUM not supported for {}", other) |
155 | | } |
156 | | } |
157 | 1 | } |
158 | | |
159 | 0 | fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
160 | 0 | if args.is_distinct { |
161 | | macro_rules! helper { |
162 | | ($t:ty, $dt:expr) => { |
163 | | Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?)) |
164 | | }; |
165 | | } |
166 | 0 | downcast_sum!(args, helper) |
167 | | } else { |
168 | | macro_rules! helper { |
169 | | ($t:ty, $dt:expr) => { |
170 | | Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) |
171 | | }; |
172 | | } |
173 | 0 | downcast_sum!(args, helper) |
174 | | } |
175 | 0 | } |
176 | | |
177 | 2 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
178 | 2 | if args.is_distinct { |
179 | 0 | Ok(vec![Field::new_list( |
180 | 0 | format_state_name(args.name, "sum distinct"), |
181 | 0 | // See COMMENTS.md to understand why nullable is set to true |
182 | 0 | Field::new("item", args.return_type.clone(), true), |
183 | 0 | false, |
184 | 0 | )]) |
185 | | } else { |
186 | 2 | Ok(vec![Field::new( |
187 | 2 | format_state_name(args.name, "sum"), |
188 | 2 | args.return_type.clone(), |
189 | 2 | true, |
190 | 2 | )]) |
191 | | } |
192 | 2 | } |
193 | | |
194 | 0 | fn aliases(&self) -> &[String] { |
195 | 0 | &[] |
196 | 0 | } |
197 | | |
198 | 1 | fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { |
199 | 1 | !args.is_distinct |
200 | 1 | } |
201 | | |
202 | 1 | fn create_groups_accumulator( |
203 | 1 | &self, |
204 | 1 | args: AccumulatorArgs, |
205 | 1 | ) -> Result<Box<dyn GroupsAccumulator>> { |
206 | | macro_rules! helper { |
207 | | ($t:ty, $dt:expr) => { |
208 | | Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( |
209 | | &$dt, |
210 | 3 | |x, y| *x = x.add_wrapping(y), |
211 | | ))) |
212 | | }; |
213 | | } |
214 | 1 | downcast_sum!(args, helper) |
215 | 1 | } |
216 | | |
217 | 0 | fn create_sliding_accumulator( |
218 | 0 | &self, |
219 | 0 | args: AccumulatorArgs, |
220 | 0 | ) -> Result<Box<dyn Accumulator>> { |
221 | | macro_rules! helper { |
222 | | ($t:ty, $dt:expr) => { |
223 | | Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) |
224 | | }; |
225 | | } |
226 | 0 | downcast_sum!(args, helper) |
227 | 0 | } |
228 | | |
229 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
230 | 0 | ReversedUDAF::Identical |
231 | 0 | } |
232 | | |
233 | 0 | fn order_sensitivity(&self) -> AggregateOrderSensitivity { |
234 | 0 | AggregateOrderSensitivity::Insensitive |
235 | 0 | } |
236 | | } |
237 | | |
238 | | /// This accumulator computes SUM incrementally |
239 | | struct SumAccumulator<T: ArrowNumericType> { |
240 | | sum: Option<T::Native>, |
241 | | data_type: DataType, |
242 | | } |
243 | | |
244 | | impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> { |
245 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
246 | 0 | write!(f, "SumAccumulator({})", self.data_type) |
247 | 0 | } |
248 | | } |
249 | | |
250 | | impl<T: ArrowNumericType> SumAccumulator<T> { |
251 | 0 | fn new(data_type: DataType) -> Self { |
252 | 0 | Self { |
253 | 0 | sum: None, |
254 | 0 | data_type, |
255 | 0 | } |
256 | 0 | } |
257 | | } |
258 | | |
259 | | impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> { |
260 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
261 | 0 | Ok(vec![self.evaluate()?]) |
262 | 0 | } |
263 | | |
264 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
265 | 0 | let values = values[0].as_primitive::<T>(); |
266 | 0 | if let Some(x) = arrow::compute::sum(values) { |
267 | 0 | let v = self.sum.get_or_insert(T::Native::usize_as(0)); |
268 | 0 | *v = v.add_wrapping(x); |
269 | 0 | } |
270 | 0 | Ok(()) |
271 | 0 | } |
272 | | |
273 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
274 | 0 | self.update_batch(states) |
275 | 0 | } |
276 | | |
277 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
278 | 0 | ScalarValue::new_primitive::<T>(self.sum, &self.data_type) |
279 | 0 | } |
280 | | |
281 | 0 | fn size(&self) -> usize { |
282 | 0 | std::mem::size_of_val(self) |
283 | 0 | } |
284 | | } |
285 | | |
286 | | /// This accumulator incrementally computes sums over a sliding window |
287 | | /// |
288 | | /// This is separate from [`SumAccumulator`] as requires additional state |
289 | | struct SlidingSumAccumulator<T: ArrowNumericType> { |
290 | | sum: T::Native, |
291 | | count: u64, |
292 | | data_type: DataType, |
293 | | } |
294 | | |
295 | | impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> { |
296 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
297 | 0 | write!(f, "SlidingSumAccumulator({})", self.data_type) |
298 | 0 | } |
299 | | } |
300 | | |
301 | | impl<T: ArrowNumericType> SlidingSumAccumulator<T> { |
302 | 0 | fn new(data_type: DataType) -> Self { |
303 | 0 | Self { |
304 | 0 | sum: T::Native::usize_as(0), |
305 | 0 | count: 0, |
306 | 0 | data_type, |
307 | 0 | } |
308 | 0 | } |
309 | | } |
310 | | |
311 | | impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> { |
312 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
313 | 0 | Ok(vec![self.evaluate()?, self.count.into()]) |
314 | 0 | } |
315 | | |
316 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
317 | 0 | let values = values[0].as_primitive::<T>(); |
318 | 0 | self.count += (values.len() - values.null_count()) as u64; |
319 | 0 | if let Some(x) = arrow::compute::sum(values) { |
320 | 0 | self.sum = self.sum.add_wrapping(x) |
321 | 0 | } |
322 | 0 | Ok(()) |
323 | 0 | } |
324 | | |
325 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
326 | 0 | let values = states[0].as_primitive::<T>(); |
327 | 0 | if let Some(x) = arrow::compute::sum(values) { |
328 | 0 | self.sum = self.sum.add_wrapping(x) |
329 | 0 | } |
330 | 0 | if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) { |
331 | 0 | self.count += x; |
332 | 0 | } |
333 | 0 | Ok(()) |
334 | 0 | } |
335 | | |
336 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
337 | 0 | let v = (self.count != 0).then_some(self.sum); |
338 | 0 | ScalarValue::new_primitive::<T>(v, &self.data_type) |
339 | 0 | } |
340 | | |
341 | 0 | fn size(&self) -> usize { |
342 | 0 | std::mem::size_of_val(self) |
343 | 0 | } |
344 | | |
345 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
346 | 0 | let values = values[0].as_primitive::<T>(); |
347 | 0 | if let Some(x) = arrow::compute::sum(values) { |
348 | 0 | self.sum = self.sum.sub_wrapping(x) |
349 | 0 | } |
350 | 0 | self.count -= (values.len() - values.null_count()) as u64; |
351 | 0 | Ok(()) |
352 | 0 | } |
353 | | |
354 | 0 | fn supports_retract_batch(&self) -> bool { |
355 | 0 | true |
356 | 0 | } |
357 | | } |
358 | | |
359 | | struct DistinctSumAccumulator<T: ArrowPrimitiveType> { |
360 | | values: HashSet<Hashable<T::Native>, RandomState>, |
361 | | data_type: DataType, |
362 | | } |
363 | | |
364 | | impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> { |
365 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
366 | 0 | write!(f, "DistinctSumAccumulator({})", self.data_type) |
367 | 0 | } |
368 | | } |
369 | | |
370 | | impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> { |
371 | 0 | pub fn try_new(data_type: &DataType) -> Result<Self> { |
372 | 0 | Ok(Self { |
373 | 0 | values: HashSet::default(), |
374 | 0 | data_type: data_type.clone(), |
375 | 0 | }) |
376 | 0 | } |
377 | | } |
378 | | |
379 | | impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> { |
380 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
381 | | // 1. Stores aggregate state in `ScalarValue::List` |
382 | | // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set |
383 | 0 | let state_out = { |
384 | 0 | let distinct_values = self |
385 | 0 | .values |
386 | 0 | .iter() |
387 | 0 | .map(|value| { |
388 | 0 | ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type) |
389 | 0 | }) |
390 | 0 | .collect::<Result<Vec<_>>>()?; |
391 | | |
392 | 0 | vec![ScalarValue::List(ScalarValue::new_list_nullable( |
393 | 0 | &distinct_values, |
394 | 0 | &self.data_type, |
395 | 0 | ))] |
396 | 0 | }; |
397 | 0 | Ok(state_out) |
398 | 0 | } |
399 | | |
400 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
401 | 0 | if values.is_empty() { |
402 | 0 | return Ok(()); |
403 | 0 | } |
404 | 0 |
|
405 | 0 | let array = values[0].as_primitive::<T>(); |
406 | 0 | match array.nulls().filter(|x| x.null_count() > 0) { |
407 | 0 | Some(n) => { |
408 | 0 | for idx in n.valid_indices() { |
409 | 0 | self.values.insert(Hashable(array.value(idx))); |
410 | 0 | } |
411 | | } |
412 | 0 | None => array.values().iter().for_each(|x| { |
413 | 0 | self.values.insert(Hashable(*x)); |
414 | 0 | }), |
415 | | } |
416 | 0 | Ok(()) |
417 | 0 | } |
418 | | |
419 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
420 | 0 | for x in states[0].as_list::<i32>().iter().flatten() { |
421 | 0 | self.update_batch(&[x])? |
422 | | } |
423 | 0 | Ok(()) |
424 | 0 | } |
425 | | |
426 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
427 | 0 | let mut acc = T::Native::usize_as(0); |
428 | 0 | for distinct_value in self.values.iter() { |
429 | 0 | acc = acc.add_wrapping(distinct_value.0) |
430 | | } |
431 | 0 | let v = (!self.values.is_empty()).then_some(acc); |
432 | 0 | ScalarValue::new_primitive::<T>(v, &self.data_type) |
433 | 0 | } |
434 | | |
435 | 0 | fn size(&self) -> usize { |
436 | 0 | std::mem::size_of_val(self) |
437 | 0 | + self.values.capacity() * std::mem::size_of::<T::Native>() |
438 | 0 | } |
439 | | } |