Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/test/function_stub.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Aggregate function stubs for test in expr / optimizer.
19
//!
20
//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate
21
22
use std::any::Any;
23
24
use arrow::datatypes::{
25
    DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26
};
27
28
use datafusion_common::{exec_err, not_impl_err, Result};
29
30
use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
31
use crate::Volatility::Immutable;
32
use crate::{
33
    expr::AggregateFunction,
34
    function::{AccumulatorArgs, StateFieldsArgs},
35
    utils::AggregateOrderSensitivity,
36
    Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
37
    Volatility,
38
};
39
40
macro_rules! create_func {
41
    ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
42
        paste::paste! {
43
            /// Singleton instance of [$UDAF], ensures the UDAF is only created once
44
            /// named STATIC_$(UDAF). For example `STATIC_FirstValue`
45
            #[allow(non_upper_case_globals)]
46
            static [< STATIC_ $UDAF >]: std::sync::OnceLock<std::sync::Arc<crate::AggregateUDF>> =
47
                std::sync::OnceLock::new();
48
49
            #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
50
0
            pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
51
0
                [< STATIC_ $UDAF >]
52
0
                    .get_or_init(|| {
53
0
                        std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
54
0
                    })
55
0
                    .clone()
56
0
            }
57
        }
58
    }
59
}
60
61
create_func!(Sum, sum_udaf);
62
63
0
pub fn sum(expr: Expr) -> Expr {
64
0
    Expr::AggregateFunction(AggregateFunction::new_udf(
65
0
        sum_udaf(),
66
0
        vec![expr],
67
0
        false,
68
0
        None,
69
0
        None,
70
0
        None,
71
0
    ))
72
0
}
73
74
create_func!(Count, count_udaf);
75
76
0
pub fn count(expr: Expr) -> Expr {
77
0
    Expr::AggregateFunction(AggregateFunction::new_udf(
78
0
        count_udaf(),
79
0
        vec![expr],
80
0
        false,
81
0
        None,
82
0
        None,
83
0
        None,
84
0
    ))
85
0
}
86
87
create_func!(Avg, avg_udaf);
88
89
0
pub fn avg(expr: Expr) -> Expr {
90
0
    Expr::AggregateFunction(AggregateFunction::new_udf(
91
0
        avg_udaf(),
92
0
        vec![expr],
93
0
        false,
94
0
        None,
95
0
        None,
96
0
        None,
97
0
    ))
98
0
}
99
100
/// Stub `sum` used for optimizer testing
101
#[derive(Debug)]
102
pub struct Sum {
103
    signature: Signature,
104
}
105
106
impl Sum {
107
0
    pub fn new() -> Self {
108
0
        Self {
109
0
            signature: Signature::user_defined(Volatility::Immutable),
110
0
        }
111
0
    }
112
}
113
114
impl Default for Sum {
115
0
    fn default() -> Self {
116
0
        Self::new()
117
0
    }
118
}
119
120
impl AggregateUDFImpl for Sum {
121
0
    fn as_any(&self) -> &dyn Any {
122
0
        self
123
0
    }
124
125
0
    fn name(&self) -> &str {
126
0
        "sum"
127
0
    }
128
129
0
    fn signature(&self) -> &Signature {
130
0
        &self.signature
131
0
    }
132
133
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
134
0
        if arg_types.len() != 1 {
135
0
            return exec_err!("SUM expects exactly one argument");
136
0
        }
137
138
        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
139
        // smallint, int, bigint, real, double precision, decimal, or interval.
140
141
0
        fn coerced_type(data_type: &DataType) -> Result<DataType> {
142
0
            match data_type {
143
0
                DataType::Dictionary(_, v) => coerced_type(v),
144
                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
145
                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
146
                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
147
0
                    Ok(data_type.clone())
148
                }
149
0
                dt if dt.is_signed_integer() => Ok(DataType::Int64),
150
0
                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
151
0
                dt if dt.is_floating() => Ok(DataType::Float64),
152
0
                _ => exec_err!("Sum not supported for {}", data_type),
153
            }
154
0
        }
155
156
0
        Ok(vec![coerced_type(&arg_types[0])?])
157
0
    }
