Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/udaf.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20
use std::any::Any;
21
use std::cmp::Ordering;
22
use std::fmt::{self, Debug, Formatter};
23
use std::hash::{DefaultHasher, Hash, Hasher};
24
use std::sync::Arc;
25
use std::vec;
26
27
use arrow::datatypes::{DataType, Field};
28
29
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32
use crate::expr::AggregateFunction;
33
use crate::function::{
34
    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
35
};
36
use crate::groups_accumulator::GroupsAccumulator;
37
use crate::utils::format_state_name;
38
use crate::utils::AggregateOrderSensitivity;
39
use crate::{Accumulator, Expr};
40
use crate::{Documentation, Signature};
41
42
/// Logical representation of a user-defined [aggregate function] (UDAF).
43
///
44
/// An aggregate function combines the values from multiple input rows
45
/// into a single output "aggregate" (summary) row. It is different
46
/// from a scalar function because it is stateful across batches. User
47
/// defined aggregate functions can be used as normal SQL aggregate
48
/// functions (`GROUP BY` clause) as well as window functions (`OVER`
49
/// clause).
50
///
51
/// `AggregateUDF` provides DataFusion the information needed to plan and call
52
/// aggregate functions, including name, type information, and a factory
53
/// function to create an [`Accumulator`] instance, to perform the actual
54
/// aggregation.
55
///
56
/// For more information, please see [the examples]:
57
///
58
/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
59
///
60
/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
61
///    access (examples in [`advanced_udaf.rs`]).
62
///
63
/// # API Note
64
/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
65
/// compatibility with the older API.
66
///
67
/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
68
/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
69
/// [`Accumulator`]: crate::Accumulator
70
/// [`create_udaf`]: crate::expr_fn::create_udaf
71
/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
72
/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
73
#[derive(Debug, Clone, PartialOrd)]
74
pub struct AggregateUDF {
75
    inner: Arc<dyn AggregateUDFImpl>,
76
}
77
78
impl PartialEq for AggregateUDF {
79
0
    fn eq(&self, other: &Self) -> bool {
80
0
        self.inner.equals(other.inner.as_ref())
81
0
    }
82
}
83
84
impl Eq for AggregateUDF {}
85
86
impl Hash for AggregateUDF {
87
0
    fn hash<H: Hasher>(&self, state: &mut H) {
88
0
        self.inner.hash_value().hash(state)
89
0
    }
90
}
91
92
impl fmt::Display for AggregateUDF {
93
1
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
94
1
        write!(f, "{}", self.name())
95
1
    }
96
}
97
98
/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
99
pub struct StatisticsArgs<'a> {
100
    /// The statistics of the aggregate input
101
    pub statistics: &'a Statistics,
102
    /// The resolved return type of the aggregate function
103
    pub return_type: &'a DataType,
104
    /// Whether the aggregate function is distinct.
105
    ///
106
    /// ```sql
107
    /// SELECT COUNT(DISTINCT column1) FROM t;
108
    /// ```
109
    pub is_distinct: bool,
110
    /// The physical expression of arguments the aggregate function takes.
111
    pub exprs: &'a [Arc<dyn PhysicalExpr>],
112
}
113
114
impl AggregateUDF {
115
    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
116
    ///
117
    /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
118
7
    pub fn new_from_impl<F>(fun: F) -> AggregateUDF
119
7
    where
120
7
        F: AggregateUDFImpl + 'static,
121
7
    {
122
7
        Self {
123
7
            inner: Arc::new(fun),
124
7
        }
125
7
    }
126
127
    /// Return the underlying [`AggregateUDFImpl`] trait object for this function
128
0
    pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
129
0
        &self.inner
130
0
    }
131
132
    /// Adds additional names that can be used to invoke this function, in
133
    /// addition to `name`
134
    ///
135
    /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
136
0
    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
137
0
        Self::new_from_impl(AliasedAggregateUDFImpl::new(
138
0
            Arc::clone(&self.inner),
139
0
            aliases,
140
0
        ))
141
0
    }
142
143
    /// creates an [`Expr`] that calls the aggregate function.
144
    ///
145
    /// This utility allows using the UDAF without requiring access to
146
    /// the registry, such as with the DataFrame API.
