Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/sum.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20
use ahash::RandomState;
21
use datafusion_expr::utils::AggregateOrderSensitivity;
22
use std::any::Any;
23
use std::collections::HashSet;
24
25
use arrow::array::Array;
26
use arrow::array::ArrowNativeTypeOp;
27
use arrow::array::{ArrowNumericType, AsArray};
28
use arrow::datatypes::ArrowNativeType;
29
use arrow::datatypes::ArrowPrimitiveType;
30
use arrow::datatypes::{
31
    DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
32
    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
33
};
34
use arrow::{array::ArrayRef, datatypes::Field};
35
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
36
use datafusion_expr::function::AccumulatorArgs;
37
use datafusion_expr::function::StateFieldsArgs;
38
use datafusion_expr::utils::format_state_name;
39
use datafusion_expr::{
40
    Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility,
41
};
42
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
43
use datafusion_functions_aggregate_common::utils::Hashable;
44
45
make_udaf_expr_and_func!(
46
    Sum,
47
    sum,
48
    expression,
49
    "Returns the sum of a group of values.",
50
    sum_udaf
51
);
52
53
/// Sum only supports a subset of numeric types, instead relying on type coercion
54
///
55
/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
56
///
57
/// `args` is [AccumulatorArgs]
58
/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
59
macro_rules! downcast_sum {
60
    ($args:ident, $helper:ident) => {
61
        match $args.return_type {
62
            DataType::UInt64 => $helper!(UInt64Type, $args.return_type),
63
            DataType::Int64 => $helper!(Int64Type, $args.return_type),
64
            DataType::Float64 => $helper!(Float64Type, $args.return_type),
65
            DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type),
66
            DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type),
67
            _ => {
68
                not_impl_err!(
69
                    "Sum not supported for {}: {}",
70
                    $args.name,
71
                    $args.return_type
72
                )
73
            }
74
        }
75
    };
76
}
77
78
#[derive(Debug)]
79
pub struct Sum {
80
    signature: Signature,
81
}
82
83
impl Sum {
84
1
    pub fn new() -> Self {
85
1
        Self {
86
1
            signature: Signature::user_defined(Volatility::Immutable),
87
1
        }
88
1
    }
89
}
90
91
impl Default for Sum {
92
1
    fn default() -> Self {
93
1
        Self::new()
94
1
    }
95
}
96
97
impl AggregateUDFImpl for Sum {
98
0
    fn as_any(&self) -> &dyn Any {
99
0
        self
100
0
    }
101
102
1
    fn name(&self) -> &str {
103
1
        "sum"
104
1
    }
105
106
1
    fn signature(&self) -> &Signature {
107
1
        &self.signature
108
1
    }
109
110
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
111
0
        if arg_types.len() != 1 {
112
0
            return exec_err!("SUM expects exactly one argument");
113
0
        }
114
115
        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
116
        // smallint, int, bigint, real, double precision, decimal, or interval.
117
118
0
        fn coerced_type(data_type: &DataType) -> Result<DataType> {
119
0
            match data_type {
120
0
                DataType::Dictionary(_, v) => coerced_type(v),
121
                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
122
                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
123
                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
124
0
                    Ok(data_type.clone())
125
                }
126
0
                dt if dt.is_signed_integer() => Ok(DataType::Int64),
127
0
                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
128
0
                dt if dt.is_floating() => Ok(DataType::Float64),
129
0
                _ => exec_err!("Sum not supported for {}", data_type),
130
            }
131
0
        }
132
133
0
        Ok(vec![coerced_type(&arg_types[0])?])
134
0
    }
135
136
1
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137
1
        match &arg_types[0] {
138
0
            DataType::Int64 => Ok(DataType::Int64),
139
1
            DataType::UInt64 => Ok(DataType::UInt64),
140
0
            DataType::Float64 => Ok(DataType::Float64),
141
0
            DataType::Decimal128(precision, scale) => {
142
0
                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
143
0
                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
144
0
                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
145
0
                Ok(DataType::Decimal128(new_precision, *scale))
146
            }
147
0
            DataType::Decimal256(precision, scale) => {
148
0
                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
149
0
                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
150
0
                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
151
0
                Ok(DataType::Decimal256(new_precision, *scale))
152
            }
153
0
            other => {
154
0
                exec_err!("[return_type] SUM not supported for {}", other)
155
            }
156
        }
