/Users/andrewlamb/Software/datafusion/datafusion/expr/src/udaf.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`AggregateUDF`]: User Defined Aggregate Functions |
19 | | |
20 | | use std::any::Any; |
21 | | use std::cmp::Ordering; |
22 | | use std::fmt::{self, Debug, Formatter}; |
23 | | use std::hash::{DefaultHasher, Hash, Hasher}; |
24 | | use std::sync::Arc; |
25 | | use std::vec; |
26 | | |
27 | | use arrow::datatypes::{DataType, Field}; |
28 | | |
29 | | use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; |
30 | | use datafusion_physical_expr_common::physical_expr::PhysicalExpr; |
31 | | |
32 | | use crate::expr::AggregateFunction; |
33 | | use crate::function::{ |
34 | | AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, |
35 | | }; |
36 | | use crate::groups_accumulator::GroupsAccumulator; |
37 | | use crate::utils::format_state_name; |
38 | | use crate::utils::AggregateOrderSensitivity; |
39 | | use crate::{Accumulator, Expr}; |
40 | | use crate::{Documentation, Signature}; |
41 | | |
42 | | /// Logical representation of a user-defined [aggregate function] (UDAF). |
43 | | /// |
44 | | /// An aggregate function combines the values from multiple input rows |
45 | | /// into a single output "aggregate" (summary) row. It is different |
46 | | /// from a scalar function because it is stateful across batches. User |
47 | | /// defined aggregate functions can be used as normal SQL aggregate |
48 | | /// functions (`GROUP BY` clause) as well as window functions (`OVER` |
49 | | /// clause). |
50 | | /// |
51 | | /// `AggregateUDF` provides DataFusion the information needed to plan and call |
52 | | /// aggregate functions, including name, type information, and a factory |
53 | | /// function to create an [`Accumulator`] instance, to perform the actual |
54 | | /// aggregation. |
55 | | /// |
56 | | /// For more information, please see [the examples]: |
57 | | /// |
58 | | /// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]). |
59 | | /// |
60 | | /// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API |
61 | | /// access (examples in [`advanced_udaf.rs`]). |
62 | | /// |
63 | | /// # API Note |
64 | | /// This is a separate struct from `AggregateUDFImpl` to maintain backwards |
65 | | /// compatibility with the older API. |
66 | | /// |
67 | | /// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process |
68 | | /// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function |
69 | | /// [`Accumulator`]: crate::Accumulator |
70 | | /// [`create_udaf`]: crate::expr_fn::create_udaf |
71 | | /// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs |
72 | | /// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs |
73 | | #[derive(Debug, Clone, PartialOrd)] |
74 | | pub struct AggregateUDF { |
75 | | inner: Arc<dyn AggregateUDFImpl>, |
76 | | } |
77 | | |
78 | | impl PartialEq for AggregateUDF { |
79 | 0 | fn eq(&self, other: &Self) -> bool { |
80 | 0 | self.inner.equals(other.inner.as_ref()) |
81 | 0 | } |
82 | | } |
83 | | |
84 | | impl Eq for AggregateUDF {} |
85 | | |
86 | | impl Hash for AggregateUDF { |
87 | 0 | fn hash<H: Hasher>(&self, state: &mut H) { |
88 | 0 | self.inner.hash_value().hash(state) |
89 | 0 | } |
90 | | } |
91 | | |
92 | | impl fmt::Display for AggregateUDF { |
93 | 1 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { |
94 | 1 | write!(f, "{}", self.name()) |
95 | 1 | } |
96 | | } |
97 | | |
98 | | /// Arguments passed to [`AggregateUDFImpl::value_from_stats`] |
99 | | pub struct StatisticsArgs<'a> { |
100 | | /// The statistics of the aggregate input |
101 | | pub statistics: &'a Statistics, |
102 | | /// The resolved return type of the aggregate function |
103 | | pub return_type: &'a DataType, |
104 | | /// Whether the aggregate function is distinct. |
105 | | /// |
106 | | /// ```sql |
107 | | /// SELECT COUNT(DISTINCT column1) FROM t; |
108 | | /// ``` |
109 | | pub is_distinct: bool, |
110 | | /// The physical expression of arguments the aggregate function takes. |
111 | | pub exprs: &'a [Arc<dyn PhysicalExpr>], |
112 | | } |
113 | | |
114 | | impl AggregateUDF { |
115 | | /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object |
116 | | /// |
117 | | /// Note this is the same as using the `From` impl (`AggregateUDF::from`) |
118 | 7 | pub fn new_from_impl<F>(fun: F) -> AggregateUDF |
119 | 7 | where |
120 | 7 | F: AggregateUDFImpl + 'static, |
121 | 7 | { |
122 | 7 | Self { |
123 | 7 | inner: Arc::new(fun), |
124 | 7 | } |
125 | 7 | } |
126 | | |
127 | | /// Return the underlying [`AggregateUDFImpl`] trait object for this function |
128 | 0 | pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> { |
129 | 0 | &self.inner |
130 | 0 | } |
131 | | |
132 | | /// Adds additional names that can be used to invoke this function, in |
133 | | /// addition to `name` |
134 | | /// |
135 | | /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly. |
136 | 0 | pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self { |
137 | 0 | Self::new_from_impl(AliasedAggregateUDFImpl::new( |
138 | 0 | Arc::clone(&self.inner), |
139 | 0 | aliases, |
140 | 0 | )) |
141 | 0 | } |
142 | | |
143 | | /// creates an [`Expr`] that calls the aggregate function. |
144 | | /// |
145 | | /// This utility allows using the UDAF without requiring access to |
146 | | /// the registry, such as with the DataFrame API. |
147 | 0 | pub fn call(&self, args: Vec<Expr>) -> Expr { |
148 | 0 | Expr::AggregateFunction(AggregateFunction::new_udf( |
149 | 0 | Arc::new(self.clone()), |
150 | 0 | args, |
151 | 0 | false, |
152 | 0 | None, |
153 | 0 | None, |
154 | 0 | None, |
155 | 0 | )) |
156 | 0 | } |
157 | | |
158 | | /// Returns this function's name |
159 | | /// |
160 | | /// See [`AggregateUDFImpl::name`] for more details. |
161 | 49 | pub fn name(&self) -> &str { |
162 | 49 | self.inner.name() |
163 | 49 | } |
164 | | |
165 | 36 | pub fn is_nullable(&self) -> bool { |
166 | 36 | self.inner.is_nullable() |
167 | 36 | } |
168 | | |
169 | | /// Returns the aliases for this function. |
170 | 0 | pub fn aliases(&self) -> &[String] { |
171 | 0 | self.inner.aliases() |
172 | 0 | } |
173 | | |
174 | | /// Returns this function's signature (what input types are accepted) |
175 | | /// |
176 | | /// See [`AggregateUDFImpl::signature`] for more details. |
177 | 36 | pub fn signature(&self) -> &Signature { |
178 | 36 | self.inner.signature() |
179 | 36 | } |
180 | | |
181 | | /// Return the type of the function given its input types |
182 | | /// |
183 | | /// See [`AggregateUDFImpl::return_type`] for more details. |
184 | 36 | pub fn return_type(&self, args: &[DataType]) -> Result<DataType> { |
185 | 36 | self.inner.return_type(args) |
186 | 36 | } |
187 | | |
188 | | /// Return an accumulator the given aggregate, given its return datatype |
189 | 132 | pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
190 | 132 | self.inner.accumulator(acc_args) |
191 | 132 | } |
192 | | |
193 | | /// Return the fields used to store the intermediate state for this aggregator, given |
194 | | /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`] |
195 | | /// for more details. |
196 | | /// |
197 | | /// This is used to support multi-phase aggregations |
198 | 112 | pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
199 | 112 | self.inner.state_fields(args) |
200 | 112 | } |
201 | | |
202 | | /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. |
203 | 70 | pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { |
204 | 70 | self.inner.groups_accumulator_supported(args) |
205 | 70 | } |
206 | | |
207 | | /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. |
208 | 30 | pub fn create_groups_accumulator( |
209 | 30 | &self, |
210 | 30 | args: AccumulatorArgs, |
211 | 30 | ) -> Result<Box<dyn GroupsAccumulator>> { |
212 | 30 | self.inner.create_groups_accumulator(args) |
213 | 30 | } |
214 | | |
215 | 3 | pub fn create_sliding_accumulator( |
216 | 3 | &self, |
217 | 3 | args: AccumulatorArgs, |
218 | 3 | ) -> Result<Box<dyn Accumulator>> { |
219 | 3 | self.inner.create_sliding_accumulator(args) |
220 | 3 | } |
221 | | |
222 | 0 | pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
223 | 0 | self.inner.coerce_types(arg_types) |
224 | 0 | } |
225 | | |
226 | | /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details. |
227 | 0 | pub fn with_beneficial_ordering( |
228 | 0 | self, |
229 | 0 | beneficial_ordering: bool, |
230 | 0 | ) -> Result<Option<AggregateUDF>> { |
231 | 0 | self.inner |
232 | 0 | .with_beneficial_ordering(beneficial_ordering) |
233 | 0 | .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf })) |
234 | 0 | } |
235 | | |
236 | | /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`] |
237 | | /// for possible options. |
238 | 70 | pub fn order_sensitivity(&self) -> AggregateOrderSensitivity { |
239 | 70 | self.inner.order_sensitivity() |
240 | 70 | } |
241 | | |
242 | | /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will |
243 | | /// generate same result with this `AggregateUDF` when iterated in reverse |
244 | | /// order, and `None` if there is no such `AggregateUDF`). |
245 | 3 | pub fn reverse_udf(&self) -> ReversedUDAF { |
246 | 3 | self.inner.reverse_expr() |
247 | 3 | } |
248 | | |
249 | | /// Do the function rewrite |
250 | | /// |
251 | | /// See [`AggregateUDFImpl::simplify`] for more details. |
252 | 0 | pub fn simplify(&self) -> Option<AggregateFunctionSimplification> { |
253 | 0 | self.inner.simplify() |
254 | 0 | } |
255 | | |
256 | | /// Returns true if the function is max, false if the function is min |
257 | | /// None in all other cases, used in certain optimizations for |
258 | | /// or aggregate |
259 | 0 | pub fn is_descending(&self) -> Option<bool> { |
260 | 0 | self.inner.is_descending() |
261 | 0 | } |
262 | | |
263 | | /// Return the value of this aggregate function if it can be determined |
264 | | /// entirely from statistics and arguments. |
265 | | /// |
266 | | /// See [`AggregateUDFImpl::value_from_stats`] for more details. |
267 | 0 | pub fn value_from_stats( |
268 | 0 | &self, |
269 | 0 | statistics_args: &StatisticsArgs, |
270 | 0 | ) -> Option<ScalarValue> { |
271 | 0 | self.inner.value_from_stats(statistics_args) |
272 | 0 | } |
273 | | |
274 | | /// See [`AggregateUDFImpl::default_value`] for more details. |
275 | 0 | pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> { |
276 | 0 | self.inner.default_value(data_type) |
277 | 0 | } |
278 | | |
279 | | /// Returns the documentation for this Aggregate UDF. |
280 | | /// |
281 | | /// Documentation can be accessed programmatically as well as |
282 | | /// generating publicly facing documentation. |
283 | 0 | pub fn documentation(&self) -> Option<&Documentation> { |
284 | 0 | self.inner.documentation() |
285 | 0 | } |
286 | | } |
287 | | |
288 | | impl<F> From<F> for AggregateUDF |
289 | | where |
290 | | F: AggregateUDFImpl + Send + Sync + 'static, |
291 | | { |
292 | 7 | fn from(fun: F) -> Self { |
293 | 7 | Self::new_from_impl(fun) |
294 | 7 | } |
295 | | } |
296 | | |
297 | | /// Trait for implementing [`AggregateUDF`]. |
298 | | /// |
299 | | /// This trait exposes the full API for implementing user defined aggregate functions and |
300 | | /// can be used to implement any function. |
301 | | /// |
302 | | /// See [`advanced_udaf.rs`] for a full example with complete implementation and |
303 | | /// [`AggregateUDF`] for other available options. |
304 | | /// |
305 | | /// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs |
306 | | /// |
307 | | /// # Basic Example |
308 | | /// ``` |
309 | | /// # use std::any::Any; |
310 | | /// # use std::sync::OnceLock; |
311 | | /// # use arrow::datatypes::DataType; |
312 | | /// # use datafusion_common::{DataFusionError, plan_err, Result}; |
313 | | /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation}; |
314 | | /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; |
315 | | /// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE; |
316 | | /// # use arrow::datatypes::Schema; |
317 | | /// # use arrow::datatypes::Field; |
318 | | /// |
319 | | /// #[derive(Debug, Clone)] |
320 | | /// struct GeoMeanUdf { |
321 | | /// signature: Signature, |
322 | | /// } |
323 | | /// |
324 | | /// impl GeoMeanUdf { |
325 | | /// fn new() -> Self { |
326 | | /// Self { |
327 | | /// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), |
328 | | /// } |
329 | | /// } |
330 | | /// } |
331 | | /// |
332 | | /// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new(); |
333 | | /// |
334 | | /// fn get_doc() -> &'static Documentation { |
335 | | /// DOCUMENTATION.get_or_init(|| { |
336 | | /// Documentation::builder() |
337 | | /// .with_doc_section(DOC_SECTION_AGGREGATE) |
338 | | /// .with_description("calculates a geometric mean") |
339 | | /// .with_syntax_example("geo_mean(2.0)") |
340 | | /// .with_argument("arg1", "The Float64 number for the geometric mean") |
341 | | /// .build() |
342 | | /// .unwrap() |
343 | | /// }) |
344 | | /// } |
345 | | /// |
346 | | /// /// Implement the AggregateUDFImpl trait for GeoMeanUdf |
347 | | /// impl AggregateUDFImpl for GeoMeanUdf { |
348 | | /// fn as_any(&self) -> &dyn Any { self } |
349 | | /// fn name(&self) -> &str { "geo_mean" } |
350 | | /// fn signature(&self) -> &Signature { &self.signature } |
351 | | /// fn return_type(&self, args: &[DataType]) -> Result<DataType> { |
352 | | /// if !matches!(args.get(0), Some(&DataType::Float64)) { |
353 | | /// return plan_err!("geo_mean only accepts Float64 arguments"); |
354 | | /// } |
355 | | /// Ok(DataType::Float64) |
356 | | /// } |
357 | | /// // This is the accumulator factory; DataFusion uses it to create new accumulators. |
358 | | /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() } |
359 | | /// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
360 | | /// Ok(vec![ |
361 | | /// Field::new("value", args.return_type.clone(), true), |
362 | | /// Field::new("ordering", DataType::UInt32, true) |
363 | | /// ]) |
364 | | /// } |
365 | | /// fn documentation(&self) -> Option<&Documentation> { |
366 | | /// Some(get_doc()) |
367 | | /// } |
368 | | /// } |
369 | | /// |
370 | | /// // Create a new AggregateUDF from the implementation |
371 | | /// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); |
372 | | /// |
373 | | /// // Call the function `geo_mean(col)` |
374 | | /// let expr = geometric_mean.call(vec![col("a")]); |
375 | | /// ``` |
376 | | pub trait AggregateUDFImpl: Debug + Send + Sync { |
377 | | // Note: When adding any methods (with default implementations), remember to add them also |
378 | | // into the AliasedAggregateUDFImpl below! |
379 | | |
380 | | /// Returns this object as an [`Any`] trait object |
381 | | fn as_any(&self) -> &dyn Any; |
382 | | |
383 | | /// Returns this function's name |
384 | | fn name(&self) -> &str; |
385 | | |
386 | | /// Returns the function's [`Signature`] for information about what input |
387 | | /// types are accepted and the function's Volatility. |
388 | | fn signature(&self) -> &Signature; |
389 | | |
390 | | /// What [`DataType`] will be returned by this function, given the types of |
391 | | /// the arguments |
392 | | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>; |
393 | | |
394 | | /// Whether the aggregate function is nullable. |
395 | | /// |
396 | | /// Nullable means that that the function could return `null` for any inputs. |
397 | | /// For example, aggregate functions like `COUNT` always return a non null value |
398 | | /// but others like `MIN` will return `NULL` if there is nullable input. |
399 | | /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null` |
400 | 26 | fn is_nullable(&self) -> bool { |
401 | 26 | true |
402 | 26 | } |
403 | | |
404 | | /// Return a new [`Accumulator`] that aggregates values for a specific |
405 | | /// group during query execution. |
406 | | /// |
407 | | /// acc_args: [`AccumulatorArgs`] contains information about how the |
408 | | /// aggregate function was called. |
409 | | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>; |
410 | | |
411 | | /// Return the fields used to store the intermediate state of this accumulator. |
412 | | /// |
413 | | /// See [`Accumulator::state`] for background information. |
414 | | /// |
415 | | /// args: [`StateFieldsArgs`] contains arguments passed to the |
416 | | /// aggregate function's accumulator. |
417 | | /// |
418 | | /// # Notes: |
419 | | /// |
420 | | /// The default implementation returns a single state field named `name` |
421 | | /// with the same type as `value_type`. This is suitable for aggregates such |
422 | | /// as `SUM` or `MIN` where partial state can be combined by applying the |
423 | | /// same aggregate. |
424 | | /// |
425 | | /// For aggregates such as `AVG` where the partial state is more complex |
426 | | /// (e.g. a COUNT and a SUM), this method is used to define the additional |
427 | | /// fields. |
428 | | /// |
429 | | /// The name of the fields must be unique within the query and thus should |
430 | | /// be derived from `name`. See [`format_state_name`] for a utility function |
431 | | /// to generate a unique name. |
432 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
433 | 0 | let fields = vec![Field::new( |
434 | 0 | format_state_name(args.name, "value"), |
435 | 0 | args.return_type.clone(), |
436 | 0 | true, |
437 | 0 | )]; |
438 | 0 |
|
439 | 0 | Ok(fields |
440 | 0 | .into_iter() |
441 | 0 | .chain(args.ordering_fields.to_vec()) |
442 | 0 | .collect()) |
443 | 0 | } |
444 | | |
445 | | /// If the aggregate expression has a specialized |
446 | | /// [`GroupsAccumulator`] implementation. If this returns true, |
447 | | /// `[Self::create_groups_accumulator]` will be called. |
448 | | /// |
449 | | /// # Notes |
450 | | /// |
451 | | /// Even if this function returns true, DataFusion will still use |
452 | | /// [`Self::accumulator`] for certain queries, such as when this aggregate is |
453 | | /// used as a window function or when there no GROUP BY columns in the |
454 | | /// query. |
455 | 40 | fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { |
456 | 40 | false |
457 | 40 | } |
458 | | |
459 | | /// Return a specialized [`GroupsAccumulator`] that manages state |
460 | | /// for all groups. |
461 | | /// |
462 | | /// For maximum performance, a [`GroupsAccumulator`] should be |
463 | | /// implemented in addition to [`Accumulator`]. |
464 | 0 | fn create_groups_accumulator( |
465 | 0 | &self, |
466 | 0 | _args: AccumulatorArgs, |
467 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
468 | 0 | not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") |
469 | 0 | } |
470 | | |
471 | | /// Returns any aliases (alternate names) for this function. |
472 | | /// |
473 | | /// Note: `aliases` should only include names other than [`Self::name`]. |
474 | | /// Defaults to `[]` (no aliases) |
475 | 0 | fn aliases(&self) -> &[String] { |
476 | 0 | &[] |
477 | 0 | } |
478 | | |
479 | | /// Sliding accumulator is an alternative accumulator that can be used for |
480 | | /// window functions. It has retract method to revert the previous update. |
481 | | /// |
482 | | /// See [retract_batch] for more details. |
483 | | /// |
484 | | /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch |
485 | 3 | fn create_sliding_accumulator( |
486 | 3 | &self, |
487 | 3 | args: AccumulatorArgs, |
488 | 3 | ) -> Result<Box<dyn Accumulator>> { |
489 | 3 | self.accumulator(args) |
490 | 3 | } |
491 | | |
492 | | /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is |
493 | | /// satisfied by its input. If this is not the case, UDFs with order |
494 | | /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce |
495 | | /// the correct result with possibly more work internally. |
496 | | /// |
497 | | /// # Returns |
498 | | /// |
499 | | /// Returns `Ok(Some(updated_udf))` if the process completes successfully. |
500 | | /// If the expression can benefit from existing input ordering, but does |
501 | | /// not implement the method, returns an error. Order insensitive and hard |
502 | | /// requirement aggregators return `Ok(None)`. |
503 | 0 | fn with_beneficial_ordering( |
504 | 0 | self: Arc<Self>, |
505 | 0 | _beneficial_ordering: bool, |
506 | 0 | ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> { |
507 | 0 | if self.order_sensitivity().is_beneficial() { |
508 | 0 | return exec_err!( |
509 | 0 | "Should implement with satisfied for aggregator :{:?}", |
510 | 0 | self.name() |
511 | 0 | ); |
512 | 0 | } |
513 | 0 | Ok(None) |
514 | 0 | } |
515 | | |
516 | | /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`] |
517 | | /// for possible options. |
518 | 18 | fn order_sensitivity(&self) -> AggregateOrderSensitivity { |
519 | 18 | // We have hard ordering requirements by default, meaning that order |
520 | 18 | // sensitive UDFs need their input orderings to satisfy their ordering |
521 | 18 | // requirements to generate correct results. |
522 | 18 | AggregateOrderSensitivity::HardRequirement |
523 | 18 | } |
524 | | |
525 | | /// Optionally apply per-UDaF simplification / rewrite rules. |
526 | | /// |
527 | | /// This can be used to apply function specific simplification rules during |
528 | | /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default |
529 | | /// implementation does nothing. |
530 | | /// |
531 | | /// Note that DataFusion handles simplifying arguments and "constant |
532 | | /// folding" (replacing a function call with constant arguments such as |
533 | | /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such |
534 | | /// optimizations manually for specific UDFs. |
535 | | /// |
536 | | /// # Returns |
537 | | /// |
538 | | /// [None] if simplify is not defined or, |
539 | | /// |
540 | | /// Or, a closure with two arguments: |
541 | | /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked |
542 | | /// * 'info': [crate::simplify::SimplifyInfo] |
543 | | /// |
544 | | /// closure returns simplified [Expr] or an error. |
545 | | /// |
546 | 0 | fn simplify(&self) -> Option<AggregateFunctionSimplification> { |
547 | 0 | None |
548 | 0 | } |
549 | | |
550 | | /// Returns the reverse expression of the aggregate function. |
551 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
552 | 0 | ReversedUDAF::NotSupported |
553 | 0 | } |
554 | | |
555 | | /// Coerce arguments of a function call to types that the function can evaluate. |
556 | | /// |
557 | | /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most |
558 | | /// UDAFs should return one of the other variants of `TypeSignature` which handle common |
559 | | /// cases |
560 | | /// |
561 | | /// See the [type coercion module](crate::type_coercion) |
562 | | /// documentation for more details on type coercion |
563 | | /// |
564 | | /// For example, if your function requires a floating point arguments, but the user calls |
565 | | /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` |
566 | | /// to ensure the argument was cast to `1::double` |
567 | | /// |
568 | | /// # Parameters |
569 | | /// * `arg_types`: The argument types of the arguments this function with |
570 | | /// |
571 | | /// # Return value |
572 | | /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call |
573 | | /// arguments to these specific types. |
574 | 0 | fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> { |
575 | 0 | not_impl_err!("Function {} does not implement coerce_types", self.name()) |
576 | 0 | } |
577 | | |
578 | | /// Return true if this aggregate UDF is equal to the other. |
579 | | /// |
580 | | /// Allows customizing the equality of aggregate UDFs. |
581 | | /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: |
582 | | /// |
583 | | /// - reflexive: `a.equals(a)`; |
584 | | /// - symmetric: `a.equals(b)` implies `b.equals(a)`; |
585 | | /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. |
586 | | /// |
587 | | /// By default, compares [`Self::name`] and [`Self::signature`]. |
588 | 0 | fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { |
589 | 0 | self.name() == other.name() && self.signature() == other.signature() |
590 | 0 | } |
591 | | |
592 | | /// Returns a hash value for this aggregate UDF. |
593 | | /// |
594 | | /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`], |
595 | | /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. |
596 | | /// |
597 | | /// By default, hashes [`Self::name`] and [`Self::signature`]. |
598 | 0 | fn hash_value(&self) -> u64 { |
599 | 0 | let hasher = &mut DefaultHasher::new(); |
600 | 0 | self.name().hash(hasher); |
601 | 0 | self.signature().hash(hasher); |
602 | 0 | hasher.finish() |
603 | 0 | } |
604 | | |
605 | | /// If this function is max, return true |
606 | | /// if the function is min, return false |
607 | | /// otherwise return None (the default) |
608 | | /// |
609 | | /// |
610 | | /// Note: this is used to use special aggregate implementations in certain conditions |
611 | 0 | fn is_descending(&self) -> Option<bool> { |
612 | 0 | None |
613 | 0 | } |
614 | | |
615 | | /// Return the value of this aggregate function if it can be determined |
616 | | /// entirely from statistics and arguments. |
617 | | /// |
618 | | /// Using a [`ScalarValue`] rather than a runtime computation can significantly |
619 | | /// improving query performance. |
620 | | /// |
621 | | /// For example, if the minimum value of column `x` is known to be `42` from |
622 | | /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))` |
623 | 0 | fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> { |
624 | 0 | None |
625 | 0 | } |
626 | | |
627 | | /// Returns default value of the function given the input is all `null`. |
628 | | /// |
629 | | /// Most of the aggregate function return Null if input is Null, |
630 | | /// while `count` returns 0 if input is Null |
631 | 0 | fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> { |
632 | 0 | ScalarValue::try_from(data_type) |
633 | 0 | } |
634 | | |
635 | | /// Returns the documentation for this Aggregate UDF. |
636 | | /// |
637 | | /// Documentation can be accessed programmatically as well as |
638 | | /// generating publicly facing documentation. |
639 | 0 | fn documentation(&self) -> Option<&Documentation> { |
640 | 0 | None |
641 | 0 | } |
642 | | } |
643 | | |
644 | | impl PartialEq for dyn AggregateUDFImpl { |
645 | 0 | fn eq(&self, other: &Self) -> bool { |
646 | 0 | self.equals(other) |
647 | 0 | } |
648 | | } |
649 | | |
650 | | // manual implementation of `PartialOrd` |
651 | | // There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl |
652 | | // https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5 |
653 | | impl PartialOrd for dyn AggregateUDFImpl { |
654 | 0 | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
655 | 0 | match self.name().partial_cmp(other.name()) { |
656 | 0 | Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), |
657 | 0 | cmp => cmp, |
658 | | } |
659 | 0 | } |
660 | | } |
661 | | |
662 | | pub enum ReversedUDAF { |
663 | | /// The expression is the same as the original expression, like SUM, COUNT |
664 | | Identical, |
665 | | /// The expression does not support reverse calculation |
666 | | NotSupported, |
667 | | /// The expression is different from the original expression |
668 | | Reversed(Arc<AggregateUDF>), |
669 | | } |
670 | | |
671 | | /// AggregateUDF that adds an alias to the underlying function. It is better to |
672 | | /// implement [`AggregateUDFImpl`], which supports aliases, directly if possible. |
673 | | #[derive(Debug)] |
674 | | struct AliasedAggregateUDFImpl { |
675 | | inner: Arc<dyn AggregateUDFImpl>, |
676 | | aliases: Vec<String>, |
677 | | } |
678 | | |
679 | | impl AliasedAggregateUDFImpl { |
680 | 0 | pub fn new( |
681 | 0 | inner: Arc<dyn AggregateUDFImpl>, |
682 | 0 | new_aliases: impl IntoIterator<Item = &'static str>, |
683 | 0 | ) -> Self { |
684 | 0 | let mut aliases = inner.aliases().to_vec(); |
685 | 0 | aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); |
686 | 0 |
|
687 | 0 | Self { inner, aliases } |
688 | 0 | } |
689 | | } |
690 | | |
691 | | impl AggregateUDFImpl for AliasedAggregateUDFImpl { |
692 | 0 | fn as_any(&self) -> &dyn Any { |
693 | 0 | self |
694 | 0 | } |
695 | | |
696 | 0 | fn name(&self) -> &str { |
697 | 0 | self.inner.name() |
698 | 0 | } |
699 | | |
700 | 0 | fn signature(&self) -> &Signature { |
701 | 0 | self.inner.signature() |
702 | 0 | } |
703 | | |
704 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
705 | 0 | self.inner.return_type(arg_types) |
706 | 0 | } |
707 | | |
708 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
709 | 0 | self.inner.accumulator(acc_args) |
710 | 0 | } |
711 | | |
712 | 0 | fn aliases(&self) -> &[String] { |
713 | 0 | &self.aliases |
714 | 0 | } |
715 | | |
716 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
717 | 0 | self.inner.state_fields(args) |
718 | 0 | } |
719 | | |
720 | 0 | fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { |
721 | 0 | self.inner.groups_accumulator_supported(args) |
722 | 0 | } |
723 | | |
724 | 0 | fn create_groups_accumulator( |
725 | 0 | &self, |
726 | 0 | args: AccumulatorArgs, |
727 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
728 | 0 | self.inner.create_groups_accumulator(args) |
729 | 0 | } |
730 | | |
731 | 0 | fn create_sliding_accumulator( |
732 | 0 | &self, |
733 | 0 | args: AccumulatorArgs, |
734 | 0 | ) -> Result<Box<dyn Accumulator>> { |
735 | 0 | self.inner.accumulator(args) |
736 | 0 | } |
737 | | |
738 | 0 | fn with_beneficial_ordering( |
739 | 0 | self: Arc<Self>, |
740 | 0 | beneficial_ordering: bool, |
741 | 0 | ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> { |
742 | 0 | Arc::clone(&self.inner) |
743 | 0 | .with_beneficial_ordering(beneficial_ordering) |
744 | 0 | .map(|udf| { |
745 | 0 | udf.map(|udf| { |
746 | 0 | Arc::new(AliasedAggregateUDFImpl { |
747 | 0 | inner: udf, |
748 | 0 | aliases: self.aliases.clone(), |
749 | 0 | }) as Arc<dyn AggregateUDFImpl> |
750 | 0 | }) |
751 | 0 | }) |
752 | 0 | } |
753 | | |
754 | 0 | fn order_sensitivity(&self) -> AggregateOrderSensitivity { |
755 | 0 | self.inner.order_sensitivity() |
756 | 0 | } |
757 | | |
758 | 0 | fn simplify(&self) -> Option<AggregateFunctionSimplification> { |
759 | 0 | self.inner.simplify() |
760 | 0 | } |
761 | | |
762 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
763 | 0 | self.inner.reverse_expr() |
764 | 0 | } |
765 | | |
766 | 0 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
767 | 0 | self.inner.coerce_types(arg_types) |
768 | 0 | } |
769 | | |
770 | 0 | fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { |
771 | 0 | if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() { |
772 | 0 | self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases |
773 | | } else { |
774 | 0 | false |
775 | | } |
776 | 0 | } |
777 | | |
778 | 0 | fn hash_value(&self) -> u64 { |
779 | 0 | let hasher = &mut DefaultHasher::new(); |
780 | 0 | self.inner.hash_value().hash(hasher); |
781 | 0 | self.aliases.hash(hasher); |
782 | 0 | hasher.finish() |
783 | 0 | } |
784 | | |
785 | 0 | fn is_descending(&self) -> Option<bool> { |
786 | 0 | self.inner.is_descending() |
787 | 0 | } |
788 | | |
789 | 0 | fn documentation(&self) -> Option<&Documentation> { |
790 | 0 | self.inner.documentation() |
791 | 0 | } |
792 | | } |
793 | | |
794 | | // Aggregate UDF doc sections for use in public documentation |
795 | | pub mod aggregate_doc_sections { |
796 | | use crate::DocSection; |
797 | | |
798 | | pub fn doc_sections() -> Vec<DocSection> { |
799 | | vec![ |
800 | | DOC_SECTION_GENERAL, |
801 | | DOC_SECTION_STATISTICAL, |
802 | | DOC_SECTION_APPROXIMATE, |
803 | | ] |
804 | | } |
805 | | |
806 | | pub const DOC_SECTION_GENERAL: DocSection = DocSection { |
807 | | include: true, |
808 | | label: "General Functions", |
809 | | description: None, |
810 | | }; |
811 | | |
812 | | pub const DOC_SECTION_STATISTICAL: DocSection = DocSection { |
813 | | include: true, |
814 | | label: "Statistical Functions", |
815 | | description: None, |
816 | | }; |
817 | | |
818 | | pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection { |
819 | | include: true, |
820 | | label: "Approximate Functions", |
821 | | description: None, |
822 | | }; |
823 | | } |
824 | | |
825 | | #[cfg(test)] |
826 | | mod test { |
827 | | use crate::{AggregateUDF, AggregateUDFImpl}; |
828 | | use arrow::datatypes::{DataType, Field}; |
829 | | use datafusion_common::Result; |
830 | | use datafusion_expr_common::accumulator::Accumulator; |
831 | | use datafusion_expr_common::signature::{Signature, Volatility}; |
832 | | use datafusion_functions_aggregate_common::accumulator::{ |
833 | | AccumulatorArgs, StateFieldsArgs, |
834 | | }; |
835 | | use std::any::Any; |
836 | | use std::cmp::Ordering; |
837 | | |
838 | | #[derive(Debug, Clone)] |
839 | | struct AMeanUdf { |
840 | | signature: Signature, |
841 | | } |
842 | | |
843 | | impl AMeanUdf { |
844 | | fn new() -> Self { |
845 | | Self { |
846 | | signature: Signature::uniform( |
847 | | 1, |
848 | | vec![DataType::Float64], |
849 | | Volatility::Immutable, |
850 | | ), |
851 | | } |
852 | | } |
853 | | } |
854 | | |
855 | | impl AggregateUDFImpl for AMeanUdf { |
856 | | fn as_any(&self) -> &dyn Any { |
857 | | self |
858 | | } |
859 | | fn name(&self) -> &str { |
860 | | "a" |
861 | | } |
862 | | fn signature(&self) -> &Signature { |
863 | | &self.signature |
864 | | } |
865 | | fn return_type(&self, _args: &[DataType]) -> Result<DataType> { |
866 | | unimplemented!() |
867 | | } |
868 | | fn accumulator( |
869 | | &self, |
870 | | _acc_args: AccumulatorArgs, |
871 | | ) -> Result<Box<dyn Accumulator>> { |
872 | | unimplemented!() |
873 | | } |
874 | | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
875 | | unimplemented!() |
876 | | } |
877 | | } |
878 | | |
879 | | #[derive(Debug, Clone)] |
880 | | struct BMeanUdf { |
881 | | signature: Signature, |
882 | | } |
883 | | impl BMeanUdf { |
884 | | fn new() -> Self { |
885 | | Self { |
886 | | signature: Signature::uniform( |
887 | | 1, |
888 | | vec![DataType::Float64], |
889 | | Volatility::Immutable, |
890 | | ), |
891 | | } |
892 | | } |
893 | | } |
894 | | |
895 | | impl AggregateUDFImpl for BMeanUdf { |
896 | | fn as_any(&self) -> &dyn Any { |
897 | | self |
898 | | } |
899 | | fn name(&self) -> &str { |
900 | | "b" |
901 | | } |
902 | | fn signature(&self) -> &Signature { |
903 | | &self.signature |
904 | | } |
905 | | fn return_type(&self, _args: &[DataType]) -> Result<DataType> { |
906 | | unimplemented!() |
907 | | } |
908 | | fn accumulator( |
909 | | &self, |
910 | | _acc_args: AccumulatorArgs, |
911 | | ) -> Result<Box<dyn Accumulator>> { |
912 | | unimplemented!() |
913 | | } |
914 | | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
915 | | unimplemented!() |
916 | | } |
917 | | } |
918 | | |
919 | | #[test] |
920 | | fn test_partial_ord() { |
921 | | // Test validates that partial ord is defined for AggregateUDF using the name and signature, |
922 | | // not intended to exhaustively test all possibilities |
923 | | let a1 = AggregateUDF::from(AMeanUdf::new()); |
924 | | let a2 = AggregateUDF::from(AMeanUdf::new()); |
925 | | assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal)); |
926 | | |
927 | | let b1 = AggregateUDF::from(BMeanUdf::new()); |
928 | | assert!(a1 < b1); |
929 | | assert!(!(a1 == b1)); |
930 | | } |
931 | | } |