147
0
    pub fn call(&self, args: Vec<Expr>) -> Expr {
148
0
        Expr::AggregateFunction(AggregateFunction::new_udf(
149
0
            Arc::new(self.clone()),
150
0
            args,
151
0
            false,
152
0
            None,
153
0
            None,
154
0
            None,
155
0
        ))
156
0
    }
157
158
    /// Returns this function's name
159
    ///
160
    /// See [`AggregateUDFImpl::name`] for more details.
161
49
    pub fn name(&self) -> &str {
162
49
        self.inner.name()
163
49
    }
164
165
36
    pub fn is_nullable(&self) -> bool {
166
36
        self.inner.is_nullable()
167
36
    }
168
169
    /// Returns the aliases for this function.
170
0
    pub fn aliases(&self) -> &[String] {
171
0
        self.inner.aliases()
172
0
    }
173
174
    /// Returns this function's signature (what input types are accepted)
175
    ///
176
    /// See [`AggregateUDFImpl::signature`] for more details.
177
36
    pub fn signature(&self) -> &Signature {
178
36
        self.inner.signature()
179
36
    }
180
181
    /// Return the type of the function given its input types
182
    ///
183
    /// See [`AggregateUDFImpl::return_type`] for more details.
184
36
    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
185
36
        self.inner.return_type(args)
186
36
    }
187
188
    /// Return an accumulator the given aggregate, given its return datatype
189
132
    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
190
132
        self.inner.accumulator(acc_args)
191
132
    }
192
193
    /// Return the fields used to store the intermediate state for this aggregator, given
194
    /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
195
    /// for more details.
196
    ///
197
    /// This is used to support multi-phase aggregations
198
112
    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
199
112
        self.inner.state_fields(args)
200
112
    }
201
202
    /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
203
70
    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
204
70
        self.inner.groups_accumulator_supported(args)
205
70
    }
206
207
    /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
208
30
    pub fn create_groups_accumulator(
209
30
        &self,
210
30
        args: AccumulatorArgs,
211
30
    ) -> Result<Box<dyn GroupsAccumulator>> {
212
30
        self.inner.create_groups_accumulator(args)
213
30
    }
214
215
3
    pub fn create_sliding_accumulator(
216
3
        &self,
217
3
        args: AccumulatorArgs,
218
3
    ) -> Result<Box<dyn Accumulator>> {
219
3
        self.inner.create_sliding_accumulator(args)
220
3
    }
221
222
0
    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
223
0
        self.inner.coerce_types(arg_types)
224
0
    }
225
226
    /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
227
0
    pub fn with_beneficial_ordering(
228
0
        self,
229
0
        beneficial_ordering: bool,
230
0
    ) -> Result<Option<AggregateUDF>> {
231
0
        self.inner
232
0
            .with_beneficial_ordering(beneficial_ordering)
233
0
            .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
234
0
    }
235
236
    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
237
    /// for possible options.
238
70
    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
239
70
        self.inner.order_sensitivity()
240
70
    }
241
242
    /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
243
    /// generate same result with this `AggregateUDF` when iterated in reverse
244
    /// order, and `None` if there is no such `AggregateUDF`).
245
3
    pub fn reverse_udf(&self) -> ReversedUDAF {
246
3
        self.inner.reverse_expr()
247
3
    }
248
249
    /// Do the function rewrite
250
    ///
251
    /// See [`AggregateUDFImpl::simplify`] for more details.
252
0
    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
253
0
        self.inner.simplify()
254
0
    }
255
256
    /// Returns true if the function is max, false if the function is min
257
    /// None in all other cases, used in certain optimizations for
258
    /// or aggregate
259
0
    pub fn is_descending(&self) -> Option<bool> {
260
0
        self.inner.is_descending()
261
0
    }
262
263
    /// Return the value of this aggregate function if it can be determined
264
    /// entirely from statistics and arguments.
265
    ///
266
    /// See [`AggregateUDFImpl::value_from_stats`] for more details.
267
0
    pub fn value_from_stats(
268
0
        &self,
269
0
        statistics_args: &StatisticsArgs,
270
0
    ) -> Option<ScalarValue> {
271
0
        self.inner.value_from_stats(statistics_args)
272
0
    }
273
274
    /// See [`AggregateUDFImpl::default_value`] for more details.
275
0
    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