157
1
    }
158
159
0
    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
160
0
        if args.is_distinct {
161
            macro_rules! helper {
162
                ($t:ty, $dt:expr) => {
163
                    Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
164
                };
165
            }
166
0
            downcast_sum!(args, helper)
167
        } else {
168
            macro_rules! helper {
169
                ($t:ty, $dt:expr) => {
170
                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
171
                };
172
            }
173
0
            downcast_sum!(args, helper)
174
        }
175
0
    }
176
177
2
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
178
2
        if args.is_distinct {
179
0
            Ok(vec![Field::new_list(
180
0
                format_state_name(args.name, "sum distinct"),
181
0
                // See COMMENTS.md to understand why nullable is set to true
182
0
                Field::new("item", args.return_type.clone(), true),
183
0
                false,
184
0
            )])
185
        } else {
186
2
            Ok(vec![Field::new(
187
2
                format_state_name(args.name, "sum"),
188
2
                args.return_type.clone(),
189
2
                true,
190
2
            )])
191
        }
192
2
    }
193
194
0
    fn aliases(&self) -> &[String] {
195
0
        &[]
196
0
    }
197
198
1
    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
199
1
        !args.is_distinct
200
1
    }
201
202
1
    fn create_groups_accumulator(
203
1
        &self,
204
1
        args: AccumulatorArgs,
205
1
    ) -> Result<Box<dyn GroupsAccumulator>> {
206
        macro_rules! helper {
207
            ($t:ty, $dt:expr) => {
208
                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
209
                    &$dt,
210
3
                    |x, y| *x = x.add_wrapping(y),
211
                )))
212
            };
213
        }
214
1
        downcast_sum!(args, helper)
215
1
    }
216
217
0
    fn create_sliding_accumulator(
218
0
        &self,
219
0
        args: AccumulatorArgs,
220
0
    ) -> Result<Box<dyn Accumulator>> {
221
        macro_rules! helper {
222
            ($t:ty, $dt:expr) => {
223
                Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
224
            };
225
        }
226
0
        downcast_sum!(args, helper)
227
0
    }
228
229
0
    fn reverse_expr(&self) -> ReversedUDAF {
230
0
        ReversedUDAF::Identical
231
0
    }
232
233
0
    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
234
0
        AggregateOrderSensitivity::Insensitive
235
0
    }
236
}
237
238
/// This accumulator computes SUM incrementally
239
struct SumAccumulator<T: ArrowNumericType> {
240
    sum: Option<T::Native>,
241
    data_type: DataType,
242
}
243
244
impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
245
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246
0
        write!(f, "SumAccumulator({})", self.data_type)
247
0
    }
248
}
249
250
impl<T: ArrowNumericType> SumAccumulator<T> {
251
0
    fn new(data_type: DataType) -> Self {
252
0
        Self {
253
0
            sum: None,
254
0
            data_type,
255
0
        }
256
0
    }
257
}
258
259
impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
260
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
261
0
        Ok(vec![self.evaluate()?])
262
0
    }
263
264
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
265
0
        let values = values[0].as_primitive::<T>();
266
0
        if let Some(x) = arrow::compute::sum(values) {
267
0
            let v = self.sum.get_or_insert(T::Native::usize_as(0));
268
0
            *v = v.add_wrapping(x);
269
0
        }
270
0
        Ok(())
271
0
    }
272
273
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
274
0
        self.update_batch(states)
275
0
    }
276
277
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
278
0
        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
279
0
    }
280
281
0
    fn size(&self) -> usize {
282
0
        std::mem::size_of_val(self)
283
0
    }
284
}
285
286
/// This accumulator incrementally computes sums over a sliding window
287
///
288
/// This is separate from [`SumAccumulator`] as requires additional state
289
struct SlidingSumAccumulator<T: ArrowNumericType> {
290
    sum: T::Native,
291
    count: u64,
292
    data_type: DataType,
293
}
294
295
impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
296
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297
0
        write!(f, "SlidingSumAccumulator({})", self.data_type)
298
0
    }