158
159
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
160
0
        match &arg_types[0] {
161
0
            DataType::Int64 => Ok(DataType::Int64),
162
0
            DataType::UInt64 => Ok(DataType::UInt64),
163
0
            DataType::Float64 => Ok(DataType::Float64),
164
0
            DataType::Decimal128(precision, scale) => {
165
0
                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
166
0
                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
167
0
                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
168
0
                Ok(DataType::Decimal128(new_precision, *scale))
169
            }
170
0
            DataType::Decimal256(precision, scale) => {
171
0
                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
172
0
                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
173
0
                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
174
0
                Ok(DataType::Decimal256(new_precision, *scale))
175
            }
176
0
            other => {
177
0
                exec_err!("[return_type] SUM not supported for {}", other)
178
            }
179
        }
180
0
    }
181
182
0
    fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
183
0
        unreachable!("stub should not have accumulate()")
184
    }
185
186
0
    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
187
0
        unreachable!("stub should not have state_fields()")
188
    }
189
190
0
    fn aliases(&self) -> &[String] {
191
0
        &[]
192
0
    }
193
194
0
    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
195
0
        false
196
0
    }
197
198
0
    fn create_groups_accumulator(
199
0
        &self,
200
0
        _args: AccumulatorArgs,
201
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
202
0
        unreachable!("stub should not have accumulate()")
203
    }
204
205
0
    fn reverse_expr(&self) -> ReversedUDAF {
206
0
        ReversedUDAF::Identical
207
0
    }
208
209
0
    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
210
0
        AggregateOrderSensitivity::Insensitive
211
0
    }
212
}
213
214
/// Testing stub implementation of COUNT aggregate
215
pub struct Count {
216
    signature: Signature,
217
    aliases: Vec<String>,
218
}
219
220
impl std::fmt::Debug for Count {
221
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
222
0
        f.debug_struct("Count")
223
0
            .field("name", &self.name())
224
0
            .field("signature", &self.signature)
225
0
            .finish()
226
0
    }
227
}
228
229
impl Default for Count {
230
0
    fn default() -> Self {
231
0
        Self::new()
232
0
    }
233
}
234
235
impl Count {
236
0
    pub fn new() -> Self {
237
0
        Self {
238
0
            aliases: vec!["count".to_string()],
239
0
            signature: Signature::variadic_any(Volatility::Immutable),
240
0
        }
241
0
    }
242
}
243
244
impl AggregateUDFImpl for Count {
245
0
    fn as_any(&self) -> &dyn std::any::Any {
246
0
        self
247
0
    }
248
249
0
    fn name(&self) -> &str {
250
0
        "COUNT"
251
0
    }
252
253
0
    fn signature(&self) -> &Signature {
254
0
        &self.signature
255
0
    }
256
257
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
258
0
        Ok(DataType::Int64)
259
0
    }
260
261
0
    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
262
0
        not_impl_err!("no impl for stub")
263
0
    }
264
265
0
    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
266
0
        not_impl_err!("no impl for stub")
267
0
    }
268
269
0
    fn aliases(&self) -> &[String] {
270
0
        &self.aliases
271
0
    }
272
273
0
    fn create_groups_accumulator(
274
0
        &self,
275
0
        _args: AccumulatorArgs,
276
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
277
0
        not_impl_err!("no impl for stub")
278
0
    }
279
280
0
    fn reverse_expr(&self) -> ReversedUDAF {
281
0
        ReversedUDAF::Identical
282
0
    }
283
}
284
285
create_func!(Min, min_udaf);
286
287
0
pub fn min(expr: Expr) -> Expr {
288
0
    Expr::AggregateFunction(AggregateFunction::new_udf(
289
0
        min_udaf(),
290
0
        vec![expr],
291
0
        false,
292
0
        None,
293
0
        None,
294
0
        None,
295
0
    ))
296
0
}
297
298
/// Testing stub implementation of Min aggregate
299
pub struct Min {
300
    signature: Signature,
301
}
302
303
impl std::fmt::Debug for Min {
304
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
305
0
        f.debug_struct("Min")
306
0
            .field("name", &self.name())
307
0
            .field("signature", &self.signature)
308
0
            .finish()
309
0
    }
310
}
311
312
impl Default for Min {
313
0
    fn default() -> Self {
314
0
        Self::new()
315
0
    }
316
}
317
318
impl Min {
319
0
    pub fn new() -> Self {
320
0
        Self {
321
0
            signature: Signature::variadic_any(Volatility::Immutable),
322
0
        }
323
0
    }
324
}
325
326
impl AggregateUDFImpl for Min {
327
0
    fn as_any(&self) -> &dyn std::any::Any {
328
0
        self
329
0
    }
330
331
0
    fn name(&self) -> &str {
332
0
        "min"
333
0
    }
334
335
0
    fn signature(&self) -> &Signature {
336
0
        &self.signature
337
0
    }
338
339
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
340
0
        Ok(DataType::Int64)
341
0
    }