276
0
        self.inner.default_value(data_type)
277
0
    }
278
279
    /// Returns the documentation for this Aggregate UDF.
280
    ///
281
    /// Documentation can be accessed programmatically as well as
282
    /// generating publicly facing documentation.
283
0
    pub fn documentation(&self) -> Option<&Documentation> {
284
0
        self.inner.documentation()
285
0
    }
286
}
287
288
impl<F> From<F> for AggregateUDF
289
where
290
    F: AggregateUDFImpl + Send + Sync + 'static,
291
{
292
7
    fn from(fun: F) -> Self {
293
7
        Self::new_from_impl(fun)
294
7
    }
295
}
296
297
/// Trait for implementing [`AggregateUDF`].
298
///
299
/// This trait exposes the full API for implementing user defined aggregate functions and
300
/// can be used to implement any function.
301
///
302
/// See [`advanced_udaf.rs`] for a full example with complete implementation and
303
/// [`AggregateUDF`] for other available options.
304
///
305
/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
306
///
307
/// # Basic Example
308
/// ```
309
/// # use std::any::Any;
310
/// # use std::sync::OnceLock;
311
/// # use arrow::datatypes::DataType;
312
/// # use datafusion_common::{DataFusionError, plan_err, Result};
313
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
314
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
315
/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
316
/// # use arrow::datatypes::Schema;
317
/// # use arrow::datatypes::Field;
318
///
319
/// #[derive(Debug, Clone)]
320
/// struct GeoMeanUdf {
321
///   signature: Signature,
322
/// }
323
///
324
/// impl GeoMeanUdf {
325
///   fn new() -> Self {
326
///     Self {
327
///       signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
328
///      }
329
///   }
330
/// }
331
///
332
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
333
///
334
/// fn get_doc() -> &'static Documentation {
335
///     DOCUMENTATION.get_or_init(|| {
336
///         Documentation::builder()
337
///             .with_doc_section(DOC_SECTION_AGGREGATE)
338
///             .with_description("calculates a geometric mean")
339
///             .with_syntax_example("geo_mean(2.0)")
340
///             .with_argument("arg1", "The Float64 number for the geometric mean")
341
///             .build()
342
///             .unwrap()
343
///     })
344
/// }
345
///    
346
/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
347
/// impl AggregateUDFImpl for GeoMeanUdf {
348
///    fn as_any(&self) -> &dyn Any { self }
349
///    fn name(&self) -> &str { "geo_mean" }
350
///    fn signature(&self) -> &Signature { &self.signature }
351
///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
352
///      if !matches!(args.get(0), Some(&DataType::Float64)) {
353
///        return plan_err!("geo_mean only accepts Float64 arguments");
354
///      }
355
///      Ok(DataType::Float64)
356
///    }
357
///    // This is the accumulator factory; DataFusion uses it to create new accumulators.
358
///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
359
///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
360
///        Ok(vec![
361
///             Field::new("value", args.return_type.clone(), true),
362
///             Field::new("ordering", DataType::UInt32, true)
363
///        ])
364
///    }
365
///    fn documentation(&self) -> Option<&Documentation> {
366
///        Some(get_doc())  
367
///    }
368
/// }
369
///
370
/// // Create a new AggregateUDF from the implementation
371
/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
372
///
373
/// // Call the function `geo_mean(col)`
374
/// let expr = geometric_mean.call(vec![col("a")]);
375
/// ```
376
pub trait AggregateUDFImpl: Debug + Send + Sync {
377
    // Note: When adding any methods (with default implementations), remember to add them also
378
    // into the AliasedAggregateUDFImpl below!
379
380
    /// Returns this object as an [`Any`] trait object
381
    fn as_any(&self) -> &dyn Any;
382
383
    /// Returns this function's name
384
    fn name(&self) -> &str;
385
386
    /// Returns the function's [`Signature`] for information about what input
387
    /// types are accepted and the function's Volatility.
388
    fn signature(&self) -> &Signature;
389
390
    /// What [`DataType`] will be returned by this function, given the types of
391
    /// the arguments
392
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
393
394
    /// Whether the aggregate function is nullable.
395
    ///
396
    /// Nullable means that that the function could return `null` for any inputs.
397
    /// For example, aggregate functions like `COUNT` always return a non null value
398
    /// but others like `MIN` will return `NULL` if there is nullable input.
399
    /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
400
26
    fn is_nullable(&self) -> bool {
401
26
        true
402
26
    }
403
404
    /// Return a new [`Accumulator`] that aggregates values for a specific
405
    /// group during query execution.
406
    ///
407
    /// acc_args: [`AccumulatorArgs`] contains information about how the
408
    /// aggregate function was called.
409
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
410
411
    /// Return the fields used to store the intermediate state of this accumulator.
412
    ///
413
    /// See [`Accumulator::state`] for background information.
414
    ///
415
    /// args:  [`StateFieldsArgs`] contains arguments passed to the
416
    /// aggregate function's accumulator.
417
    ///
418
    /// # Notes:
419
    ///
420
    /// The default implementation returns a single state field named `name`
421
    /// with the same type as `value_type`. This is suitable for aggregates such
422
    /// as `SUM` or `MIN` where partial state can be combined by applying the
423
    /// same aggregate.
424
    ///
425
    /// For aggregates such as `AVG` where the partial state is more complex
426
    /// (e.g. a COUNT and a SUM), this method is used to define the additional
427
    /// fields.
428
    ///
429
    /// The name of the fields must be unique within the query and thus should
430
    /// be derived from `name`. See [`format_state_name`] for a utility function
431
    /// to generate a unique name.
432
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
433
0
        let fields = vec![Field::new(
434
0
            format_state_name(args.name, "value"),
435
0
            args.return_type.clone(),
436
0
            true,
437
0
        )];
438
0
439
0
        Ok(fields
440
0
            .into_iter()
441
0
            .chain(args.ordering_fields.to_vec())
442
0
            .collect())
443
0
    }