299
}
300
301
impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
302
0
    fn new(data_type: DataType) -> Self {
303
0
        Self {
304
0
            sum: T::Native::usize_as(0),
305
0
            count: 0,
306
0
            data_type,
307
0
        }
308
0
    }
309
}
310
311
impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
312
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
313
0
        Ok(vec![self.evaluate()?, self.count.into()])
314
0
    }
315
316
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
317
0
        let values = values[0].as_primitive::<T>();
318
0
        self.count += (values.len() - values.null_count()) as u64;
319
0
        if let Some(x) = arrow::compute::sum(values) {
320
0
            self.sum = self.sum.add_wrapping(x)
321
0
        }
322
0
        Ok(())
323
0
    }
324
325
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
326
0
        let values = states[0].as_primitive::<T>();
327
0
        if let Some(x) = arrow::compute::sum(values) {
328
0
            self.sum = self.sum.add_wrapping(x)
329
0
        }
330
0
        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
331
0
            self.count += x;
332
0
        }
333
0
        Ok(())
334
0
    }
335
336
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
337
0
        let v = (self.count != 0).then_some(self.sum);
338
0
        ScalarValue::new_primitive::<T>(v, &self.data_type)
339
0
    }
340
341
0
    fn size(&self) -> usize {
342
0
        std::mem::size_of_val(self)
343
0
    }
344
345
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
346
0
        let values = values[0].as_primitive::<T>();
347
0
        if let Some(x) = arrow::compute::sum(values) {
348
0
            self.sum = self.sum.sub_wrapping(x)
349
0
        }
350
0
        self.count -= (values.len() - values.null_count()) as u64;
351
0
        Ok(())
352
0
    }
353
354
0
    fn supports_retract_batch(&self) -> bool {
355
0
        true
356
0
    }
357
}
358
359
struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
360
    values: HashSet<Hashable<T::Native>, RandomState>,
361
    data_type: DataType,
362
}
363
364
impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
365
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366
0
        write!(f, "DistinctSumAccumulator({})", self.data_type)
367
0
    }
368
}
369
370
impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
371
0
    pub fn try_new(data_type: &DataType) -> Result<Self> {
372
0
        Ok(Self {
373
0
            values: HashSet::default(),
374
0
            data_type: data_type.clone(),
375
0
        })
376
0
    }
377
}
378
379
impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
380
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
381
        // 1. Stores aggregate state in `ScalarValue::List`
382
        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
383
0
        let state_out = {
384
0
            let distinct_values = self
385
0
                .values
386
0
                .iter()
387
0
                .map(|value| {
388
0
                    ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
389
0
                })
390
0
                .collect::<Result<Vec<_>>>()?;
391
392
0
            vec![ScalarValue::List(ScalarValue::new_list_nullable(
393
0
                &distinct_values,
394
0
                &self.data_type,
395
0
            ))]
396
0
        };
397
0
        Ok(state_out)
398
0
    }
399
400
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
401
0
        if values.is_empty() {
402
0
            return Ok(());
403
0
        }
404
0
405
0
        let array = values[0].as_primitive::<T>();
406
0
        match array.nulls().filter(|x| x.null_count() > 0) {
407
0
            Some(n) => {
408
0
                for idx in n.valid_indices() {
409
0
                    self.values.insert(Hashable(array.value(idx)));
410
0
                }
411
            }
412
0
            None => array.values().iter().for_each(|x| {
413
0
                self.values.insert(Hashable(*x));
414
0
            }),
415
        }
416
0
        Ok(())
417
0
    }
418
419
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
420
0
        for x in states[0].as_list::<i32>().iter().flatten() {
421
0
            self.update_batch(&[x])?
422
        }
423
0
        Ok(())
424
0
    }
425
426
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
427
0
        let mut acc = T::Native::usize_as(0);
428
0
        for distinct_value in self.values.iter() {
429
0
            acc = acc.add_wrapping(distinct_value.0)
430
        }
431
0
        let v = (!self.values.is_empty()).then_some(acc);
432
0
        ScalarValue::new_primitive::<T>(v, &self.data_type)
433
0
    }
434
435
0
    fn size(&self) -> usize {
436
0
        std::mem::size_of_val(self)
437
0
            + self.values.capacity() * std::mem::size_of::<T::Native>()
438
0
    }
439
}