342
343
0
    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
344
0
        not_impl_err!("no impl for stub")
345
0
    }
346
347
0
    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
348
0
        not_impl_err!("no impl for stub")
349
0
    }
350
351
0
    fn aliases(&self) -> &[String] {
352
0
        &[]
353
0
    }
354
355
0
    fn create_groups_accumulator(
356
0
        &self,
357
0
        _args: AccumulatorArgs,
358
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
359
0
        not_impl_err!("no impl for stub")
360
0
    }
361
362
0
    fn reverse_expr(&self) -> ReversedUDAF {
363
0
        ReversedUDAF::Identical
364
0
    }
365
0
    fn is_descending(&self) -> Option<bool> {
366
0
        Some(false)
367
0
    }
368
}
369
370
create_func!(Max, max_udaf);
371
372
0
pub fn max(expr: Expr) -> Expr {
373
0
    Expr::AggregateFunction(AggregateFunction::new_udf(
374
0
        max_udaf(),
375
0
        vec![expr],
376
0
        false,
377
0
        None,
378
0
        None,
379
0
        None,
380
0
    ))
381
0
}
382
383
/// Testing stub implementation of MAX aggregate
384
pub struct Max {
385
    signature: Signature,
386
}
387
388
impl std::fmt::Debug for Max {
389
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
390
0
        f.debug_struct("Max")
391
0
            .field("name", &self.name())
392
0
            .field("signature", &self.signature)
393
0
            .finish()
394
0
    }
395
}
396
397
impl Default for Max {
398
0
    fn default() -> Self {
399
0
        Self::new()
400
0
    }
401
}
402
403
impl Max {
404
0
    pub fn new() -> Self {
405
0
        Self {
406
0
            signature: Signature::variadic_any(Volatility::Immutable),
407
0
        }
408
0
    }
409
}
410
411
impl AggregateUDFImpl for Max {
412
0
    fn as_any(&self) -> &dyn std::any::Any {
413
0
        self
414
0
    }
415
416
0
    fn name(&self) -> &str {
417
0
        "max"
418
0
    }
419
420
0
    fn signature(&self) -> &Signature {
421
0
        &self.signature
422
0
    }
423
424
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
425
0
        Ok(DataType::Int64)
426
0
    }
427
428
0
    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
429
0
        not_impl_err!("no impl for stub")
430
0
    }
431
432
0
    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
433
0
        not_impl_err!("no impl for stub")
434
0
    }
435
436
0
    fn aliases(&self) -> &[String] {
437
0
        &[]
438
0
    }
439
440
0
    fn create_groups_accumulator(
441
0
        &self,
442
0
        _args: AccumulatorArgs,
443
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
444
0
        not_impl_err!("no impl for stub")
445
0
    }
446
447
0
    fn reverse_expr(&self) -> ReversedUDAF {
448
0
        ReversedUDAF::Identical
449
0
    }
450
0
    fn is_descending(&self) -> Option<bool> {
451
0
        Some(true)
452
0
    }
453
}
454
455
/// Testing stub implementation of avg aggregate
456
#[derive(Debug)]
457
pub struct Avg {
458
    signature: Signature,
459
    aliases: Vec<String>,
460
}
461
462
impl Avg {
463
0
    pub fn new() -> Self {
464
0
        Self {
465
0
            aliases: vec![String::from("mean")],
466
0
            signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
467
0
        }
468
0
    }
469
}
470
471
impl Default for Avg {
472
0
    fn default() -> Self {
473
0
        Self::new()
474
0
    }
475
}
476
477
impl AggregateUDFImpl for Avg {
478
0
    fn as_any(&self) -> &dyn Any {
479
0
        self
480
0
    }
481
482
0
    fn name(&self) -> &str {
483
0
        "avg"
484
0
    }
485
486
0
    fn signature(&self) -> &Signature {
487
0
        &self.signature
488
0
    }
489
490
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
491
0
        avg_return_type(self.name(), &arg_types[0])
492
0
    }
493
494
0
    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
495
0
        not_impl_err!("no impl for stub")
496
0
    }
497
498
0
    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
499
0
        not_impl_err!("no impl for stub")
500
0
    }
501
0
    fn aliases(&self) -> &[String] {
502
0
        &self.aliases
503
0
    }
504
505
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
506
0
        coerce_avg_type(self.name(), arg_types)
507
0
    }
508
}