444
445
    /// If the aggregate expression has a specialized
446
    /// [`GroupsAccumulator`] implementation. If this returns true,
447
    /// `[Self::create_groups_accumulator]` will be called.
448
    ///
449
    /// # Notes
450
    ///
451
    /// Even if this function returns true, DataFusion will still use
452
    /// [`Self::accumulator`] for certain queries, such as when this aggregate is
453
    /// used as a window function or when there no GROUP BY columns in the
454
    /// query.
455
40
    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
456
40
        false
457
40
    }
458
459
    /// Return a specialized [`GroupsAccumulator`] that manages state
460
    /// for all groups.
461
    ///
462
    /// For maximum performance, a [`GroupsAccumulator`] should be
463
    /// implemented in addition to [`Accumulator`].
464
0
    fn create_groups_accumulator(
465
0
        &self,
466
0
        _args: AccumulatorArgs,
467
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
468
0
        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
469
0
    }
470
471
    /// Returns any aliases (alternate names) for this function.
472
    ///
473
    /// Note: `aliases` should only include names other than [`Self::name`].
474
    /// Defaults to `[]` (no aliases)
475
0
    fn aliases(&self) -> &[String] {
476
0
        &[]
477
0
    }
478
479
    /// Sliding accumulator is an alternative accumulator that can be used for
480
    /// window functions. It has retract method to revert the previous update.
481
    ///
482
    /// See [retract_batch] for more details.
483
    ///
484
    /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch
485
3
    fn create_sliding_accumulator(
486
3
        &self,
487
3
        args: AccumulatorArgs,
488
3
    ) -> Result<Box<dyn Accumulator>> {
489
3
        self.accumulator(args)
490
3
    }
491
492
    /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
493
    /// satisfied by its input. If this is not the case, UDFs with order
494
    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
495
    /// the correct result with possibly more work internally.
496
    ///
497
    /// # Returns
498
    ///
499
    /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
500
    /// If the expression can benefit from existing input ordering, but does
501
    /// not implement the method, returns an error. Order insensitive and hard
502
    /// requirement aggregators return `Ok(None)`.
503
0
    fn with_beneficial_ordering(
504
0
        self: Arc<Self>,
505
0
        _beneficial_ordering: bool,
506
0
    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
507
0
        if self.order_sensitivity().is_beneficial() {
508
0
            return exec_err!(
509
0
                "Should implement with satisfied for aggregator :{:?}",
510
0
                self.name()
511
0
            );
512
0
        }
513
0
        Ok(None)
514
0
    }
