/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/min_max.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // "License"); you may not use this file except in compliance |
6 | | // with the License. You may obtain a copy of the License at |
7 | | // |
8 | | // http://www.apache.org/licenses/LICENSE-2.0 |
9 | | // |
10 | | // Unless required by applicable law or agreed to in writing, |
11 | | // software distributed under the License is distributed on an |
12 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
13 | | // KIND, either express or implied. See the License for the |
14 | | // specific language governing permissions and limitations |
15 | | // under the License. |
16 | | |
17 | | //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function |
18 | | //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function |
19 | | |
20 | | // distributed with this work for additional information |
21 | | // regarding copyright ownership. The ASF licenses this file |
22 | | // to you under the Apache License, Version 2.0 (the |
23 | | // "License"); you may not use this file except in compliance |
24 | | // with the License. You may obtain a copy of the License at |
25 | | // |
26 | | // http://www.apache.org/licenses/LICENSE-2.0 |
27 | | // |
28 | | // Unless required by applicable law or agreed to in writing, |
29 | | // software distributed under the License is distributed on an |
30 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
31 | | // KIND, either express or implied. See the License for the |
32 | | // specific language governing permissions and limitations |
33 | | // under the License. |
34 | | |
35 | | use arrow::array::{ |
36 | | ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, |
37 | | Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, |
38 | | Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, |
39 | | IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, |
40 | | LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, |
41 | | Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, |
42 | | TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, |
43 | | TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, |
44 | | }; |
45 | | use arrow::compute; |
46 | | use arrow::datatypes::{ |
47 | | DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, |
48 | | Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, |
49 | | UInt8Type, |
50 | | }; |
51 | | use arrow_schema::IntervalUnit; |
52 | | use datafusion_common::stats::Precision; |
53 | | use datafusion_common::{ |
54 | | downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, |
55 | | }; |
56 | | use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; |
57 | | use datafusion_physical_expr::expressions; |
58 | | use std::fmt::Debug; |
59 | | |
60 | | use arrow::datatypes::i256; |
61 | | use arrow::datatypes::{ |
62 | | Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, |
63 | | Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, |
64 | | TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, |
65 | | }; |
66 | | |
67 | | use datafusion_common::ScalarValue; |
68 | | use datafusion_expr::{ |
69 | | function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, |
70 | | }; |
71 | | use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; |
72 | | use half::f16; |
73 | | use std::ops::Deref; |
74 | | |
75 | 0 | fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> { |
76 | 0 | // make sure that the input types only has one element. |
77 | 0 | if input_types.len() != 1 { |
78 | 0 | return exec_err!( |
79 | 0 | "min/max was called with {} arguments. It requires only 1.", |
80 | 0 | input_types.len() |
81 | 0 | ); |
82 | 0 | } |
83 | 0 | // min and max support the dictionary data type |
84 | 0 | // unpack the dictionary to get the value |
85 | 0 | match &input_types[0] { |
86 | 0 | DataType::Dictionary(_, dict_value_type) => { |
87 | 0 | // TODO add checker, if the value type is complex data type |
88 | 0 | Ok(vec![dict_value_type.deref().clone()]) |
89 | | } |
90 | | // TODO add checker for datatype which min and max supported |
91 | | // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function |
92 | 0 | _ => Ok(input_types.to_vec()), |
93 | | } |
94 | 0 | } |
95 | | |
96 | | // MAX aggregate UDF |
97 | | #[derive(Debug)] |
98 | | pub struct Max { |
99 | | signature: Signature, |
100 | | } |
101 | | |
102 | | impl Max { |
103 | 0 | pub fn new() -> Self { |
104 | 0 | Self { |
105 | 0 | signature: Signature::user_defined(Volatility::Immutable), |
106 | 0 | } |
107 | 0 | } |
108 | | } |
109 | | |
110 | | impl Default for Max { |
111 | 0 | fn default() -> Self { |
112 | 0 | Self::new() |
113 | 0 | } |
114 | | } |
115 | | /// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX` |
116 | | /// the specified [`ArrowPrimitiveType`]. |
117 | | /// |
118 | | /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType |
119 | | macro_rules! instantiate_max_accumulator { |
120 | | ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ |
121 | | Ok(Box::new( |
122 | 0 | PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { |
123 | 0 | if *cur < new { |
124 | 0 | *cur = new |
125 | 0 | } |
126 | 0 | }) |
127 | | // Initialize each accumulator to $NATIVE::MIN |
128 | | .with_starting_value($NATIVE::MIN), |
129 | | )) |
130 | | }}; |
131 | | } |
132 | | |
133 | | /// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN` |
134 | | /// the specified [`ArrowPrimitiveType`]. |
135 | | /// |
136 | | /// |
137 | | /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType |
138 | | macro_rules! instantiate_min_accumulator { |
139 | | ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ |
140 | | Ok(Box::new( |
141 | 0 | PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { |
142 | 0 | if *cur > new { |
143 | 0 | *cur = new |
144 | 0 | } |
145 | 0 | }) |
146 | | // Initialize each accumulator to $NATIVE::MAX |
147 | | .with_starting_value($NATIVE::MAX), |
148 | | )) |
149 | | }}; |
150 | | } |
151 | | |
152 | | trait FromColumnStatistics { |
153 | | fn value_from_column_statistics( |
154 | | &self, |
155 | | stats: &ColumnStatistics, |
156 | | ) -> Option<ScalarValue>; |
157 | | |
158 | 0 | fn value_from_statistics( |
159 | 0 | &self, |
160 | 0 | statistics_args: &StatisticsArgs, |
161 | 0 | ) -> Option<ScalarValue> { |
162 | 0 | if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { |
163 | 0 | match *num_rows { |
164 | 0 | 0 => return ScalarValue::try_from(statistics_args.return_type).ok(), |
165 | 0 | value if value > 0 => { |
166 | 0 | let col_stats = &statistics_args.statistics.column_statistics; |
167 | 0 | if statistics_args.exprs.len() == 1 { |
168 | | // TODO optimize with exprs other than Column |
169 | 0 | if let Some(col_expr) = statistics_args.exprs[0] |
170 | 0 | .as_any() |
171 | 0 | .downcast_ref::<expressions::Column>() |
172 | | { |
173 | 0 | return self.value_from_column_statistics( |
174 | 0 | &col_stats[col_expr.index()], |
175 | 0 | ); |
176 | 0 | } |
177 | 0 | } |
178 | | } |
179 | 0 | _ => {} |
180 | | } |
181 | 0 | } |
182 | 0 | None |
183 | 0 | } |
184 | | } |
185 | | |
186 | | impl FromColumnStatistics for Max { |
187 | 0 | fn value_from_column_statistics( |
188 | 0 | &self, |
189 | 0 | col_stats: &ColumnStatistics, |
190 | 0 | ) -> Option<ScalarValue> { |
191 | 0 | if let Precision::Exact(ref val) = col_stats.max_value { |
192 | 0 | if !val.is_null() { |
193 | 0 | return Some(val.clone()); |
194 | 0 | } |
195 | 0 | } |
196 | 0 | None |
197 | 0 | } |
198 | | } |
199 | | |
200 | | impl AggregateUDFImpl for Max { |
201 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
202 | 0 | self |
203 | 0 | } |
204 | | |
205 | 0 | fn name(&self) -> &str { |
206 | 0 | "max" |
207 | 0 | } |
208 | | |
209 | 0 | fn signature(&self) -> &Signature { |
210 | 0 | &self.signature |
211 | 0 | } |
212 | | |
213 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
214 | 0 | Ok(arg_types[0].to_owned()) |
215 | 0 | } |
216 | | |
217 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
218 | 0 | Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?)) |
219 | 0 | } |
220 | | |
221 | 0 | fn aliases(&self) -> &[String] { |
222 | 0 | &[] |
223 | 0 | } |
224 | | |
225 | 0 | fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { |
226 | | use DataType::*; |
227 | 0 | matches!( |
228 | 0 | args.return_type, |
229 | | Int8 | Int16 |
230 | | | Int32 |
231 | | | Int64 |
232 | | | UInt8 |
233 | | | UInt16 |
234 | | | UInt32 |
235 | | | UInt64 |
236 | | | Float16 |
237 | | | Float32 |
238 | | | Float64 |
239 | | | Decimal128(_, _) |
240 | | | Decimal256(_, _) |
241 | | | Date32 |
242 | | | Date64 |
243 | | | Time32(_) |
244 | | | Time64(_) |
245 | | | Timestamp(_, _) |
246 | | ) |
247 | 0 | } |
248 | | |
249 | 0 | fn create_groups_accumulator( |
250 | 0 | &self, |
251 | 0 | args: AccumulatorArgs, |
252 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
253 | | use DataType::*; |
254 | | use TimeUnit::*; |
255 | 0 | let data_type = args.return_type; |
256 | 0 | match data_type { |
257 | 0 | Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), |
258 | 0 | Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), |
259 | 0 | Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), |
260 | 0 | Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), |
261 | 0 | UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), |
262 | 0 | UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), |
263 | 0 | UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), |
264 | 0 | UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), |
265 | | Float16 => { |
266 | 0 | instantiate_max_accumulator!(data_type, f16, Float16Type) |
267 | | } |
268 | | Float32 => { |
269 | 0 | instantiate_max_accumulator!(data_type, f32, Float32Type) |
270 | | } |
271 | | Float64 => { |
272 | 0 | instantiate_max_accumulator!(data_type, f64, Float64Type) |
273 | | } |
274 | 0 | Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), |
275 | 0 | Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), |
276 | | Time32(Second) => { |
277 | 0 | instantiate_max_accumulator!(data_type, i32, Time32SecondType) |
278 | | } |
279 | | Time32(Millisecond) => { |
280 | 0 | instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) |
281 | | } |
282 | | Time64(Microsecond) => { |
283 | 0 | instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) |
284 | | } |
285 | | Time64(Nanosecond) => { |
286 | 0 | instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) |
287 | | } |
288 | | Timestamp(Second, _) => { |
289 | 0 | instantiate_max_accumulator!(data_type, i64, TimestampSecondType) |
290 | | } |
291 | | Timestamp(Millisecond, _) => { |
292 | 0 | instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) |
293 | | } |
294 | | Timestamp(Microsecond, _) => { |
295 | 0 | instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) |
296 | | } |
297 | | Timestamp(Nanosecond, _) => { |
298 | 0 | instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) |
299 | | } |
300 | | Decimal128(_, _) => { |
301 | 0 | instantiate_max_accumulator!(data_type, i128, Decimal128Type) |
302 | | } |
303 | | Decimal256(_, _) => { |
304 | 0 | instantiate_max_accumulator!(data_type, i256, Decimal256Type) |
305 | | } |
306 | | |
307 | | // It would be nice to have a fast implementation for Strings as well |
308 | | // https://github.com/apache/datafusion/issues/6906 |
309 | | |
310 | | // This is only reached if groups_accumulator_supported is out of sync |
311 | 0 | _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), |
312 | | } |
313 | 0 | } |
314 | | |
315 | 0 | fn create_sliding_accumulator( |
316 | 0 | &self, |
317 | 0 | args: AccumulatorArgs, |
318 | 0 | ) -> Result<Box<dyn Accumulator>> { |
319 | 0 | Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?)) |
320 | 0 | } |
321 | | |
322 | 0 | fn is_descending(&self) -> Option<bool> { |
323 | 0 | Some(true) |
324 | 0 | } |
325 | | |
326 | 0 | fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { |
327 | 0 | datafusion_expr::utils::AggregateOrderSensitivity::Insensitive |
328 | 0 | } |
329 | | |
330 | 0 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
331 | 0 | get_min_max_result_type(arg_types) |
332 | 0 | } |
333 | 0 | fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { |
334 | 0 | datafusion_expr::ReversedUDAF::Identical |
335 | 0 | } |
336 | 0 | fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> { |
337 | 0 | self.value_from_statistics(statistics_args) |
338 | 0 | } |
339 | | } |
340 | | |
341 | | // Statically-typed version of min/max(array) -> ScalarValue for string types |
342 | | macro_rules! typed_min_max_batch_string { |
343 | | ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ |
344 | | let array = downcast_value!($VALUES, $ARRAYTYPE); |
345 | | let value = compute::$OP(array); |
346 | 0 | let value = value.and_then(|e| Some(e.to_string())); |
347 | | ScalarValue::$SCALAR(value) |
348 | | }}; |
349 | | } |
350 | | // Statically-typed version of min/max(array) -> ScalarValue for binay types. |
351 | | macro_rules! typed_min_max_batch_binary { |
352 | | ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ |
353 | | let array = downcast_value!($VALUES, $ARRAYTYPE); |
354 | | let value = compute::$OP(array); |
355 | 0 | let value = value.and_then(|e| Some(e.to_vec())); |
356 | | ScalarValue::$SCALAR(value) |
357 | | }}; |
358 | | } |
359 | | |
360 | | // Statically-typed version of min/max(array) -> ScalarValue for non-string types. |
361 | | macro_rules! typed_min_max_batch { |
362 | | ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ |
363 | | let array = downcast_value!($VALUES, $ARRAYTYPE); |
364 | | let value = compute::$OP(array); |
365 | | ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) |
366 | | }}; |
367 | | } |
368 | | |
369 | | // Statically-typed version of min/max(array) -> ScalarValue for non-string types. |
370 | | // this is a macro to support both operations (min and max). |
371 | | macro_rules! min_max_batch { |
372 | | ($VALUES:expr, $OP:ident) => {{ |
373 | | match $VALUES.data_type() { |
374 | | DataType::Null => ScalarValue::Null, |
375 | | DataType::Decimal128(precision, scale) => { |
376 | | typed_min_max_batch!( |
377 | | $VALUES, |
378 | | Decimal128Array, |
379 | | Decimal128, |
380 | | $OP, |
381 | | precision, |
382 | | scale |
383 | | ) |
384 | | } |
385 | | DataType::Decimal256(precision, scale) => { |
386 | | typed_min_max_batch!( |
387 | | $VALUES, |
388 | | Decimal256Array, |
389 | | Decimal256, |
390 | | $OP, |
391 | | precision, |
392 | | scale |
393 | | ) |
394 | | } |
395 | | // all types that have a natural order |
396 | | DataType::Float64 => { |
397 | | typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) |
398 | | } |
399 | | DataType::Float32 => { |
400 | | typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) |
401 | | } |
402 | | DataType::Float16 => { |
403 | | typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) |
404 | | } |
405 | | DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), |
406 | | DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), |
407 | | DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), |
408 | | DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), |
409 | | DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), |
410 | | DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), |
411 | | DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), |
412 | | DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), |
413 | | DataType::Timestamp(TimeUnit::Second, tz_opt) => { |
414 | | typed_min_max_batch!( |
415 | | $VALUES, |
416 | | TimestampSecondArray, |
417 | | TimestampSecond, |
418 | | $OP, |
419 | | tz_opt |
420 | | ) |
421 | | } |
422 | | DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( |
423 | | $VALUES, |
424 | | TimestampMillisecondArray, |
425 | | TimestampMillisecond, |
426 | | $OP, |
427 | | tz_opt |
428 | | ), |
429 | | DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( |
430 | | $VALUES, |
431 | | TimestampMicrosecondArray, |
432 | | TimestampMicrosecond, |
433 | | $OP, |
434 | | tz_opt |
435 | | ), |
436 | | DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( |
437 | | $VALUES, |
438 | | TimestampNanosecondArray, |
439 | | TimestampNanosecond, |
440 | | $OP, |
441 | | tz_opt |
442 | | ), |
443 | | DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), |
444 | | DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), |
445 | | DataType::Time32(TimeUnit::Second) => { |
446 | | typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) |
447 | | } |
448 | | DataType::Time32(TimeUnit::Millisecond) => { |
449 | | typed_min_max_batch!( |
450 | | $VALUES, |
451 | | Time32MillisecondArray, |
452 | | Time32Millisecond, |
453 | | $OP |
454 | | ) |
455 | | } |
456 | | DataType::Time64(TimeUnit::Microsecond) => { |
457 | | typed_min_max_batch!( |
458 | | $VALUES, |
459 | | Time64MicrosecondArray, |
460 | | Time64Microsecond, |
461 | | $OP |
462 | | ) |
463 | | } |
464 | | DataType::Time64(TimeUnit::Nanosecond) => { |
465 | | typed_min_max_batch!( |
466 | | $VALUES, |
467 | | Time64NanosecondArray, |
468 | | Time64Nanosecond, |
469 | | $OP |
470 | | ) |
471 | | } |
472 | | DataType::Interval(IntervalUnit::YearMonth) => { |
473 | | typed_min_max_batch!( |
474 | | $VALUES, |
475 | | IntervalYearMonthArray, |
476 | | IntervalYearMonth, |
477 | | $OP |
478 | | ) |
479 | | } |
480 | | DataType::Interval(IntervalUnit::DayTime) => { |
481 | | typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) |
482 | | } |
483 | | DataType::Interval(IntervalUnit::MonthDayNano) => { |
484 | | typed_min_max_batch!( |
485 | | $VALUES, |
486 | | IntervalMonthDayNanoArray, |
487 | | IntervalMonthDayNano, |
488 | | $OP |
489 | | ) |
490 | | } |
491 | | other => { |
492 | | // This should have been handled before |
493 | | return internal_err!( |
494 | | "Min/Max accumulator not implemented for type {:?}", |
495 | | other |
496 | | ); |
497 | | } |
498 | | } |
499 | | }}; |
500 | | } |
501 | | |
502 | | /// dynamically-typed min(array) -> ScalarValue |
503 | 0 | fn min_batch(values: &ArrayRef) -> Result<ScalarValue> { |
504 | 0 | Ok(match values.data_type() { |
505 | | DataType::Utf8 => { |
506 | 0 | typed_min_max_batch_string!(values, StringArray, Utf8, min_string) |
507 | | } |
508 | | DataType::LargeUtf8 => { |
509 | 0 | typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) |
510 | | } |
511 | | DataType::Utf8View => { |
512 | 0 | typed_min_max_batch_string!( |
513 | 0 | values, |
514 | 0 | StringViewArray, |
515 | 0 | Utf8View, |
516 | 0 | min_string_view |
517 | 0 | ) |
518 | | } |
519 | | DataType::Boolean => { |
520 | 0 | typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) |
521 | | } |
522 | | DataType::Binary => { |
523 | 0 | typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) |
524 | | } |
525 | | DataType::LargeBinary => { |
526 | 0 | typed_min_max_batch_binary!( |
527 | 0 | &values, |
528 | 0 | LargeBinaryArray, |
529 | 0 | LargeBinary, |
530 | 0 | min_binary |
531 | 0 | ) |
532 | | } |
533 | | DataType::BinaryView => { |
534 | 0 | typed_min_max_batch_binary!( |
535 | 0 | &values, |
536 | 0 | BinaryViewArray, |
537 | 0 | BinaryView, |
538 | 0 | min_binary_view |
539 | 0 | ) |
540 | | } |
541 | 0 | _ => min_max_batch!(values, min), |
542 | | }) |
543 | 0 | } |
544 | | |
545 | | /// dynamically-typed max(array) -> ScalarValue |
546 | 0 | fn max_batch(values: &ArrayRef) -> Result<ScalarValue> { |
547 | 0 | Ok(match values.data_type() { |
548 | | DataType::Utf8 => { |
549 | 0 | typed_min_max_batch_string!(values, StringArray, Utf8, max_string) |
550 | | } |
551 | | DataType::LargeUtf8 => { |
552 | 0 | typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) |
553 | | } |
554 | | DataType::Utf8View => { |
555 | 0 | typed_min_max_batch_string!( |
556 | 0 | values, |
557 | 0 | StringViewArray, |
558 | 0 | Utf8View, |
559 | 0 | max_string_view |
560 | 0 | ) |
561 | | } |
562 | | DataType::Boolean => { |
563 | 0 | typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) |
564 | | } |
565 | | DataType::Binary => { |
566 | 0 | typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) |
567 | | } |
568 | | DataType::BinaryView => { |
569 | 0 | typed_min_max_batch_binary!( |
570 | 0 | &values, |
571 | 0 | BinaryViewArray, |
572 | 0 | BinaryView, |
573 | 0 | max_binary_view |
574 | 0 | ) |
575 | | } |
576 | | DataType::LargeBinary => { |
577 | 0 | typed_min_max_batch_binary!( |
578 | 0 | &values, |
579 | 0 | LargeBinaryArray, |
580 | 0 | LargeBinary, |
581 | 0 | max_binary |
582 | 0 | ) |
583 | | } |
584 | 0 | _ => min_max_batch!(values, max), |
585 | | }) |
586 | 0 | } |
587 | | |
588 | | // min/max of two non-string scalar values. |
589 | | macro_rules! typed_min_max { |
590 | | ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ |
591 | | ScalarValue::$SCALAR( |
592 | | match ($VALUE, $DELTA) { |
593 | | (None, None) => None, |
594 | | (Some(a), None) => Some(*a), |
595 | | (None, Some(b)) => Some(*b), |
596 | | (Some(a), Some(b)) => Some((*a).$OP(*b)), |
597 | | }, |
598 | | $($EXTRA_ARGS.clone()),* |
599 | | ) |
600 | | }}; |
601 | | } |
602 | | macro_rules! typed_min_max_float { |
603 | | ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ |
604 | | ScalarValue::$SCALAR(match ($VALUE, $DELTA) { |
605 | | (None, None) => None, |
606 | | (Some(a), None) => Some(*a), |
607 | | (None, Some(b)) => Some(*b), |
608 | | (Some(a), Some(b)) => match a.total_cmp(b) { |
609 | | choose_min_max!($OP) => Some(*b), |
610 | | _ => Some(*a), |
611 | | }, |
612 | | }) |
613 | | }}; |
614 | | } |
615 | | |
616 | | // min/max of two scalar string values. |
617 | | macro_rules! typed_min_max_string { |
618 | | ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ |
619 | | ScalarValue::$SCALAR(match ($VALUE, $DELTA) { |
620 | | (None, None) => None, |
621 | | (Some(a), None) => Some(a.clone()), |
622 | | (None, Some(b)) => Some(b.clone()), |
623 | | (Some(a), Some(b)) => Some((a).$OP(b).clone()), |
624 | | }) |
625 | | }}; |
626 | | } |
627 | | |
628 | | macro_rules! choose_min_max { |
629 | | (min) => { |
630 | | std::cmp::Ordering::Greater |
631 | | }; |
632 | | (max) => { |
633 | | std::cmp::Ordering::Less |
634 | | }; |
635 | | } |
636 | | |
637 | | macro_rules! interval_min_max { |
638 | | ($OP:tt, $LHS:expr, $RHS:expr) => {{ |
639 | | match $LHS.partial_cmp(&$RHS) { |
640 | | Some(choose_min_max!($OP)) => $RHS.clone(), |
641 | | Some(_) => $LHS.clone(), |
642 | | None => { |
643 | | return internal_err!("Comparison error while computing interval min/max") |
644 | | } |
645 | | } |
646 | | }}; |
647 | | } |
648 | | |
649 | | // min/max of two scalar values of the same type |
650 | | macro_rules! min_max { |
651 | | ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ |
652 | | Ok(match ($VALUE, $DELTA) { |
653 | | (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, |
654 | | ( |
655 | | lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), |
656 | | rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) |
657 | | ) => { |
658 | | if lhsp.eq(rhsp) && lhss.eq(rhss) { |
659 | | typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) |
660 | | } else { |
661 | | return internal_err!( |
662 | | "MIN/MAX is not expected to receive scalars of incompatible types {:?}", |
663 | | (lhs, rhs) |
664 | | ); |
665 | | } |
666 | | } |
667 | | ( |
668 | | lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), |
669 | | rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) |
670 | | ) => { |
671 | | if lhsp.eq(rhsp) && lhss.eq(rhss) { |
672 | | typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) |
673 | | } else { |
674 | | return internal_err!( |
675 | | "MIN/MAX is not expected to receive scalars of incompatible types {:?}", |
676 | | (lhs, rhs) |
677 | | ); |
678 | | } |
679 | | } |
680 | | (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { |
681 | | typed_min_max!(lhs, rhs, Boolean, $OP) |
682 | | } |
683 | | (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { |
684 | | typed_min_max_float!(lhs, rhs, Float64, $OP) |
685 | | } |
686 | | (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { |
687 | | typed_min_max_float!(lhs, rhs, Float32, $OP) |
688 | | } |
689 | | (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { |
690 | | typed_min_max_float!(lhs, rhs, Float16, $OP) |
691 | | } |
692 | | (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { |
693 | | typed_min_max!(lhs, rhs, UInt64, $OP) |
694 | | } |
695 | | (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { |
696 | | typed_min_max!(lhs, rhs, UInt32, $OP) |
697 | | } |
698 | | (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { |
699 | | typed_min_max!(lhs, rhs, UInt16, $OP) |
700 | | } |
701 | | (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { |
702 | | typed_min_max!(lhs, rhs, UInt8, $OP) |
703 | | } |
704 | | (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { |
705 | | typed_min_max!(lhs, rhs, Int64, $OP) |
706 | | } |
707 | | (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { |
708 | | typed_min_max!(lhs, rhs, Int32, $OP) |
709 | | } |
710 | | (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { |
711 | | typed_min_max!(lhs, rhs, Int16, $OP) |
712 | | } |
713 | | (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { |
714 | | typed_min_max!(lhs, rhs, Int8, $OP) |
715 | | } |
716 | | (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { |
717 | | typed_min_max_string!(lhs, rhs, Utf8, $OP) |
718 | | } |
719 | | (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { |
720 | | typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) |
721 | | } |
722 | | (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { |
723 | | typed_min_max_string!(lhs, rhs, Utf8View, $OP) |
724 | | } |
725 | | (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { |
726 | | typed_min_max_string!(lhs, rhs, Binary, $OP) |
727 | | } |
728 | | (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { |
729 | | typed_min_max_string!(lhs, rhs, LargeBinary, $OP) |
730 | | } |
731 | | (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { |
732 | | typed_min_max_string!(lhs, rhs, BinaryView, $OP) |
733 | | } |
734 | | (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { |
735 | | typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) |
736 | | } |
737 | | ( |
738 | | ScalarValue::TimestampMillisecond(lhs, l_tz), |
739 | | ScalarValue::TimestampMillisecond(rhs, _), |
740 | | ) => { |
741 | | typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) |
742 | | } |
743 | | ( |
744 | | ScalarValue::TimestampMicrosecond(lhs, l_tz), |
745 | | ScalarValue::TimestampMicrosecond(rhs, _), |
746 | | ) => { |
747 | | typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) |
748 | | } |
749 | | ( |
750 | | ScalarValue::TimestampNanosecond(lhs, l_tz), |
751 | | ScalarValue::TimestampNanosecond(rhs, _), |
752 | | ) => { |
753 | | typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) |
754 | | } |
755 | | ( |
756 | | ScalarValue::Date32(lhs), |
757 | | ScalarValue::Date32(rhs), |
758 | | ) => { |
759 | | typed_min_max!(lhs, rhs, Date32, $OP) |
760 | | } |
761 | | ( |
762 | | ScalarValue::Date64(lhs), |
763 | | ScalarValue::Date64(rhs), |
764 | | ) => { |
765 | | typed_min_max!(lhs, rhs, Date64, $OP) |
766 | | } |
767 | | ( |
768 | | ScalarValue::Time32Second(lhs), |
769 | | ScalarValue::Time32Second(rhs), |
770 | | ) => { |
771 | | typed_min_max!(lhs, rhs, Time32Second, $OP) |
772 | | } |
773 | | ( |
774 | | ScalarValue::Time32Millisecond(lhs), |
775 | | ScalarValue::Time32Millisecond(rhs), |
776 | | ) => { |
777 | | typed_min_max!(lhs, rhs, Time32Millisecond, $OP) |
778 | | } |
779 | | ( |
780 | | ScalarValue::Time64Microsecond(lhs), |
781 | | ScalarValue::Time64Microsecond(rhs), |
782 | | ) => { |
783 | | typed_min_max!(lhs, rhs, Time64Microsecond, $OP) |
784 | | } |
785 | | ( |
786 | | ScalarValue::Time64Nanosecond(lhs), |
787 | | ScalarValue::Time64Nanosecond(rhs), |
788 | | ) => { |
789 | | typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) |
790 | | } |
791 | | ( |
792 | | ScalarValue::IntervalYearMonth(lhs), |
793 | | ScalarValue::IntervalYearMonth(rhs), |
794 | | ) => { |
795 | | typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) |
796 | | } |
797 | | ( |
798 | | ScalarValue::IntervalMonthDayNano(lhs), |
799 | | ScalarValue::IntervalMonthDayNano(rhs), |
800 | | ) => { |
801 | | typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) |
802 | | } |
803 | | ( |
804 | | ScalarValue::IntervalDayTime(lhs), |
805 | | ScalarValue::IntervalDayTime(rhs), |
806 | | ) => { |
807 | | typed_min_max!(lhs, rhs, IntervalDayTime, $OP) |
808 | | } |
809 | | ( |
810 | | ScalarValue::IntervalYearMonth(_), |
811 | | ScalarValue::IntervalMonthDayNano(_), |
812 | | ) | ( |
813 | | ScalarValue::IntervalYearMonth(_), |
814 | | ScalarValue::IntervalDayTime(_), |
815 | | ) | ( |
816 | | ScalarValue::IntervalMonthDayNano(_), |
817 | | ScalarValue::IntervalDayTime(_), |
818 | | ) | ( |
819 | | ScalarValue::IntervalMonthDayNano(_), |
820 | | ScalarValue::IntervalYearMonth(_), |
821 | | ) | ( |
822 | | ScalarValue::IntervalDayTime(_), |
823 | | ScalarValue::IntervalYearMonth(_), |
824 | | ) | ( |
825 | | ScalarValue::IntervalDayTime(_), |
826 | | ScalarValue::IntervalMonthDayNano(_), |
827 | | ) => { |
828 | | interval_min_max!($OP, $VALUE, $DELTA) |
829 | | } |
830 | | ( |
831 | | ScalarValue::DurationSecond(lhs), |
832 | | ScalarValue::DurationSecond(rhs), |
833 | | ) => { |
834 | | typed_min_max!(lhs, rhs, DurationSecond, $OP) |
835 | | } |
836 | | ( |
837 | | ScalarValue::DurationMillisecond(lhs), |
838 | | ScalarValue::DurationMillisecond(rhs), |
839 | | ) => { |
840 | | typed_min_max!(lhs, rhs, DurationMillisecond, $OP) |
841 | | } |
842 | | ( |
843 | | ScalarValue::DurationMicrosecond(lhs), |
844 | | ScalarValue::DurationMicrosecond(rhs), |
845 | | ) => { |
846 | | typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) |
847 | | } |
848 | | ( |
849 | | ScalarValue::DurationNanosecond(lhs), |
850 | | ScalarValue::DurationNanosecond(rhs), |
851 | | ) => { |
852 | | typed_min_max!(lhs, rhs, DurationNanosecond, $OP) |
853 | | } |
854 | | e => { |
855 | | return internal_err!( |
856 | | "MIN/MAX is not expected to receive scalars of incompatible types {:?}", |
857 | | e |
858 | | ) |
859 | | } |
860 | | }) |
861 | | }}; |
862 | | } |
863 | | |
864 | | /// An accumulator to compute the maximum value |
865 | | #[derive(Debug)] |
866 | | pub struct MaxAccumulator { |
867 | | max: ScalarValue, |
868 | | } |
869 | | |
870 | | impl MaxAccumulator { |
871 | | /// new max accumulator |
872 | 0 | pub fn try_new(datatype: &DataType) -> Result<Self> { |
873 | 0 | Ok(Self { |
874 | 0 | max: ScalarValue::try_from(datatype)?, |
875 | | }) |
876 | 0 | } |
877 | | } |
878 | | |
879 | | impl Accumulator for MaxAccumulator { |
880 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
881 | 0 | let values = &values[0]; |
882 | 0 | let delta = &max_batch(values)?; |
883 | 0 | let new_max: Result<ScalarValue, DataFusionError> = |
884 | 0 | min_max!(&self.max, delta, max); |
885 | 0 | self.max = new_max?; |
886 | 0 | Ok(()) |
887 | 0 | } |
888 | | |
889 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
890 | 0 | self.update_batch(states) |
891 | 0 | } |
892 | | |
893 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
894 | 0 | Ok(vec![self.evaluate()?]) |
895 | 0 | } |
896 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
897 | 0 | Ok(self.max.clone()) |
898 | 0 | } |
899 | | |
900 | 0 | fn size(&self) -> usize { |
901 | 0 | std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() |
902 | 0 | } |
903 | | } |
904 | | |
905 | | #[derive(Debug)] |
906 | | pub struct SlidingMaxAccumulator { |
907 | | max: ScalarValue, |
908 | | moving_max: MovingMax<ScalarValue>, |
909 | | } |
910 | | |
911 | | impl SlidingMaxAccumulator { |
912 | | /// new max accumulator |
913 | 0 | pub fn try_new(datatype: &DataType) -> Result<Self> { |
914 | 0 | Ok(Self { |
915 | 0 | max: ScalarValue::try_from(datatype)?, |
916 | 0 | moving_max: MovingMax::<ScalarValue>::new(), |
917 | | }) |
918 | 0 | } |
919 | | } |
920 | | |
921 | | impl Accumulator for SlidingMaxAccumulator { |
922 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
923 | 0 | for idx in 0..values[0].len() { |
924 | 0 | let val = ScalarValue::try_from_array(&values[0], idx)?; |
925 | 0 | self.moving_max.push(val); |
926 | | } |
927 | 0 | if let Some(res) = self.moving_max.max() { |
928 | 0 | self.max = res.clone(); |
929 | 0 | } |
930 | 0 | Ok(()) |
931 | 0 | } |
932 | | |
933 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
934 | 0 | for _idx in 0..values[0].len() { |
935 | 0 | (self.moving_max).pop(); |
936 | 0 | } |
937 | 0 | if let Some(res) = self.moving_max.max() { |
938 | 0 | self.max = res.clone(); |
939 | 0 | } |
940 | 0 | Ok(()) |
941 | 0 | } |
942 | | |
943 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
944 | 0 | self.update_batch(states) |
945 | 0 | } |
946 | | |
947 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
948 | 0 | Ok(vec![self.max.clone()]) |
949 | 0 | } |
950 | | |
951 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
952 | 0 | Ok(self.max.clone()) |
953 | 0 | } |
954 | | |
955 | 0 | fn supports_retract_batch(&self) -> bool { |
956 | 0 | true |
957 | 0 | } |
958 | | |
959 | 0 | fn size(&self) -> usize { |
960 | 0 | std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() |
961 | 0 | } |
962 | | } |
963 | | |
964 | | #[derive(Debug)] |
965 | | pub struct Min { |
966 | | signature: Signature, |
967 | | } |
968 | | |
969 | | impl Min { |
970 | 0 | pub fn new() -> Self { |
971 | 0 | Self { |
972 | 0 | signature: Signature::user_defined(Volatility::Immutable), |
973 | 0 | } |
974 | 0 | } |
975 | | } |
976 | | |
977 | | impl Default for Min { |
978 | 0 | fn default() -> Self { |
979 | 0 | Self::new() |
980 | 0 | } |
981 | | } |
982 | | |
983 | | impl FromColumnStatistics for Min { |
984 | 0 | fn value_from_column_statistics( |
985 | 0 | &self, |
986 | 0 | col_stats: &ColumnStatistics, |
987 | 0 | ) -> Option<ScalarValue> { |
988 | 0 | if let Precision::Exact(ref val) = col_stats.min_value { |
989 | 0 | if !val.is_null() { |
990 | 0 | return Some(val.clone()); |
991 | 0 | } |
992 | 0 | } |
993 | 0 | None |
994 | 0 | } |
995 | | } |
996 | | |
997 | | impl AggregateUDFImpl for Min { |
998 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
999 | 0 | self |
1000 | 0 | } |
1001 | | |
1002 | 0 | fn name(&self) -> &str { |
1003 | 0 | "min" |
1004 | 0 | } |
1005 | | |
1006 | 0 | fn signature(&self) -> &Signature { |
1007 | 0 | &self.signature |
1008 | 0 | } |
1009 | | |
1010 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
1011 | 0 | Ok(arg_types[0].to_owned()) |
1012 | 0 | } |
1013 | | |
1014 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
1015 | 0 | Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?)) |
1016 | 0 | } |
1017 | | |
1018 | 0 | fn aliases(&self) -> &[String] { |
1019 | 0 | &[] |
1020 | 0 | } |
1021 | | |
1022 | 0 | fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { |
1023 | | use DataType::*; |
1024 | 0 | matches!( |
1025 | 0 | args.return_type, |
1026 | | Int8 | Int16 |
1027 | | | Int32 |
1028 | | | Int64 |
1029 | | | UInt8 |
1030 | | | UInt16 |
1031 | | | UInt32 |
1032 | | | UInt64 |
1033 | | | Float16 |
1034 | | | Float32 |
1035 | | | Float64 |
1036 | | | Decimal128(_, _) |
1037 | | | Decimal256(_, _) |
1038 | | | Date32 |
1039 | | | Date64 |
1040 | | | Time32(_) |
1041 | | | Time64(_) |
1042 | | | Timestamp(_, _) |
1043 | | ) |
1044 | 0 | } |
1045 | | |
1046 | 0 | fn create_groups_accumulator( |
1047 | 0 | &self, |
1048 | 0 | args: AccumulatorArgs, |
1049 | 0 | ) -> Result<Box<dyn GroupsAccumulator>> { |
1050 | | use DataType::*; |
1051 | | use TimeUnit::*; |
1052 | 0 | let data_type = args.return_type; |
1053 | 0 | match data_type { |
1054 | 0 | Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), |
1055 | 0 | Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), |
1056 | 0 | Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type), |
1057 | 0 | Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type), |
1058 | 0 | UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type), |
1059 | 0 | UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), |
1060 | 0 | UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), |
1061 | 0 | UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), |
1062 | | Float16 => { |
1063 | 0 | instantiate_min_accumulator!(data_type, f16, Float16Type) |
1064 | | } |
1065 | | Float32 => { |
1066 | 0 | instantiate_min_accumulator!(data_type, f32, Float32Type) |
1067 | | } |
1068 | | Float64 => { |
1069 | 0 | instantiate_min_accumulator!(data_type, f64, Float64Type) |
1070 | | } |
1071 | 0 | Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type), |
1072 | 0 | Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type), |
1073 | | Time32(Second) => { |
1074 | 0 | instantiate_min_accumulator!(data_type, i32, Time32SecondType) |
1075 | | } |
1076 | | Time32(Millisecond) => { |
1077 | 0 | instantiate_min_accumulator!(data_type, i32, Time32MillisecondType) |
1078 | | } |
1079 | | Time64(Microsecond) => { |
1080 | 0 | instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType) |
1081 | | } |
1082 | | Time64(Nanosecond) => { |
1083 | 0 | instantiate_min_accumulator!(data_type, i64, Time64NanosecondType) |
1084 | | } |
1085 | | Timestamp(Second, _) => { |
1086 | 0 | instantiate_min_accumulator!(data_type, i64, TimestampSecondType) |
1087 | | } |
1088 | | Timestamp(Millisecond, _) => { |
1089 | 0 | instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType) |
1090 | | } |
1091 | | Timestamp(Microsecond, _) => { |
1092 | 0 | instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType) |
1093 | | } |
1094 | | Timestamp(Nanosecond, _) => { |
1095 | 0 | instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType) |
1096 | | } |
1097 | | Decimal128(_, _) => { |
1098 | 0 | instantiate_min_accumulator!(data_type, i128, Decimal128Type) |
1099 | | } |
1100 | | Decimal256(_, _) => { |
1101 | 0 | instantiate_min_accumulator!(data_type, i256, Decimal256Type) |
1102 | | } |
1103 | | |
1104 | | // It would be nice to have a fast implementation for Strings as well |
1105 | | // https://github.com/apache/datafusion/issues/6906 |
1106 | | |
1107 | | // This is only reached if groups_accumulator_supported is out of sync |
1108 | 0 | _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), |
1109 | | } |
1110 | 0 | } |
1111 | | |
1112 | 0 | fn create_sliding_accumulator( |
1113 | 0 | &self, |
1114 | 0 | args: AccumulatorArgs, |
1115 | 0 | ) -> Result<Box<dyn Accumulator>> { |
1116 | 0 | Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?)) |
1117 | 0 | } |
1118 | | |
1119 | 0 | fn is_descending(&self) -> Option<bool> { |
1120 | 0 | Some(false) |
1121 | 0 | } |
1122 | | |
1123 | 0 | fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> { |
1124 | 0 | self.value_from_statistics(statistics_args) |
1125 | 0 | } |
1126 | 0 | fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { |
1127 | 0 | datafusion_expr::utils::AggregateOrderSensitivity::Insensitive |
1128 | 0 | } |
1129 | | |
1130 | 0 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
1131 | 0 | get_min_max_result_type(arg_types) |
1132 | 0 | } |
1133 | | |
1134 | 0 | fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { |
1135 | 0 | datafusion_expr::ReversedUDAF::Identical |
1136 | 0 | } |
1137 | | } |
1138 | | /// An accumulator to compute the minimum value |
1139 | | #[derive(Debug)] |
1140 | | pub struct MinAccumulator { |
1141 | | min: ScalarValue, |
1142 | | } |
1143 | | |
1144 | | impl MinAccumulator { |
1145 | | /// new min accumulator |
1146 | 0 | pub fn try_new(datatype: &DataType) -> Result<Self> { |
1147 | 0 | Ok(Self { |
1148 | 0 | min: ScalarValue::try_from(datatype)?, |
1149 | | }) |
1150 | 0 | } |
1151 | | } |
1152 | | |
1153 | | impl Accumulator for MinAccumulator { |
1154 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
1155 | 0 | Ok(vec![self.evaluate()?]) |
1156 | 0 | } |
1157 | | |
1158 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
1159 | 0 | let values = &values[0]; |
1160 | 0 | let delta = &min_batch(values)?; |
1161 | 0 | let new_min: Result<ScalarValue, DataFusionError> = |
1162 | 0 | min_max!(&self.min, delta, min); |
1163 | 0 | self.min = new_min?; |
1164 | 0 | Ok(()) |
1165 | 0 | } |
1166 | | |
1167 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
1168 | 0 | self.update_batch(states) |
1169 | 0 | } |
1170 | | |
1171 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
1172 | 0 | Ok(self.min.clone()) |
1173 | 0 | } |
1174 | | |
1175 | 0 | fn size(&self) -> usize { |
1176 | 0 | std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() |
1177 | 0 | } |
1178 | | } |
1179 | | |
1180 | | #[derive(Debug)] |
1181 | | pub struct SlidingMinAccumulator { |
1182 | | min: ScalarValue, |
1183 | | moving_min: MovingMin<ScalarValue>, |
1184 | | } |
1185 | | |
1186 | | impl SlidingMinAccumulator { |
1187 | 0 | pub fn try_new(datatype: &DataType) -> Result<Self> { |
1188 | 0 | Ok(Self { |
1189 | 0 | min: ScalarValue::try_from(datatype)?, |
1190 | 0 | moving_min: MovingMin::<ScalarValue>::new(), |
1191 | | }) |
1192 | 0 | } |
1193 | | } |
1194 | | |
1195 | | impl Accumulator for SlidingMinAccumulator { |
1196 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
1197 | 0 | Ok(vec![self.min.clone()]) |
1198 | 0 | } |
1199 | | |
1200 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
1201 | 0 | for idx in 0..values[0].len() { |
1202 | 0 | let val = ScalarValue::try_from_array(&values[0], idx)?; |
1203 | 0 | if !val.is_null() { |
1204 | 0 | self.moving_min.push(val); |
1205 | 0 | } |
1206 | | } |
1207 | 0 | if let Some(res) = self.moving_min.min() { |
1208 | 0 | self.min = res.clone(); |
1209 | 0 | } |
1210 | 0 | Ok(()) |
1211 | 0 | } |
1212 | | |
1213 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
1214 | 0 | for idx in 0..values[0].len() { |
1215 | 0 | let val = ScalarValue::try_from_array(&values[0], idx)?; |
1216 | 0 | if !val.is_null() { |
1217 | 0 | (self.moving_min).pop(); |
1218 | 0 | } |
1219 | | } |
1220 | 0 | if let Some(res) = self.moving_min.min() { |
1221 | 0 | self.min = res.clone(); |
1222 | 0 | } |
1223 | 0 | Ok(()) |
1224 | 0 | } |
1225 | | |
1226 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
1227 | 0 | self.update_batch(states) |
1228 | 0 | } |
1229 | | |
1230 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
1231 | 0 | Ok(self.min.clone()) |
1232 | 0 | } |
1233 | | |
1234 | 0 | fn supports_retract_batch(&self) -> bool { |
1235 | 0 | true |
1236 | 0 | } |
1237 | | |
1238 | 0 | fn size(&self) -> usize { |
1239 | 0 | std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() |
1240 | 0 | } |
1241 | | } |
1242 | | |
1243 | | // |
1244 | | // Moving min and moving max |
1245 | | // The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. |
1246 | | |
1247 | | // Keep track of the minimum or maximum value in a sliding window. |
1248 | | // |
1249 | | // `moving min max` provides one data structure for keeping track of the |
1250 | | // minimum value and one for keeping track of the maximum value in a sliding |
1251 | | // window. |
1252 | | // |
1253 | | // Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, |
1254 | | // push to this stack all elements popped from first stack while updating their current min/max. Now pop from |
1255 | | // the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, |
1256 | | // look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. |
1257 | | // |
1258 | | // The complexity of the operations are |
1259 | | // - O(1) for getting the minimum/maximum |
1260 | | // - O(1) for push |
1261 | | // - amortized O(1) for pop |
1262 | | |
1263 | | /// ``` |
1264 | | /// # use datafusion_functions_aggregate::min_max::MovingMin; |
1265 | | /// let mut moving_min = MovingMin::<i32>::new(); |
1266 | | /// moving_min.push(2); |
1267 | | /// moving_min.push(1); |
1268 | | /// moving_min.push(3); |
1269 | | /// |
1270 | | /// assert_eq!(moving_min.min(), Some(&1)); |
1271 | | /// assert_eq!(moving_min.pop(), Some(2)); |
1272 | | /// |
1273 | | /// assert_eq!(moving_min.min(), Some(&1)); |
1274 | | /// assert_eq!(moving_min.pop(), Some(1)); |
1275 | | /// |
1276 | | /// assert_eq!(moving_min.min(), Some(&3)); |
1277 | | /// assert_eq!(moving_min.pop(), Some(3)); |
1278 | | /// |
1279 | | /// assert_eq!(moving_min.min(), None); |
1280 | | /// assert_eq!(moving_min.pop(), None); |
1281 | | /// ``` |
1282 | | #[derive(Debug)] |
1283 | | pub struct MovingMin<T> { |
1284 | | push_stack: Vec<(T, T)>, |
1285 | | pop_stack: Vec<(T, T)>, |
1286 | | } |
1287 | | |
1288 | | impl<T: Clone + PartialOrd> Default for MovingMin<T> { |
1289 | 0 | fn default() -> Self { |
1290 | 0 | Self { |
1291 | 0 | push_stack: Vec::new(), |
1292 | 0 | pop_stack: Vec::new(), |
1293 | 0 | } |
1294 | 0 | } |
1295 | | } |
1296 | | |
1297 | | impl<T: Clone + PartialOrd> MovingMin<T> { |
1298 | | /// Creates a new `MovingMin` to keep track of the minimum in a sliding |
1299 | | /// window. |
1300 | | #[inline] |
1301 | 0 | pub fn new() -> Self { |
1302 | 0 | Self::default() |
1303 | 0 | } |
1304 | | |
1305 | | /// Creates a new `MovingMin` to keep track of the minimum in a sliding |
1306 | | /// window with `capacity` allocated slots. |
1307 | | #[inline] |
1308 | 0 | pub fn with_capacity(capacity: usize) -> Self { |
1309 | 0 | Self { |
1310 | 0 | push_stack: Vec::with_capacity(capacity), |
1311 | 0 | pop_stack: Vec::with_capacity(capacity), |
1312 | 0 | } |
1313 | 0 | } |
1314 | | |
1315 | | /// Returns the minimum of the sliding window or `None` if the window is |
1316 | | /// empty. |
1317 | | #[inline] |
1318 | 0 | pub fn min(&self) -> Option<&T> { |
1319 | 0 | match (self.push_stack.last(), self.pop_stack.last()) { |
1320 | 0 | (None, None) => None, |
1321 | 0 | (Some((_, min)), None) => Some(min), |
1322 | 0 | (None, Some((_, min))) => Some(min), |
1323 | 0 | (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), |
1324 | | } |
1325 | 0 | } |
1326 | | |
1327 | | /// Pushes a new element into the sliding window. |
1328 | | #[inline] |
1329 | 0 | pub fn push(&mut self, val: T) { |
1330 | 0 | self.push_stack.push(match self.push_stack.last() { |
1331 | 0 | Some((_, min)) => { |
1332 | 0 | if val > *min { |
1333 | 0 | (val, min.clone()) |
1334 | | } else { |
1335 | 0 | (val.clone(), val) |
1336 | | } |
1337 | | } |
1338 | 0 | None => (val.clone(), val), |
1339 | | }); |
1340 | 0 | } |
1341 | | |
1342 | | /// Removes and returns the last value of the sliding window. |
1343 | | #[inline] |
1344 | 0 | pub fn pop(&mut self) -> Option<T> { |
1345 | 0 | if self.pop_stack.is_empty() { |
1346 | 0 | match self.push_stack.pop() { |
1347 | 0 | Some((val, _)) => { |
1348 | 0 | let mut last = (val.clone(), val); |
1349 | 0 | self.pop_stack.push(last.clone()); |
1350 | 0 | while let Some((val, _)) = self.push_stack.pop() { |
1351 | 0 | let min = if last.1 < val { |
1352 | 0 | last.1.clone() |
1353 | | } else { |
1354 | 0 | val.clone() |
1355 | | }; |
1356 | 0 | last = (val.clone(), min); |
1357 | 0 | self.pop_stack.push(last.clone()); |
1358 | | } |
1359 | | } |
1360 | 0 | None => return None, |
1361 | | } |
1362 | 0 | } |
1363 | 0 | self.pop_stack.pop().map(|(val, _)| val) |
1364 | 0 | } |
1365 | | |
1366 | | /// Returns the number of elements stored in the sliding window. |
1367 | | #[inline] |
1368 | 0 | pub fn len(&self) -> usize { |
1369 | 0 | self.push_stack.len() + self.pop_stack.len() |
1370 | 0 | } |
1371 | | |
1372 | | /// Returns `true` if the moving window contains no elements. |
1373 | | #[inline] |
1374 | 0 | pub fn is_empty(&self) -> bool { |
1375 | 0 | self.len() == 0 |
1376 | 0 | } |
1377 | | } |
1378 | | /// ``` |
1379 | | /// # use datafusion_functions_aggregate::min_max::MovingMax; |
1380 | | /// let mut moving_max = MovingMax::<i32>::new(); |
1381 | | /// moving_max.push(2); |
1382 | | /// moving_max.push(3); |
1383 | | /// moving_max.push(1); |
1384 | | /// |
1385 | | /// assert_eq!(moving_max.max(), Some(&3)); |
1386 | | /// assert_eq!(moving_max.pop(), Some(2)); |
1387 | | /// |
1388 | | /// assert_eq!(moving_max.max(), Some(&3)); |
1389 | | /// assert_eq!(moving_max.pop(), Some(3)); |
1390 | | /// |
1391 | | /// assert_eq!(moving_max.max(), Some(&1)); |
1392 | | /// assert_eq!(moving_max.pop(), Some(1)); |
1393 | | /// |
1394 | | /// assert_eq!(moving_max.max(), None); |
1395 | | /// assert_eq!(moving_max.pop(), None); |
1396 | | /// ``` |
1397 | | #[derive(Debug)] |
1398 | | pub struct MovingMax<T> { |
1399 | | push_stack: Vec<(T, T)>, |
1400 | | pop_stack: Vec<(T, T)>, |
1401 | | } |
1402 | | |
1403 | | impl<T: Clone + PartialOrd> Default for MovingMax<T> { |
1404 | 0 | fn default() -> Self { |
1405 | 0 | Self { |
1406 | 0 | push_stack: Vec::new(), |
1407 | 0 | pop_stack: Vec::new(), |
1408 | 0 | } |
1409 | 0 | } |
1410 | | } |
1411 | | |
1412 | | impl<T: Clone + PartialOrd> MovingMax<T> { |
1413 | | /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. |
1414 | | #[inline] |
1415 | 0 | pub fn new() -> Self { |
1416 | 0 | Self::default() |
1417 | 0 | } |
1418 | | |
1419 | | /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with |
1420 | | /// `capacity` allocated slots. |
1421 | | #[inline] |
1422 | 0 | pub fn with_capacity(capacity: usize) -> Self { |
1423 | 0 | Self { |
1424 | 0 | push_stack: Vec::with_capacity(capacity), |
1425 | 0 | pop_stack: Vec::with_capacity(capacity), |
1426 | 0 | } |
1427 | 0 | } |
1428 | | |
1429 | | /// Returns the maximum of the sliding window or `None` if the window is empty. |
1430 | | #[inline] |
1431 | 0 | pub fn max(&self) -> Option<&T> { |
1432 | 0 | match (self.push_stack.last(), self.pop_stack.last()) { |
1433 | 0 | (None, None) => None, |
1434 | 0 | (Some((_, max)), None) => Some(max), |
1435 | 0 | (None, Some((_, max))) => Some(max), |
1436 | 0 | (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), |
1437 | | } |
1438 | 0 | } |
1439 | | |
1440 | | /// Pushes a new element into the sliding window. |
1441 | | #[inline] |
1442 | 0 | pub fn push(&mut self, val: T) { |
1443 | 0 | self.push_stack.push(match self.push_stack.last() { |
1444 | 0 | Some((_, max)) => { |
1445 | 0 | if val < *max { |
1446 | 0 | (val, max.clone()) |
1447 | | } else { |
1448 | 0 | (val.clone(), val) |
1449 | | } |
1450 | | } |
1451 | 0 | None => (val.clone(), val), |
1452 | | }); |
1453 | 0 | } |
1454 | | |
1455 | | /// Removes and returns the last value of the sliding window. |
1456 | | #[inline] |
1457 | 0 | pub fn pop(&mut self) -> Option<T> { |
1458 | 0 | if self.pop_stack.is_empty() { |
1459 | 0 | match self.push_stack.pop() { |
1460 | 0 | Some((val, _)) => { |
1461 | 0 | let mut last = (val.clone(), val); |
1462 | 0 | self.pop_stack.push(last.clone()); |
1463 | 0 | while let Some((val, _)) = self.push_stack.pop() { |
1464 | 0 | let max = if last.1 > val { |
1465 | 0 | last.1.clone() |
1466 | | } else { |
1467 | 0 | val.clone() |
1468 | | }; |
1469 | 0 | last = (val.clone(), max); |
1470 | 0 | self.pop_stack.push(last.clone()); |
1471 | | } |
1472 | | } |
1473 | 0 | None => return None, |
1474 | | } |
1475 | 0 | } |
1476 | 0 | self.pop_stack.pop().map(|(val, _)| val) |
1477 | 0 | } |
1478 | | |
1479 | | /// Returns the number of elements stored in the sliding window. |
1480 | | #[inline] |
1481 | 0 | pub fn len(&self) -> usize { |
1482 | 0 | self.push_stack.len() + self.pop_stack.len() |
1483 | 0 | } |
1484 | | |
1485 | | /// Returns `true` if the moving window contains no elements. |
1486 | | #[inline] |
1487 | 0 | pub fn is_empty(&self) -> bool { |
1488 | 0 | self.len() == 0 |
1489 | 0 | } |
1490 | | } |
1491 | | |
1492 | | make_udaf_expr_and_func!( |
1493 | | Max, |
1494 | | max, |
1495 | | expression, |
1496 | | "Returns the maximum of a group of values.", |
1497 | | max_udaf |
1498 | | ); |
1499 | | |
1500 | | make_udaf_expr_and_func!( |
1501 | | Min, |
1502 | | min, |
1503 | | expression, |
1504 | | "Returns the minimum of a group of values.", |
1505 | | min_udaf |
1506 | | ); |
1507 | | |
1508 | | #[cfg(test)] |
1509 | | mod tests { |
1510 | | use super::*; |
1511 | | use arrow::datatypes::{ |
1512 | | IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, |
1513 | | }; |
1514 | | use std::sync::Arc; |
1515 | | |
1516 | | #[test] |
1517 | | fn interval_min_max() { |
1518 | | // IntervalYearMonth |
1519 | | let b = IntervalYearMonthArray::from(vec![ |
1520 | | IntervalYearMonthType::make_value(0, 1), |
1521 | | IntervalYearMonthType::make_value(5, 34), |
1522 | | IntervalYearMonthType::make_value(-2, 4), |
1523 | | IntervalYearMonthType::make_value(7, -4), |
1524 | | IntervalYearMonthType::make_value(0, 1), |
1525 | | ]); |
1526 | | let b: ArrayRef = Arc::new(b); |
1527 | | |
1528 | | let mut min = |
1529 | | MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) |
1530 | | .unwrap(); |
1531 | | min.update_batch(&[Arc::clone(&b)]).unwrap(); |
1532 | | let min_res = min.evaluate().unwrap(); |
1533 | | assert_eq!( |
1534 | | min_res, |
1535 | | ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value( |
1536 | | -2, 4 |
1537 | | ))) |
1538 | | ); |
1539 | | |
1540 | | let mut max = |
1541 | | MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) |
1542 | | .unwrap(); |
1543 | | max.update_batch(&[Arc::clone(&b)]).unwrap(); |
1544 | | let max_res = max.evaluate().unwrap(); |
1545 | | assert_eq!( |
1546 | | max_res, |
1547 | | ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value( |
1548 | | 5, 34 |
1549 | | ))) |
1550 | | ); |
1551 | | |
1552 | | // IntervalDayTime |
1553 | | let b = IntervalDayTimeArray::from(vec![ |
1554 | | IntervalDayTimeType::make_value(0, 0), |
1555 | | IntervalDayTimeType::make_value(5, 454000), |
1556 | | IntervalDayTimeType::make_value(-34, 0), |
1557 | | IntervalDayTimeType::make_value(7, -4000), |
1558 | | IntervalDayTimeType::make_value(1, 0), |
1559 | | ]); |
1560 | | let b: ArrayRef = Arc::new(b); |
1561 | | |
1562 | | let mut min = |
1563 | | MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); |
1564 | | min.update_batch(&[Arc::clone(&b)]).unwrap(); |
1565 | | let min_res = min.evaluate().unwrap(); |
1566 | | assert_eq!( |
1567 | | min_res, |
1568 | | ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0))) |
1569 | | ); |
1570 | | |
1571 | | let mut max = |
1572 | | MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); |
1573 | | max.update_batch(&[Arc::clone(&b)]).unwrap(); |
1574 | | let max_res = max.evaluate().unwrap(); |
1575 | | assert_eq!( |
1576 | | max_res, |
1577 | | ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000))) |
1578 | | ); |
1579 | | |
1580 | | // IntervalMonthDayNano |
1581 | | let b = IntervalMonthDayNanoArray::from(vec![ |
1582 | | IntervalMonthDayNanoType::make_value(1, 0, 0), |
1583 | | IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000), |
1584 | | IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000), |
1585 | | IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000), |
1586 | | IntervalMonthDayNanoType::make_value(1, 0, 0), |
1587 | | ]); |
1588 | | let b: ArrayRef = Arc::new(b); |
1589 | | |
1590 | | let mut min = |
1591 | | MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) |
1592 | | .unwrap(); |
1593 | | min.update_batch(&[Arc::clone(&b)]).unwrap(); |
1594 | | let min_res = min.evaluate().unwrap(); |
1595 | | assert_eq!( |
1596 | | min_res, |
1597 | | ScalarValue::IntervalMonthDayNano(Some( |
1598 | | IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000) |
1599 | | )) |
1600 | | ); |
1601 | | |
1602 | | let mut max = |
1603 | | MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) |
1604 | | .unwrap(); |
1605 | | max.update_batch(&[Arc::clone(&b)]).unwrap(); |
1606 | | let max_res = max.evaluate().unwrap(); |
1607 | | assert_eq!( |
1608 | | max_res, |
1609 | | ScalarValue::IntervalMonthDayNano(Some( |
1610 | | IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000) |
1611 | | )) |
1612 | | ); |
1613 | | } |
1614 | | |
1615 | | #[test] |
1616 | | fn float_min_max_with_nans() { |
1617 | | let pos_nan = f32::NAN; |
1618 | | let zero = 0_f32; |
1619 | | let neg_inf = f32::NEG_INFINITY; |
1620 | | |
1621 | | let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| { |
1622 | | for batch in values.iter() { |
1623 | | let batch = |
1624 | | Arc::new(Float32Array::from_iter_values(batch.iter().copied())); |
1625 | | acc.update_batch(&[batch]).unwrap(); |
1626 | | } |
1627 | | let result = acc.evaluate().unwrap(); |
1628 | | assert_eq!(result, ScalarValue::Float32(Some(expected))); |
1629 | | }; |
1630 | | |
1631 | | // This test checks both comparison between batches (which uses the min_max macro |
1632 | | // defined above) and within a batch (which uses the arrow min/max compute function |
1633 | | // and verifies both respect the total order comparison for floats) |
1634 | | |
1635 | | let min = || MinAccumulator::try_new(&DataType::Float32).unwrap(); |
1636 | | let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap(); |
1637 | | |
1638 | | check(&mut min(), &[&[zero], &[pos_nan]], zero); |
1639 | | check(&mut min(), &[&[zero, pos_nan]], zero); |
1640 | | check(&mut min(), &[&[zero], &[neg_inf]], neg_inf); |
1641 | | check(&mut min(), &[&[zero, neg_inf]], neg_inf); |
1642 | | check(&mut max(), &[&[zero], &[pos_nan]], pos_nan); |
1643 | | check(&mut max(), &[&[zero, pos_nan]], pos_nan); |
1644 | | check(&mut max(), &[&[zero], &[neg_inf]], zero); |
1645 | | check(&mut max(), &[&[zero, neg_inf]], zero); |
1646 | | } |
1647 | | |
1648 | | use datafusion_common::Result; |
1649 | | use rand::Rng; |
1650 | | |
1651 | | fn get_random_vec_i32(len: usize) -> Vec<i32> { |
1652 | | let mut rng = rand::thread_rng(); |
1653 | | let mut input = Vec::with_capacity(len); |
1654 | | for _i in 0..len { |
1655 | | input.push(rng.gen_range(0..100)); |
1656 | | } |
1657 | | input |
1658 | | } |
1659 | | |
1660 | | fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { |
1661 | | let data = get_random_vec_i32(len); |
1662 | | let mut expected = Vec::with_capacity(len); |
1663 | | let mut moving_min = MovingMin::<i32>::new(); |
1664 | | let mut res = Vec::with_capacity(len); |
1665 | | for i in 0..len { |
1666 | | let start = i.saturating_sub(n_sliding_window); |
1667 | | expected.push(*data[start..i + 1].iter().min().unwrap()); |
1668 | | |
1669 | | moving_min.push(data[i]); |
1670 | | if i > n_sliding_window { |
1671 | | moving_min.pop(); |
1672 | | } |
1673 | | res.push(*moving_min.min().unwrap()); |
1674 | | } |
1675 | | assert_eq!(res, expected); |
1676 | | Ok(()) |
1677 | | } |
1678 | | |
1679 | | fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { |
1680 | | let data = get_random_vec_i32(len); |
1681 | | let mut expected = Vec::with_capacity(len); |
1682 | | let mut moving_max = MovingMax::<i32>::new(); |
1683 | | let mut res = Vec::with_capacity(len); |
1684 | | for i in 0..len { |
1685 | | let start = i.saturating_sub(n_sliding_window); |
1686 | | expected.push(*data[start..i + 1].iter().max().unwrap()); |
1687 | | |
1688 | | moving_max.push(data[i]); |
1689 | | if i > n_sliding_window { |
1690 | | moving_max.pop(); |
1691 | | } |
1692 | | res.push(*moving_max.max().unwrap()); |
1693 | | } |
1694 | | assert_eq!(res, expected); |
1695 | | Ok(()) |
1696 | | } |
1697 | | |
1698 | | #[test] |
1699 | | fn moving_min_tests() -> Result<()> { |
1700 | | moving_min_i32(100, 10)?; |
1701 | | moving_min_i32(100, 20)?; |
1702 | | moving_min_i32(100, 50)?; |
1703 | | moving_min_i32(100, 100)?; |
1704 | | Ok(()) |
1705 | | } |
1706 | | |
1707 | | #[test] |
1708 | | fn moving_max_tests() -> Result<()> { |
1709 | | moving_max_i32(100, 10)?; |
1710 | | moving_max_i32(100, 20)?; |
1711 | | moving_max_i32(100, 50)?; |
1712 | | moving_max_i32(100, 100)?; |
1713 | | Ok(()) |
1714 | | } |
1715 | | |
1716 | | #[test] |
1717 | | fn test_min_max_coerce_types() { |
1718 | | // the coerced types is same with input types |
1719 | | let funs: Vec<Box<dyn AggregateUDFImpl>> = |
1720 | | vec![Box::new(Min::new()), Box::new(Max::new())]; |
1721 | | let input_types = vec![ |
1722 | | vec![DataType::Int32], |
1723 | | vec![DataType::Decimal128(10, 2)], |
1724 | | vec![DataType::Decimal256(1, 1)], |
1725 | | vec![DataType::Utf8], |
1726 | | ]; |
1727 | | for fun in funs { |
1728 | | for input_type in &input_types { |
1729 | | let result = fun.coerce_types(input_type); |
1730 | | assert_eq!(*input_type, result.unwrap()); |
1731 | | } |
1732 | | } |
1733 | | } |
1734 | | |
1735 | | #[test] |
1736 | | fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { |
1737 | | let data_type = |
1738 | | DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); |
1739 | | let result = get_min_max_result_type(&[data_type])?; |
1740 | | assert_eq!(result, vec![DataType::Int32]); |
1741 | | Ok(()) |
1742 | | } |
1743 | | } |