/Users/andrewlamb/Software/datafusion/datafusion/expr/src/test/function_stub.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Aggregate function stubs for test in expr / optimizer. |
19 | | //! |
20 | | //! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate |
21 | | |
22 | | use std::any::Any; |
23 | | |
24 | | use arrow::datatypes::{ |
25 | | DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, |
26 | | }; |
27 | | |
28 | | use datafusion_common::{exec_err, not_impl_err, Result}; |
29 | | |
30 | | use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS}; |
31 | | use crate::Volatility::Immutable; |
32 | | use crate::{ |
33 | | expr::AggregateFunction, |
34 | | function::{AccumulatorArgs, StateFieldsArgs}, |
35 | | utils::AggregateOrderSensitivity, |
36 | | Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, |
37 | | Volatility, |
38 | | }; |
39 | | |
40 | | macro_rules! create_func { |
41 | | ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { |
42 | | paste::paste! { |
43 | | /// Singleton instance of [$UDAF], ensures the UDAF is only created once |
44 | | /// named STATIC_$(UDAF). For example `STATIC_FirstValue` |
45 | | #[allow(non_upper_case_globals)] |
46 | | static [< STATIC_ $UDAF >]: std::sync::OnceLock<std::sync::Arc<crate::AggregateUDF>> = |
47 | | std::sync::OnceLock::new(); |
48 | | |
49 | | #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")] |
50 | 0 | pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> { |
51 | 0 | [< STATIC_ $UDAF >] |
52 | 0 | .get_or_init(|| { |
53 | 0 | std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default())) |
54 | 0 | }) |
55 | 0 | .clone() |
56 | 0 | } |
57 | | } |
58 | | } |
59 | | } |
60 | | |
61 | | create_func!(Sum, sum_udaf); |
62 | | |
63 | 0 | pub fn sum(expr: Expr) -> Expr { |
64 | 0 | Expr::AggregateFunction(AggregateFunction::new_udf( |
65 | 0 | sum_udaf(), |
66 | 0 | vec![expr], |
67 | 0 | false, |
68 | 0 | None, |
69 | 0 | None, |
70 | 0 | None, |
71 | 0 | )) |
72 | 0 | } |
73 | | |
74 | | create_func!(Count, count_udaf); |
75 | | |
76 | 0 | pub fn count(expr: Expr) -> Expr { |
77 | 0 | Expr::AggregateFunction(AggregateFunction::new_udf( |
78 | 0 | count_udaf(), |
79 | 0 | vec![expr], |
80 | 0 | false, |
81 | 0 | None, |
82 | 0 | None, |
83 | 0 | None, |
84 | 0 | )) |
85 | 0 | } |
86 | | |
87 | | create_func!(Avg, avg_udaf); |
88 | | |
89 | 0 | pub fn avg(expr: Expr) -> Expr { |
90 | 0 | Expr::AggregateFunction(AggregateFunction::new_udf( |
91 | 0 | avg_udaf(), |
92 | 0 | vec![expr], |
93 | 0 | false, |
94 | 0 | None, |
95 | 0 | None, |
96 | 0 | None, |
97 | 0 | )) |
98 | 0 | } |
99 | | |
100 | | /// Stub `sum` used for optimizer testing |
101 | | #[derive(Debug)] |
102 | | pub struct Sum { |
103 | | signature: Signature, |
104 | | } |
105 | | |
106 | | impl Sum { |
107 | 0 | pub fn new() -> Self { |
108 | 0 | Self { |
109 | 0 | signature: Signature::user_defined(Volatility::Immutable), |
110 | 0 | } |
111 | 0 | } |
112 | | } |
113 | | |
114 | | impl Default for Sum { |
115 | 0 | fn default() -> Self { |
116 | 0 | Self::new() |
117 | 0 | } |
118 | | } |
119 | | |
120 | | impl AggregateUDFImpl for Sum { |
121 | 0 | fn as_any(&self) -> &dyn Any { |
122 | 0 | self |
123 | 0 | } |
124 | | |
125 | 0 | fn name(&self) -> &str { |
126 | 0 | "sum" |
127 | 0 | } |
128 | | |
129 | 0 | fn signature(&self) -> &Signature { |
130 | 0 | &self.signature |
131 | 0 | } |
132 | | |
133 | 0 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
134 | 0 | if arg_types.len() != 1 { |
135 | 0 | return exec_err!("SUM expects exactly one argument"); |
136 | 0 | } |
137 | | |
138 | | // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc |
139 | | // smallint, int, bigint, real, double precision, decimal, or interval. |
140 | | |
141 | 0 | fn coerced_type(data_type: &DataType) -> Result<DataType> { |
142 | 0 | match data_type { |
143 | 0 | DataType::Dictionary(_, v) => coerced_type(v), |
144 | | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
145 | | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
146 | | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { |
147 | 0 | Ok(data_type.clone()) |
148 | | } |
149 | 0 | dt if dt.is_signed_integer() => Ok(DataType::Int64), |
150 | 0 | dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), |
151 | 0 | dt if dt.is_floating() => Ok(DataType::Float64), |
152 | 0 | _ => exec_err!("Sum not supported for {}", data_type), |
153 | | } |
154 | 0 | } |
155 | | |
156 | 0 | Ok(vec![coerced_type(&arg_types[0])?]) |
157 | 0 | } |
158 | | |
159 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
160 | 0 | match &arg_types[0] { |
161 | 0 | DataType::Int64 => Ok(DataType::Int64), |
162 | 0 | DataType::UInt64 => Ok(DataType::UInt64), |
163 | 0 | DataType::Float64 => Ok(DataType::Float64), |
164 | 0 | DataType::Decimal128(precision, scale) => { |
165 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
166 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
167 | 0 | let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); |
168 | 0 | Ok(DataType::Decimal128(new_precision, *scale)) |
169 | | } |
170 | 0 | DataType::Decimal256(precision, scale) => { |
171 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
172 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
173 | 0 | let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); |
174 | 0 | Ok(DataType::Decimal256(new_precision, *scale)) |
175 | | } |
176 | 0 | other => { |
177 | 0 | exec_err!("[return_type] SUM not supported for {}", other) |
178 | | } |
179 | | } |
180 | 0 | } |
181 | | |
182 | 0 | fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
183 | 0 | unreachable!("stub should not have accumulate()") |
184 | | } |
185 | | |
186 | 0 | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
187 | 0 | unreachable!("stub should not have state_fields()") |
188 | | } |
189 | | |
190 | 0 | fn aliases(&self) -> &[String] { |
191 | 0 | &[] |
192 | 0 | } |
193 | | |
194 | 0 | fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { |
195 | 0 | false |
196 | 0 | } |
197 | | |
198 | 0 | fn create_groups_accumulator( |
199 | 0 | &self, |
200 | 0 | _args: AccumulatorArgs, |
201 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
202 | 0 | unreachable!("stub should not have accumulate()") |
203 | | } |
204 | | |
205 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
206 | 0 | ReversedUDAF::Identical |
207 | 0 | } |
208 | | |
209 | 0 | fn order_sensitivity(&self) -> AggregateOrderSensitivity { |
210 | 0 | AggregateOrderSensitivity::Insensitive |
211 | 0 | } |
212 | | } |
213 | | |
214 | | /// Testing stub implementation of COUNT aggregate |
215 | | pub struct Count { |
216 | | signature: Signature, |
217 | | aliases: Vec<String>, |
218 | | } |
219 | | |
220 | | impl std::fmt::Debug for Count { |
221 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
222 | 0 | f.debug_struct("Count") |
223 | 0 | .field("name", &self.name()) |
224 | 0 | .field("signature", &self.signature) |
225 | 0 | .finish() |
226 | 0 | } |
227 | | } |
228 | | |
229 | | impl Default for Count { |
230 | 0 | fn default() -> Self { |
231 | 0 | Self::new() |
232 | 0 | } |
233 | | } |
234 | | |
235 | | impl Count { |
236 | 0 | pub fn new() -> Self { |
237 | 0 | Self { |
238 | 0 | aliases: vec!["count".to_string()], |
239 | 0 | signature: Signature::variadic_any(Volatility::Immutable), |
240 | 0 | } |
241 | 0 | } |
242 | | } |
243 | | |
244 | | impl AggregateUDFImpl for Count { |
245 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
246 | 0 | self |
247 | 0 | } |
248 | | |
249 | 0 | fn name(&self) -> &str { |
250 | 0 | "COUNT" |
251 | 0 | } |
252 | | |
253 | 0 | fn signature(&self) -> &Signature { |
254 | 0 | &self.signature |
255 | 0 | } |
256 | | |
257 | 0 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
258 | 0 | Ok(DataType::Int64) |
259 | 0 | } |
260 | | |
261 | 0 | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
262 | 0 | not_impl_err!("no impl for stub") |
263 | 0 | } |
264 | | |
265 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
266 | 0 | not_impl_err!("no impl for stub") |
267 | 0 | } |
268 | | |
269 | 0 | fn aliases(&self) -> &[String] { |
270 | 0 | &self.aliases |
271 | 0 | } |
272 | | |
273 | 0 | fn create_groups_accumulator( |
274 | 0 | &self, |
275 | 0 | _args: AccumulatorArgs, |
276 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
277 | 0 | not_impl_err!("no impl for stub") |
278 | 0 | } |
279 | | |
280 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
281 | 0 | ReversedUDAF::Identical |
282 | 0 | } |
283 | | } |
284 | | |
285 | | create_func!(Min, min_udaf); |
286 | | |
287 | 0 | pub fn min(expr: Expr) -> Expr { |
288 | 0 | Expr::AggregateFunction(AggregateFunction::new_udf( |
289 | 0 | min_udaf(), |
290 | 0 | vec![expr], |
291 | 0 | false, |
292 | 0 | None, |
293 | 0 | None, |
294 | 0 | None, |
295 | 0 | )) |
296 | 0 | } |
297 | | |
298 | | /// Testing stub implementation of Min aggregate |
299 | | pub struct Min { |
300 | | signature: Signature, |
301 | | } |
302 | | |
303 | | impl std::fmt::Debug for Min { |
304 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
305 | 0 | f.debug_struct("Min") |
306 | 0 | .field("name", &self.name()) |
307 | 0 | .field("signature", &self.signature) |
308 | 0 | .finish() |
309 | 0 | } |
310 | | } |
311 | | |
312 | | impl Default for Min { |
313 | 0 | fn default() -> Self { |
314 | 0 | Self::new() |
315 | 0 | } |
316 | | } |
317 | | |
318 | | impl Min { |
319 | 0 | pub fn new() -> Self { |
320 | 0 | Self { |
321 | 0 | signature: Signature::variadic_any(Volatility::Immutable), |
322 | 0 | } |
323 | 0 | } |
324 | | } |
325 | | |
326 | | impl AggregateUDFImpl for Min { |
327 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
328 | 0 | self |
329 | 0 | } |
330 | | |
331 | 0 | fn name(&self) -> &str { |
332 | 0 | "min" |
333 | 0 | } |
334 | | |
335 | 0 | fn signature(&self) -> &Signature { |
336 | 0 | &self.signature |
337 | 0 | } |
338 | | |
339 | 0 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
340 | 0 | Ok(DataType::Int64) |
341 | 0 | } |
342 | | |
343 | 0 | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
344 | 0 | not_impl_err!("no impl for stub") |
345 | 0 | } |
346 | | |
347 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
348 | 0 | not_impl_err!("no impl for stub") |
349 | 0 | } |
350 | | |
351 | 0 | fn aliases(&self) -> &[String] { |
352 | 0 | &[] |
353 | 0 | } |
354 | | |
355 | 0 | fn create_groups_accumulator( |
356 | 0 | &self, |
357 | 0 | _args: AccumulatorArgs, |
358 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
359 | 0 | not_impl_err!("no impl for stub") |
360 | 0 | } |
361 | | |
362 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
363 | 0 | ReversedUDAF::Identical |
364 | 0 | } |
365 | 0 | fn is_descending(&self) -> Option<bool> { |
366 | 0 | Some(false) |
367 | 0 | } |
368 | | } |
369 | | |
370 | | create_func!(Max, max_udaf); |
371 | | |
372 | 0 | pub fn max(expr: Expr) -> Expr { |
373 | 0 | Expr::AggregateFunction(AggregateFunction::new_udf( |
374 | 0 | max_udaf(), |
375 | 0 | vec![expr], |
376 | 0 | false, |
377 | 0 | None, |
378 | 0 | None, |
379 | 0 | None, |
380 | 0 | )) |
381 | 0 | } |
382 | | |
383 | | /// Testing stub implementation of MAX aggregate |
384 | | pub struct Max { |
385 | | signature: Signature, |
386 | | } |
387 | | |
388 | | impl std::fmt::Debug for Max { |
389 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
390 | 0 | f.debug_struct("Max") |
391 | 0 | .field("name", &self.name()) |
392 | 0 | .field("signature", &self.signature) |
393 | 0 | .finish() |
394 | 0 | } |
395 | | } |
396 | | |
397 | | impl Default for Max { |
398 | 0 | fn default() -> Self { |
399 | 0 | Self::new() |
400 | 0 | } |
401 | | } |
402 | | |
403 | | impl Max { |
404 | 0 | pub fn new() -> Self { |
405 | 0 | Self { |
406 | 0 | signature: Signature::variadic_any(Volatility::Immutable), |
407 | 0 | } |
408 | 0 | } |
409 | | } |
410 | | |
411 | | impl AggregateUDFImpl for Max { |
412 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
413 | 0 | self |
414 | 0 | } |
415 | | |
416 | 0 | fn name(&self) -> &str { |
417 | 0 | "max" |
418 | 0 | } |
419 | | |
420 | 0 | fn signature(&self) -> &Signature { |
421 | 0 | &self.signature |
422 | 0 | } |
423 | | |
424 | 0 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
425 | 0 | Ok(DataType::Int64) |
426 | 0 | } |
427 | | |
428 | 0 | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
429 | 0 | not_impl_err!("no impl for stub") |
430 | 0 | } |
431 | | |
432 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
433 | 0 | not_impl_err!("no impl for stub") |
434 | 0 | } |
435 | | |
436 | 0 | fn aliases(&self) -> &[String] { |
437 | 0 | &[] |
438 | 0 | } |
439 | | |
440 | 0 | fn create_groups_accumulator( |
441 | 0 | &self, |
442 | 0 | _args: AccumulatorArgs, |
443 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
444 | 0 | not_impl_err!("no impl for stub") |
445 | 0 | } |
446 | | |
447 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
448 | 0 | ReversedUDAF::Identical |
449 | 0 | } |
450 | 0 | fn is_descending(&self) -> Option<bool> { |
451 | 0 | Some(true) |
452 | 0 | } |
453 | | } |
454 | | |
455 | | /// Testing stub implementation of avg aggregate |
456 | | #[derive(Debug)] |
457 | | pub struct Avg { |
458 | | signature: Signature, |
459 | | aliases: Vec<String>, |
460 | | } |
461 | | |
462 | | impl Avg { |
463 | 0 | pub fn new() -> Self { |
464 | 0 | Self { |
465 | 0 | aliases: vec![String::from("mean")], |
466 | 0 | signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable), |
467 | 0 | } |
468 | 0 | } |
469 | | } |
470 | | |
471 | | impl Default for Avg { |
472 | 0 | fn default() -> Self { |
473 | 0 | Self::new() |
474 | 0 | } |
475 | | } |
476 | | |
477 | | impl AggregateUDFImpl for Avg { |
478 | 0 | fn as_any(&self) -> &dyn Any { |
479 | 0 | self |
480 | 0 | } |
481 | | |
482 | 0 | fn name(&self) -> &str { |
483 | 0 | "avg" |
484 | 0 | } |
485 | | |
486 | 0 | fn signature(&self) -> &Signature { |
487 | 0 | &self.signature |
488 | 0 | } |
489 | | |
490 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
491 | 0 | avg_return_type(self.name(), &arg_types[0]) |
492 | 0 | } |
493 | | |
494 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
495 | 0 | not_impl_err!("no impl for stub") |
496 | 0 | } |
497 | | |
498 | 0 | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
499 | 0 | not_impl_err!("no impl for stub") |
500 | 0 | } |
501 | 0 | fn aliases(&self) -> &[String] { |
502 | 0 | &self.aliases |
503 | 0 | } |
504 | | |
505 | 0 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
506 | 0 | coerce_avg_type(self.name(), arg_types) |
507 | 0 | } |
508 | | } |