515
516
    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
517
    /// for possible options.
518
18
    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
519
18
        // We have hard ordering requirements by default, meaning that order
520
18
        // sensitive UDFs need their input orderings to satisfy their ordering
521
18
        // requirements to generate correct results.
522
18
        AggregateOrderSensitivity::HardRequirement
523
18
    }
524
525
    /// Optionally apply per-UDaF simplification / rewrite rules.
526
    ///
527
    /// This can be used to apply function specific simplification rules during
528
    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
529
    /// implementation does nothing.
530
    ///
531
    /// Note that DataFusion handles simplifying arguments and  "constant
532
    /// folding" (replacing a function call with constant arguments such as
533
    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
534
    /// optimizations manually for specific UDFs.
535
    ///
536
    /// # Returns
537
    ///
538
    /// [None] if simplify is not defined or,
539
    ///
540
    /// Or, a closure with two arguments:
541
    /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
542
    /// * 'info': [crate::simplify::SimplifyInfo]
543
    ///
544
    /// closure returns simplified [Expr] or an error.
545
    ///
546
0
    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
547
0
        None
548
0
    }
549
550
    /// Returns the reverse expression of the aggregate function.
551
0
    fn reverse_expr(&self) -> ReversedUDAF {
552
0
        ReversedUDAF::NotSupported
553
0
    }
554
555
    /// Coerce arguments of a function call to types that the function can evaluate.
556
    ///
557
    /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
558
    /// UDAFs should return one of the other variants of `TypeSignature` which handle common
559
    /// cases
560
    ///
561
    /// See the [type coercion module](crate::type_coercion)
562
    /// documentation for more details on type coercion
563
    ///
564
    /// For example, if your function requires a floating point arguments, but the user calls
565
    /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
566
    /// to ensure the argument was cast to `1::double`
567
    ///
568
    /// # Parameters
569
    /// * `arg_types`: The argument types of the arguments  this function with
570
    ///
571
    /// # Return value
572
    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
573
    /// arguments to these specific types.
574
0
    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
575
0
        not_impl_err!("Function {} does not implement coerce_types", self.name())
576
0
    }
577
578
    /// Return true if this aggregate UDF is equal to the other.
579
    ///
580
    /// Allows customizing the equality of aggregate UDFs.
581
    /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
582
    ///
583
    /// - reflexive: `a.equals(a)`;
584
    /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
585
    /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
586
    ///
587
    /// By default, compares [`Self::name`] and [`Self::signature`].
588
0
    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
589
0
        self.name() == other.name() && self.signature() == other.signature()
590
0
    }
591
592
    /// Returns a hash value for this aggregate UDF.
593
    ///
594
    /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
595
    /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
596
    ///
597
    /// By default, hashes [`Self::name`] and [`Self::signature`].
598
0
    fn hash_value(&self) -> u64 {
599
0
        let hasher = &mut DefaultHasher::new();
600
0
        self.name().hash(hasher);
601
0
        self.signature().hash(hasher);
602
0
        hasher.finish()
603
0
    }
604
605
    /// If this function is max, return true
606
    /// if the function is min, return false
607
    /// otherwise return None (the default)
608
    ///
609
    ///
610
    /// Note: this is used to use special aggregate implementations in certain conditions
611
0
    fn is_descending(&self) -> Option<bool> {
612
0
        None
613
0
    }
614
615
    /// Return the value of this aggregate function if it can be determined
616
    /// entirely from statistics and arguments.
617
    ///
618
    /// Using a [`ScalarValue`] rather than a runtime computation can significantly
619
    /// improving query performance.
620
    ///
621
    /// For example, if the minimum value of column `x` is known to be `42` from
622
    /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
623
0
    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
624
0
        None
625
0
    }
626
627
    /// Returns default value of the function given the input is all `null`.
628
    ///
629
    /// Most of the aggregate function return Null if input is Null,
630
    /// while `count` returns 0 if input is Null
631
0
    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
632
0
        ScalarValue::try_from(data_type)
633
0
    }
634
635
    /// Returns the documentation for this Aggregate UDF.
636
    ///
637
    /// Documentation can be accessed programmatically as well as
638
    /// generating publicly facing documentation.
639
0
    fn documentation(&self) -> Option<&Documentation> {
640
0
        None
641
0
    }
642
}
643
644
impl PartialEq for dyn AggregateUDFImpl {
645
0
    fn eq(&self, other: &Self) -> bool {
646
0
        self.equals(other)
647
0
    }
648
}
649
650
// manual implementation of `PartialOrd`
651
// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
652
// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
653
impl PartialOrd for dyn AggregateUDFImpl {
654
0
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
655
0
        match self.name().partial_cmp(other.name()) {
656
0
            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
657
0
            cmp => cmp,
658
        }
659
0
    }
660
}
661
662
pub enum ReversedUDAF {
663
    /// The expression is the same as the original expression, like SUM, COUNT
664
    Identical,
665
    /// The expression does not support reverse calculation
666
    NotSupported,
667
    /// The expression is different from the original expression
668
    Reversed(Arc<AggregateUDF>),
669
}
670
671
/// AggregateUDF that adds an alias to the underlying function. It is better to
672
/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
673
#[derive(Debug)]
674
struct AliasedAggregateUDFImpl {
675
    inner: Arc<dyn AggregateUDFImpl>,
676
    aliases: Vec<String>,
677
}
678
679
impl AliasedAggregateUDFImpl {
680
0
    pub fn new(
681
0
        inner: Arc<dyn AggregateUDFImpl>,
682
0
        new_aliases: impl IntoIterator<Item = &'static str>,
683
0
    ) -> Self {
684
0
        let mut aliases = inner.aliases().to_vec();
685
0
        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
686
0
687
0
        Self { inner, aliases }
688
0
    }
689
}
690
691
impl AggregateUDFImpl for AliasedAggregateUDFImpl {
692
0
    fn as_any(&self) -> &dyn Any {
693
0
        self
694
0
    }
695
696
0
    fn name(&self) -> &str {
697
0
        self.inner.name()
698
0
    }
699
700
0
    fn signature(&self) -> &Signature {
701
0
        self.inner.signature()
702
0
    }
703
704
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
705
0
        self.inner.return_type(arg_types)
706
0
    }
707
708
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
709
0
        self.inner.accumulator(acc_args)
710
0
    }
711
712
0
    fn aliases(&self) -> &[String] {
713
0
        &self.aliases
714
0
    }
715
716
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
717
0
        self.inner.state_fields(args)
718
0
    }
719
720
0
    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
721
0
        self.inner.groups_accumulator_supported(args)
722
0
    }
723
724
0
    fn create_groups_accumulator(
725
0
        &self,
726
0
        args: AccumulatorArgs,
727
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
728
0
        self.inner.create_groups_accumulator(args)
729
0
    }
730
731
0
    fn create_sliding_accumulator(
732
0
        &self,
733
0
        args: AccumulatorArgs,
734
0
    ) -> Result<Box<dyn Accumulator>> {
735
0
        self.inner.accumulator(args)
736
0
    }
737
738
0
    fn with_beneficial_ordering(
739
0
        self: Arc<Self>,
740
0
        beneficial_ordering: bool,
741
0
    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
742
0
        Arc::clone(&self.inner)
743
0
            .with_beneficial_ordering(beneficial_ordering)
744
0
            .map(|udf| {
745
0
                udf.map(|udf| {
746
0
                    Arc::new(AliasedAggregateUDFImpl {
747
0
                        inner: udf,
748
0
                        aliases: self.aliases.clone(),
749
0
                    }) as Arc<dyn AggregateUDFImpl>
750
0
                })
751
0
            })
752
0
    }
753
754
0
    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
755
0
        self.inner.order_sensitivity()
756
0
    }
757
758
0
    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
759
0
        self.inner.simplify()
760
0
    }
761
762
0
    fn reverse_expr(&self) -> ReversedUDAF {
763
0
        self.inner.reverse_expr()
764
0
    }
765
766
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
767
0
        self.inner.coerce_types(arg_types)
768
0
    }
769
770
0
    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
771
0
        if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
772
0
            self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
773
        } else {
774
0
            false
775
        }
776
0
    }
777
778
0
    fn hash_value(&self) -> u64 {
779
0
        let hasher = &mut DefaultHasher::new();
780
0
        self.inner.hash_value().hash(hasher);
781
0
        self.aliases.hash(hasher);
782
0
        hasher.finish()
783
0
    }
784
785
0
    fn is_descending(&self) -> Option<bool> {
786
0
        self.inner.is_descending()
787
0
    }
788
789
0
    fn documentation(&self) -> Option<&Documentation> {
790
0
        self.inner.documentation()
791
0
    }
792
}
793
794
// Aggregate UDF doc sections for use in public documentation
795
pub mod aggregate_doc_sections {
796
    use crate::DocSection;
797
798
    pub fn doc_sections() -> Vec<DocSection> {
799
        vec![
800
            DOC_SECTION_GENERAL,
801
            DOC_SECTION_STATISTICAL,
802
            DOC_SECTION_APPROXIMATE,
803
        ]
804
    }
805
806
    pub const DOC_SECTION_GENERAL: DocSection = DocSection {
807
        include: true,
808
        label: "General Functions",
809
        description: None,
810
    };
811
812
    pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
813
        include: true,
814
        label: "Statistical Functions",
815
        description: None,
816
    };
817
818
    pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
819
        include: true,
820
        label: "Approximate Functions",
821
        description: None,
822
    };
823
}
824
825
#[cfg(test)]
826
mod test {
827
    use crate::{AggregateUDF, AggregateUDFImpl};
828
    use arrow::datatypes::{DataType, Field};
829
    use datafusion_common::Result;
830
    use datafusion_expr_common::accumulator::Accumulator;
831
    use datafusion_expr_common::signature::{Signature, Volatility};
832
    use datafusion_functions_aggregate_common::accumulator::{
833
        AccumulatorArgs, StateFieldsArgs,
834
    };
835
    use std::any::Any;
836
    use std::cmp::Ordering;
837
838
    #[derive(Debug, Clone)]
839
    struct AMeanUdf {
840
        signature: Signature,
841
    }
842
843
    impl AMeanUdf {
844
        fn new() -> Self {
845
            Self {
846
                signature: Signature::uniform(
847
                    1,
848
                    vec![DataType::Float64],
849
                    Volatility::Immutable,
850
                ),
851
            }
852
        }
853
    }
854
855
    impl AggregateUDFImpl for AMeanUdf {
856
        fn as_any(&self) -> &dyn Any {
857
            self
858
        }
859
        fn name(&self) -> &str {
860
            "a"
861
        }
862
        fn signature(&self) -> &Signature {
863
            &self.signature
864
        }
865
        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
866
            unimplemented!()
867
        }
868
        fn accumulator(
869
            &self,
870
            _acc_args: AccumulatorArgs,
871
        ) -> Result<Box<dyn Accumulator>> {
872
            unimplemented!()
873
        }
874
        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
875
            unimplemented!()
876
        }
877
    }
878
879
    #[derive(Debug, Clone)]
880
    struct BMeanUdf {
881
        signature: Signature,
882
    }
883
    impl BMeanUdf {
884
        fn new() -> Self {
885
            Self {
886
                signature: Signature::uniform(
887
                    1,
888
                    vec![DataType::Float64],
889
                    Volatility::Immutable,
890
                ),
891
            }
892
        }
893
    }
894
895
    impl AggregateUDFImpl for BMeanUdf {
896
        fn as_any(&self) -> &dyn Any {
897
            self
898
        }
899
        fn name(&self) -> &str {
900
            "b"
901
        }
902
        fn signature(&self) -> &Signature {
903
            &self.signature
904
        }
905
        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
906
            unimplemented!()
907
        }
908
        fn accumulator(
909
            &self,
910
            _acc_args: AccumulatorArgs,
911
        ) -> Result<Box<dyn Accumulator>> {
912
            unimplemented!()
913
        }
914
        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
915
            unimplemented!()
916
        }
917
    }
918
919
    #[test]
920
    fn test_partial_ord() {
921
        // Test validates that partial ord is defined for AggregateUDF using the name and signature,
922
        // not intended to exhaustively test all possibilities
923
        let a1 = AggregateUDF::from(AMeanUdf::new());
924
        let a2 = AggregateUDF::from(AMeanUdf::new());
925
        assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
926
927
        let b1 = AggregateUDF::from(BMeanUdf::new());
928
        assert!(a1 < b1);
929
        assert!(!(a1 == b1));
930
    }
931
}