Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/min_max.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// "License"); you may not use this file except in compliance
6
// with the License.  You may obtain a copy of the License at
7
//
8
//   http://www.apache.org/licenses/LICENSE-2.0
9
//
10
// Unless required by applicable law or agreed to in writing,
11
// software distributed under the License is distributed on an
12
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
// KIND, either express or implied.  See the License for the
14
// specific language governing permissions and limitations
15
// under the License.
16
17
//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
18
//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
19
20
// distributed with this work for additional information
21
// regarding copyright ownership.  The ASF licenses this file
22
// to you under the Apache License, Version 2.0 (the
23
// "License"); you may not use this file except in compliance
24
// with the License.  You may obtain a copy of the License at
25
//
26
//   http://www.apache.org/licenses/LICENSE-2.0
27
//
28
// Unless required by applicable law or agreed to in writing,
29
// software distributed under the License is distributed on an
30
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
31
// KIND, either express or implied.  See the License for the
32
// specific language governing permissions and limitations
33
// under the License.
34
35
use arrow::array::{
36
    ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
37
    Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
38
    Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
39
    IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
40
    LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
41
    Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
42
    TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
43
    TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
44
};
45
use arrow::compute;
46
use arrow::datatypes::{
47
    DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
48
    Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
49
    UInt8Type,
50
};
51
use arrow_schema::IntervalUnit;
52
use datafusion_common::stats::Precision;
53
use datafusion_common::{
54
    downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result,
55
};
56
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
57
use datafusion_physical_expr::expressions;
58
use std::fmt::Debug;
59
60
use arrow::datatypes::i256;
61
use arrow::datatypes::{
62
    Date32Type, Date64Type, Time32MillisecondType, Time32SecondType,
63
    Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
64
    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
65
};
66
67
use datafusion_common::ScalarValue;
68
use datafusion_expr::{
69
    function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
70
};
71
use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
72
use half::f16;
73
use std::ops::Deref;
74
75
0
fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
76
0
    // make sure that the input types only has one element.
77
0
    if input_types.len() != 1 {
78
0
        return exec_err!(
79
0
            "min/max was called with {} arguments. It requires only 1.",
80
0
            input_types.len()
81
0
        );
82
0
    }
83
0
    // min and max support the dictionary data type
84
0
    // unpack the dictionary to get the value
85
0
    match &input_types[0] {
86
0
        DataType::Dictionary(_, dict_value_type) => {
87
0
            // TODO add checker, if the value type is complex data type
88
0
            Ok(vec![dict_value_type.deref().clone()])
89
        }
90
        // TODO add checker for datatype which min and max supported
91
        // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
92
0
        _ => Ok(input_types.to_vec()),
93
    }
94
0
}
95
96
// MAX aggregate UDF
97
#[derive(Debug)]
98
pub struct Max {
99
    signature: Signature,
100
}
101
102
impl Max {
103
0
    pub fn new() -> Self {
104
0
        Self {
105
0
            signature: Signature::user_defined(Volatility::Immutable),
106
0
        }
107
0
    }
108
}
109
110
impl Default for Max {
111
0
    fn default() -> Self {
112
0
        Self::new()
113
0
    }
114
}
115
/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX`
116
/// the specified [`ArrowPrimitiveType`].
117
///
118
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
119
macro_rules! instantiate_max_accumulator {
120
    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
121
        Ok(Box::new(
122
0
            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| {
123
0
                if *cur < new {
124
0
                    *cur = new
125
0
                }
126
0
            })
127
            // Initialize each accumulator to $NATIVE::MIN
128
            .with_starting_value($NATIVE::MIN),
129
        ))
130
    }};
131
}
132
133
/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN`
134
/// the specified [`ArrowPrimitiveType`].
135
///
136
///
137
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
138
macro_rules! instantiate_min_accumulator {
139
    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
140
        Ok(Box::new(
141
0
            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| {
142
0
                if *cur > new {
143
0
                    *cur = new
144
0
                }
145
0
            })
146
            // Initialize each accumulator to $NATIVE::MAX
147
            .with_starting_value($NATIVE::MAX),
148
        ))
149
    }};
150
}
151
152
trait FromColumnStatistics {
153
    fn value_from_column_statistics(
154
        &self,
155
        stats: &ColumnStatistics,
156
    ) -> Option<ScalarValue>;
157
158
0
    fn value_from_statistics(
159
0
        &self,
160
0
        statistics_args: &StatisticsArgs,
161
0
    ) -> Option<ScalarValue> {
162
0
        if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
163
0
            match *num_rows {
164
0
                0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
165
0
                value if value > 0 => {
166
0
                    let col_stats = &statistics_args.statistics.column_statistics;
167
0
                    if statistics_args.exprs.len() == 1 {
168
                        // TODO optimize with exprs other than Column
169
0
                        if let Some(col_expr) = statistics_args.exprs[0]
170
0
                            .as_any()
171
0
                            .downcast_ref::<expressions::Column>()
172
                        {
173
0
                            return self.value_from_column_statistics(
174
0
                                &col_stats[col_expr.index()],
175
0
                            );
176
0
                        }
177
0
                    }
178
                }
179
0
                _ => {}
180
            }
181
0
        }
182
0
        None
183
0
    }
184
}
185
186
impl FromColumnStatistics for Max {
187
0
    fn value_from_column_statistics(
188
0
        &self,
189
0
        col_stats: &ColumnStatistics,
190
0
    ) -> Option<ScalarValue> {
191
0
        if let Precision::Exact(ref val) = col_stats.max_value {
192
0
            if !val.is_null() {
193
0
                return Some(val.clone());
194
0
            }
195
0
        }
196
0
        None
197
0
    }
198
}
199
200
impl AggregateUDFImpl for Max {
201
0
    fn as_any(&self) -> &dyn std::any::Any {
202
0
        self
203
0
    }
204
205
0
    fn name(&self) -> &str {
206
0
        "max"
207
0
    }
208
209
0
    fn signature(&self) -> &Signature {
210
0
        &self.signature
211
0
    }
212
213
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
214
0
        Ok(arg_types[0].to_owned())
215
0
    }
216
217
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
218
0
        Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?))
219
0
    }
220
221
0
    fn aliases(&self) -> &[String] {
222
0
        &[]
223
0
    }
224
225
0
    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
226
        use DataType::*;
227
0
        matches!(
228
0
            args.return_type,
229
            Int8 | Int16
230
                | Int32
231
                | Int64
232
                | UInt8
233
                | UInt16
234
                | UInt32
235
                | UInt64
236
                | Float16
237
                | Float32
238
                | Float64
239
                | Decimal128(_, _)
240
                | Decimal256(_, _)
241
                | Date32
242
                | Date64
243
                | Time32(_)
244
                | Time64(_)
245
                | Timestamp(_, _)
246
        )
247
0
    }
248
249
0
    fn create_groups_accumulator(
250
0
        &self,
251
0
        args: AccumulatorArgs,
252
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
253
        use DataType::*;
254
        use TimeUnit::*;
255
0
        let data_type = args.return_type;
256
0
        match data_type {
257
0
            Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type),
258
0
            Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type),
259
0
            Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type),
260
0
            Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type),
261
0
            UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type),
262
0
            UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
263
0
            UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
264
0
            UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
265
            Float16 => {
266
0
                instantiate_max_accumulator!(data_type, f16, Float16Type)
267
            }
268
            Float32 => {
269
0
                instantiate_max_accumulator!(data_type, f32, Float32Type)
270
            }
271
            Float64 => {
272
0
                instantiate_max_accumulator!(data_type, f64, Float64Type)
273
            }
274
0
            Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type),
275
0
            Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type),
276
            Time32(Second) => {
277
0
                instantiate_max_accumulator!(data_type, i32, Time32SecondType)
278
            }
279
            Time32(Millisecond) => {
280
0
                instantiate_max_accumulator!(data_type, i32, Time32MillisecondType)
281
            }
282
            Time64(Microsecond) => {
283
0
                instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType)
284
            }
285
            Time64(Nanosecond) => {
286
0
                instantiate_max_accumulator!(data_type, i64, Time64NanosecondType)
287
            }
288
            Timestamp(Second, _) => {
289
0
                instantiate_max_accumulator!(data_type, i64, TimestampSecondType)
290
            }
291
            Timestamp(Millisecond, _) => {
292
0
                instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType)
293
            }
294
            Timestamp(Microsecond, _) => {
295
0
                instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType)
296
            }
297
            Timestamp(Nanosecond, _) => {
298
0
                instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType)
299
            }
300
            Decimal128(_, _) => {
301
0
                instantiate_max_accumulator!(data_type, i128, Decimal128Type)
302
            }
303
            Decimal256(_, _) => {
304
0
                instantiate_max_accumulator!(data_type, i256, Decimal256Type)
305
            }
306
307
            // It would be nice to have a fast implementation for Strings as well
308
            // https://github.com/apache/datafusion/issues/6906
309
310
            // This is only reached if groups_accumulator_supported is out of sync
311
0
            _ => internal_err!("GroupsAccumulator not supported for max({})", data_type),
312
        }
313
0
    }
314
315
0
    fn create_sliding_accumulator(
316
0
        &self,
317
0
        args: AccumulatorArgs,
318
0
    ) -> Result<Box<dyn Accumulator>> {
319
0
        Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?))
320
0
    }
321
322
0
    fn is_descending(&self) -> Option<bool> {
323
0
        Some(true)
324
0
    }
325
326
0
    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
327
0
        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
328
0
    }
329
330
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
331
0
        get_min_max_result_type(arg_types)
332
0
    }
333
0
    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
334
0
        datafusion_expr::ReversedUDAF::Identical
335
0
    }
336
0
    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
337
0
        self.value_from_statistics(statistics_args)
338
0
    }
339
}
340
341
// Statically-typed version of min/max(array) -> ScalarValue for string types
342
macro_rules! typed_min_max_batch_string {
343
    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
344
        let array = downcast_value!($VALUES, $ARRAYTYPE);
345
        let value = compute::$OP(array);
346
0
        let value = value.and_then(|e| Some(e.to_string()));
347
        ScalarValue::$SCALAR(value)
348
    }};
349
}
350
// Statically-typed version of min/max(array) -> ScalarValue for binay types.
351
macro_rules! typed_min_max_batch_binary {
352
    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
353
        let array = downcast_value!($VALUES, $ARRAYTYPE);
354
        let value = compute::$OP(array);
355
0
        let value = value.and_then(|e| Some(e.to_vec()));
356
        ScalarValue::$SCALAR(value)
357
    }};
358
}
359
360
// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
361
macro_rules! typed_min_max_batch {
362
    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
363
        let array = downcast_value!($VALUES, $ARRAYTYPE);
364
        let value = compute::$OP(array);
365
        ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*)
366
    }};
367
}
368
369
// Statically-typed version of min/max(array) -> ScalarValue  for non-string types.
370
// this is a macro to support both operations (min and max).
371
macro_rules! min_max_batch {
372
    ($VALUES:expr, $OP:ident) => {{
373
        match $VALUES.data_type() {
374
            DataType::Null => ScalarValue::Null,
375
            DataType::Decimal128(precision, scale) => {
376
                typed_min_max_batch!(
377
                    $VALUES,
378
                    Decimal128Array,
379
                    Decimal128,
380
                    $OP,
381
                    precision,
382
                    scale
383
                )
384
            }
385
            DataType::Decimal256(precision, scale) => {
386
                typed_min_max_batch!(
387
                    $VALUES,
388
                    Decimal256Array,
389
                    Decimal256,
390
                    $OP,
391
                    precision,
392
                    scale
393
                )
394
            }
395
            // all types that have a natural order
396
            DataType::Float64 => {
397
                typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
398
            }
399
            DataType::Float32 => {
400
                typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
401
            }
402
            DataType::Float16 => {
403
                typed_min_max_batch!($VALUES, Float16Array, Float16, $OP)
404
            }
405
            DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
406
            DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
407
            DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
408
            DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP),
409
            DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP),
410
            DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP),
411
            DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP),
412
            DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP),
413
            DataType::Timestamp(TimeUnit::Second, tz_opt) => {
414
                typed_min_max_batch!(
415
                    $VALUES,
416
                    TimestampSecondArray,
417
                    TimestampSecond,
418
                    $OP,
419
                    tz_opt
420
                )
421
            }
422
            DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!(
423
                $VALUES,
424
                TimestampMillisecondArray,
425
                TimestampMillisecond,
426
                $OP,
427
                tz_opt
428
            ),
429
            DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!(
430
                $VALUES,
431
                TimestampMicrosecondArray,
432
                TimestampMicrosecond,
433
                $OP,
434
                tz_opt
435
            ),
436
            DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!(
437
                $VALUES,
438
                TimestampNanosecondArray,
439
                TimestampNanosecond,
440
                $OP,
441
                tz_opt
442
            ),
443
            DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP),
444
            DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP),
445
            DataType::Time32(TimeUnit::Second) => {
446
                typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP)
447
            }
448
            DataType::Time32(TimeUnit::Millisecond) => {
449
                typed_min_max_batch!(
450
                    $VALUES,
451
                    Time32MillisecondArray,
452
                    Time32Millisecond,
453
                    $OP
454
                )
455
            }
456
            DataType::Time64(TimeUnit::Microsecond) => {
457
                typed_min_max_batch!(
458
                    $VALUES,
459
                    Time64MicrosecondArray,
460
                    Time64Microsecond,
461
                    $OP
462
                )
463
            }
464
            DataType::Time64(TimeUnit::Nanosecond) => {
465
                typed_min_max_batch!(
466
                    $VALUES,
467
                    Time64NanosecondArray,
468
                    Time64Nanosecond,
469
                    $OP
470
                )
471
            }
472
            DataType::Interval(IntervalUnit::YearMonth) => {
473
                typed_min_max_batch!(
474
                    $VALUES,
475
                    IntervalYearMonthArray,
476
                    IntervalYearMonth,
477
                    $OP
478
                )
479
            }
480
            DataType::Interval(IntervalUnit::DayTime) => {
481
                typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP)
482
            }
483
            DataType::Interval(IntervalUnit::MonthDayNano) => {
484
                typed_min_max_batch!(
485
                    $VALUES,
486
                    IntervalMonthDayNanoArray,
487
                    IntervalMonthDayNano,
488
                    $OP
489
                )
490
            }
491
            other => {
492
                // This should have been handled before
493
                return internal_err!(
494
                    "Min/Max accumulator not implemented for type {:?}",
495
                    other
496
                );
497
            }
498
        }
499
    }};
500
}
501
502
/// dynamically-typed min(array) -> ScalarValue
503
0
fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
504
0
    Ok(match values.data_type() {
505
        DataType::Utf8 => {
506
0
            typed_min_max_batch_string!(values, StringArray, Utf8, min_string)
507
        }
508
        DataType::LargeUtf8 => {
509
0
            typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
510
        }
511
        DataType::Utf8View => {
512
0
            typed_min_max_batch_string!(
513
0
                values,
514
0
                StringViewArray,
515
0
                Utf8View,
516
0
                min_string_view
517
0
            )
518
        }
519
        DataType::Boolean => {
520
0
            typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean)
521
        }
522
        DataType::Binary => {
523
0
            typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary)
524
        }
525
        DataType::LargeBinary => {
526
0
            typed_min_max_batch_binary!(
527
0
                &values,
528
0
                LargeBinaryArray,
529
0
                LargeBinary,
530
0
                min_binary
531
0
            )
532
        }
533
        DataType::BinaryView => {
534
0
            typed_min_max_batch_binary!(
535
0
                &values,
536
0
                BinaryViewArray,
537
0
                BinaryView,
538
0
                min_binary_view
539
0
            )
540
        }
541
0
        _ => min_max_batch!(values, min),
542
    })
543
0
}
544
545
/// dynamically-typed max(array) -> ScalarValue
546
0
fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
547
0
    Ok(match values.data_type() {
548
        DataType::Utf8 => {
549
0
            typed_min_max_batch_string!(values, StringArray, Utf8, max_string)
550
        }
551
        DataType::LargeUtf8 => {
552
0
            typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
553
        }
554
        DataType::Utf8View => {
555
0
            typed_min_max_batch_string!(
556
0
                values,
557
0
                StringViewArray,
558
0
                Utf8View,
559
0
                max_string_view
560
0
            )
561
        }
562
        DataType::Boolean => {
563
0
            typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean)
564
        }
565
        DataType::Binary => {
566
0
            typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary)
567
        }
568
        DataType::BinaryView => {
569
0
            typed_min_max_batch_binary!(
570
0
                &values,
571
0
                BinaryViewArray,
572
0
                BinaryView,
573
0
                max_binary_view
574
0
            )
575
        }
576
        DataType::LargeBinary => {
577
0
            typed_min_max_batch_binary!(
578
0
                &values,
579
0
                LargeBinaryArray,
580
0
                LargeBinary,
581
0
                max_binary
582
0
            )
583
        }
584
0
        _ => min_max_batch!(values, max),
585
    })
586
0
}
587
588
// min/max of two non-string scalar values.
589
macro_rules! typed_min_max {
590
    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
591
        ScalarValue::$SCALAR(
592
            match ($VALUE, $DELTA) {
593
                (None, None) => None,
594
                (Some(a), None) => Some(*a),
595
                (None, Some(b)) => Some(*b),
596
                (Some(a), Some(b)) => Some((*a).$OP(*b)),
597
            },
598
            $($EXTRA_ARGS.clone()),*
599
        )
600
    }};
601
}
602
macro_rules! typed_min_max_float {
603
    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
604
        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
605
            (None, None) => None,
606
            (Some(a), None) => Some(*a),
607
            (None, Some(b)) => Some(*b),
608
            (Some(a), Some(b)) => match a.total_cmp(b) {
609
                choose_min_max!($OP) => Some(*b),
610
                _ => Some(*a),
611
            },
612
        })
613
    }};
614
}
615
616
// min/max of two scalar string values.
617
macro_rules! typed_min_max_string {
618
    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
619
        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
620
            (None, None) => None,
621
            (Some(a), None) => Some(a.clone()),
622
            (None, Some(b)) => Some(b.clone()),
623
            (Some(a), Some(b)) => Some((a).$OP(b).clone()),
624
        })
625
    }};
626
}
627
628
macro_rules! choose_min_max {
629
    (min) => {
630
        std::cmp::Ordering::Greater
631
    };
632
    (max) => {
633
        std::cmp::Ordering::Less
634
    };
635
}
636
637
macro_rules! interval_min_max {
638
    ($OP:tt, $LHS:expr, $RHS:expr) => {{
639
        match $LHS.partial_cmp(&$RHS) {
640
            Some(choose_min_max!($OP)) => $RHS.clone(),
641
            Some(_) => $LHS.clone(),
642
            None => {
643
                return internal_err!("Comparison error while computing interval min/max")
644
            }
645
        }
646
    }};
647
}
648
649
// min/max of two scalar values of the same type
650
macro_rules! min_max {
651
    ($VALUE:expr, $DELTA:expr, $OP:ident) => {{
652
        Ok(match ($VALUE, $DELTA) {
653
            (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
654
            (
655
                lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
656
                rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
657
            ) => {
658
                if lhsp.eq(rhsp) && lhss.eq(rhss) {
659
                    typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss)
660
                } else {
661
                    return internal_err!(
662
                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
663
                    (lhs, rhs)
664
                );
665
                }
666
            }
667
            (
668
                lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss),
669
                rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss)
670
            ) => {
671
                if lhsp.eq(rhsp) && lhss.eq(rhss) {
672
                    typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss)
673
                } else {
674
                    return internal_err!(
675
                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
676
                    (lhs, rhs)
677
                );
678
                }
679
            }
680
            (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => {
681
                typed_min_max!(lhs, rhs, Boolean, $OP)
682
            }
683
            (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
684
                typed_min_max_float!(lhs, rhs, Float64, $OP)
685
            }
686
            (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
687
                typed_min_max_float!(lhs, rhs, Float32, $OP)
688
            }
689
            (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
690
                typed_min_max_float!(lhs, rhs, Float16, $OP)
691
            }
692
            (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
693
                typed_min_max!(lhs, rhs, UInt64, $OP)
694
            }
695
            (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
696
                typed_min_max!(lhs, rhs, UInt32, $OP)
697
            }
698
            (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
699
                typed_min_max!(lhs, rhs, UInt16, $OP)
700
            }
701
            (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
702
                typed_min_max!(lhs, rhs, UInt8, $OP)
703
            }
704
            (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
705
                typed_min_max!(lhs, rhs, Int64, $OP)
706
            }
707
            (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
708
                typed_min_max!(lhs, rhs, Int32, $OP)
709
            }
710
            (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
711
                typed_min_max!(lhs, rhs, Int16, $OP)
712
            }
713
            (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
714
                typed_min_max!(lhs, rhs, Int8, $OP)
715
            }
716
            (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => {
717
                typed_min_max_string!(lhs, rhs, Utf8, $OP)
718
            }
719
            (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => {
720
                typed_min_max_string!(lhs, rhs, LargeUtf8, $OP)
721
            }
722
            (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => {
723
                typed_min_max_string!(lhs, rhs, Utf8View, $OP)
724
            }
725
            (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => {
726
                typed_min_max_string!(lhs, rhs, Binary, $OP)
727
            }
728
            (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => {
729
                typed_min_max_string!(lhs, rhs, LargeBinary, $OP)
730
            }
731
            (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => {
732
                typed_min_max_string!(lhs, rhs, BinaryView, $OP)
733
            }
734
            (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => {
735
                typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz)
736
            }
737
            (
738
                ScalarValue::TimestampMillisecond(lhs, l_tz),
739
                ScalarValue::TimestampMillisecond(rhs, _),
740
            ) => {
741
                typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz)
742
            }
743
            (
744
                ScalarValue::TimestampMicrosecond(lhs, l_tz),
745
                ScalarValue::TimestampMicrosecond(rhs, _),
746
            ) => {
747
                typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz)
748
            }
749
            (
750
                ScalarValue::TimestampNanosecond(lhs, l_tz),
751
                ScalarValue::TimestampNanosecond(rhs, _),
752
            ) => {
753
                typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz)
754
            }
755
            (
756
                ScalarValue::Date32(lhs),
757
                ScalarValue::Date32(rhs),
758
            ) => {
759
                typed_min_max!(lhs, rhs, Date32, $OP)
760
            }
761
            (
762
                ScalarValue::Date64(lhs),
763
                ScalarValue::Date64(rhs),
764
            ) => {
765
                typed_min_max!(lhs, rhs, Date64, $OP)
766
            }
767
            (
768
                ScalarValue::Time32Second(lhs),
769
                ScalarValue::Time32Second(rhs),
770
            ) => {
771
                typed_min_max!(lhs, rhs, Time32Second, $OP)
772
            }
773
            (
774
                ScalarValue::Time32Millisecond(lhs),
775
                ScalarValue::Time32Millisecond(rhs),
776
            ) => {
777
                typed_min_max!(lhs, rhs, Time32Millisecond, $OP)
778
            }
779
            (
780
                ScalarValue::Time64Microsecond(lhs),
781
                ScalarValue::Time64Microsecond(rhs),
782
            ) => {
783
                typed_min_max!(lhs, rhs, Time64Microsecond, $OP)
784
            }
785
            (
786
                ScalarValue::Time64Nanosecond(lhs),
787
                ScalarValue::Time64Nanosecond(rhs),
788
            ) => {
789
                typed_min_max!(lhs, rhs, Time64Nanosecond, $OP)
790
            }
791
            (
792
                ScalarValue::IntervalYearMonth(lhs),
793
                ScalarValue::IntervalYearMonth(rhs),
794
            ) => {
795
                typed_min_max!(lhs, rhs, IntervalYearMonth, $OP)
796
            }
797
            (
798
                ScalarValue::IntervalMonthDayNano(lhs),
799
                ScalarValue::IntervalMonthDayNano(rhs),
800
            ) => {
801
                typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP)
802
            }
803
            (
804
                ScalarValue::IntervalDayTime(lhs),
805
                ScalarValue::IntervalDayTime(rhs),
806
            ) => {
807
                typed_min_max!(lhs, rhs, IntervalDayTime, $OP)
808
            }
809
            (
810
                ScalarValue::IntervalYearMonth(_),
811
                ScalarValue::IntervalMonthDayNano(_),
812
            ) | (
813
                ScalarValue::IntervalYearMonth(_),
814
                ScalarValue::IntervalDayTime(_),
815
            ) | (
816
                ScalarValue::IntervalMonthDayNano(_),
817
                ScalarValue::IntervalDayTime(_),
818
            ) | (
819
                ScalarValue::IntervalMonthDayNano(_),
820
                ScalarValue::IntervalYearMonth(_),
821
            ) | (
822
                ScalarValue::IntervalDayTime(_),
823
                ScalarValue::IntervalYearMonth(_),
824
            ) | (
825
                ScalarValue::IntervalDayTime(_),
826
                ScalarValue::IntervalMonthDayNano(_),
827
            ) => {
828
                interval_min_max!($OP, $VALUE, $DELTA)
829
            }
830
                    (
831
                ScalarValue::DurationSecond(lhs),
832
                ScalarValue::DurationSecond(rhs),
833
            ) => {
834
                typed_min_max!(lhs, rhs, DurationSecond, $OP)
835
            }
836
                                (
837
                ScalarValue::DurationMillisecond(lhs),
838
                ScalarValue::DurationMillisecond(rhs),
839
            ) => {
840
                typed_min_max!(lhs, rhs, DurationMillisecond, $OP)
841
            }
842
                                (
843
                ScalarValue::DurationMicrosecond(lhs),
844
                ScalarValue::DurationMicrosecond(rhs),
845
            ) => {
846
                typed_min_max!(lhs, rhs, DurationMicrosecond, $OP)
847
            }
848
                                        (
849
                ScalarValue::DurationNanosecond(lhs),
850
                ScalarValue::DurationNanosecond(rhs),
851
            ) => {
852
                typed_min_max!(lhs, rhs, DurationNanosecond, $OP)
853
            }
854
            e => {
855
                return internal_err!(
856
                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
857
                    e
858
                )
859
            }
860
        })
861
    }};
862
}
863
864
/// An accumulator to compute the maximum value
865
#[derive(Debug)]
866
pub struct MaxAccumulator {
867
    max: ScalarValue,
868
}
869
870
impl MaxAccumulator {
871
    /// new max accumulator
872
0
    pub fn try_new(datatype: &DataType) -> Result<Self> {
873
0
        Ok(Self {
874
0
            max: ScalarValue::try_from(datatype)?,
875
        })
876
0
    }
877
}
878
879
impl Accumulator for MaxAccumulator {
880
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
881
0
        let values = &values[0];
882
0
        let delta = &max_batch(values)?;
883
0
        let new_max: Result<ScalarValue, DataFusionError> =
884
0
            min_max!(&self.max, delta, max);
885
0
        self.max = new_max?;
886
0
        Ok(())
887
0
    }
888
889
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
890
0
        self.update_batch(states)
891
0
    }
892
893
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
894
0
        Ok(vec![self.evaluate()?])
895
0
    }
896
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
897
0
        Ok(self.max.clone())
898
0
    }
899
900
0
    fn size(&self) -> usize {
901
0
        std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size()
902
0
    }
903
}
904
905
#[derive(Debug)]
906
pub struct SlidingMaxAccumulator {
907
    max: ScalarValue,
908
    moving_max: MovingMax<ScalarValue>,
909
}
910
911
impl SlidingMaxAccumulator {
912
    /// new max accumulator
913
0
    pub fn try_new(datatype: &DataType) -> Result<Self> {
914
0
        Ok(Self {
915
0
            max: ScalarValue::try_from(datatype)?,
916
0
            moving_max: MovingMax::<ScalarValue>::new(),
917
        })
918
0
    }
919
}
920
921
impl Accumulator for SlidingMaxAccumulator {
922
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
923
0
        for idx in 0..values[0].len() {
924
0
            let val = ScalarValue::try_from_array(&values[0], idx)?;
925
0
            self.moving_max.push(val);
926
        }
927
0
        if let Some(res) = self.moving_max.max() {
928
0
            self.max = res.clone();
929
0
        }
930
0
        Ok(())
931
0
    }
932
933
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
934
0
        for _idx in 0..values[0].len() {
935
0
            (self.moving_max).pop();
936
0
        }
937
0
        if let Some(res) = self.moving_max.max() {
938
0
            self.max = res.clone();
939
0
        }
940
0
        Ok(())
941
0
    }
942
943
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
944
0
        self.update_batch(states)
945
0
    }
946
947
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
948
0
        Ok(vec![self.max.clone()])
949
0
    }
950
951
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
952
0
        Ok(self.max.clone())
953
0
    }
954
955
0
    fn supports_retract_batch(&self) -> bool {
956
0
        true
957
0
    }
958
959
0
    fn size(&self) -> usize {
960
0
        std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size()
961
0
    }
962
}
963
964
#[derive(Debug)]
965
pub struct Min {
966
    signature: Signature,
967
}
968
969
impl Min {
970
0
    pub fn new() -> Self {
971
0
        Self {
972
0
            signature: Signature::user_defined(Volatility::Immutable),
973
0
        }
974
0
    }
975
}
976
977
impl Default for Min {
978
0
    fn default() -> Self {
979
0
        Self::new()
980
0
    }
981
}
982
983
impl FromColumnStatistics for Min {
984
0
    fn value_from_column_statistics(
985
0
        &self,
986
0
        col_stats: &ColumnStatistics,
987
0
    ) -> Option<ScalarValue> {
988
0
        if let Precision::Exact(ref val) = col_stats.min_value {
989
0
            if !val.is_null() {
990
0
                return Some(val.clone());
991
0
            }
992
0
        }
993
0
        None
994
0
    }
995
}
996
997
impl AggregateUDFImpl for Min {
998
0
    fn as_any(&self) -> &dyn std::any::Any {
999
0
        self
1000
0
    }
1001
1002
0
    fn name(&self) -> &str {
1003
0
        "min"
1004
0
    }
1005
1006
0
    fn signature(&self) -> &Signature {
1007
0
        &self.signature
1008
0
    }
1009
1010
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1011
0
        Ok(arg_types[0].to_owned())
1012
0
    }
1013
1014
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1015
0
        Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?))
1016
0
    }
1017
1018
0
    fn aliases(&self) -> &[String] {
1019
0
        &[]
1020
0
    }
1021
1022
0
    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1023
        use DataType::*;
1024
0
        matches!(
1025
0
            args.return_type,
1026
            Int8 | Int16
1027
                | Int32
1028
                | Int64
1029
                | UInt8
1030
                | UInt16
1031
                | UInt32
1032
                | UInt64
1033
                | Float16
1034
                | Float32
1035
                | Float64
1036
                | Decimal128(_, _)
1037
                | Decimal256(_, _)
1038
                | Date32
1039
                | Date64
1040
                | Time32(_)
1041
                | Time64(_)
1042
                | Timestamp(_, _)
1043
        )
1044
0
    }
1045
1046
0
    fn create_groups_accumulator(
1047
0
        &self,
1048
0
        args: AccumulatorArgs,
1049
0
    ) -> Result<Box<dyn GroupsAccumulator>> {
1050
        use DataType::*;
1051
        use TimeUnit::*;
1052
0
        let data_type = args.return_type;
1053
0
        match data_type {
1054
0
            Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type),
1055
0
            Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type),
1056
0
            Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type),
1057
0
            Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type),
1058
0
            UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type),
1059
0
            UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type),
1060
0
            UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type),
1061
0
            UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type),
1062
            Float16 => {
1063
0
                instantiate_min_accumulator!(data_type, f16, Float16Type)
1064
            }
1065
            Float32 => {
1066
0
                instantiate_min_accumulator!(data_type, f32, Float32Type)
1067
            }
1068
            Float64 => {
1069
0
                instantiate_min_accumulator!(data_type, f64, Float64Type)
1070
            }
1071
0
            Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type),
1072
0
            Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type),
1073
            Time32(Second) => {
1074
0
                instantiate_min_accumulator!(data_type, i32, Time32SecondType)
1075
            }
1076
            Time32(Millisecond) => {
1077
0
                instantiate_min_accumulator!(data_type, i32, Time32MillisecondType)
1078
            }
1079
            Time64(Microsecond) => {
1080
0
                instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType)
1081
            }
1082
            Time64(Nanosecond) => {
1083
0
                instantiate_min_accumulator!(data_type, i64, Time64NanosecondType)
1084
            }
1085
            Timestamp(Second, _) => {
1086
0
                instantiate_min_accumulator!(data_type, i64, TimestampSecondType)
1087
            }
1088
            Timestamp(Millisecond, _) => {
1089
0
                instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType)
1090
            }
1091
            Timestamp(Microsecond, _) => {
1092
0
                instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType)
1093
            }
1094
            Timestamp(Nanosecond, _) => {
1095
0
                instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType)
1096
            }
1097
            Decimal128(_, _) => {
1098
0
                instantiate_min_accumulator!(data_type, i128, Decimal128Type)
1099
            }
1100
            Decimal256(_, _) => {
1101
0
                instantiate_min_accumulator!(data_type, i256, Decimal256Type)
1102
            }
1103
1104
            // It would be nice to have a fast implementation for Strings as well
1105
            // https://github.com/apache/datafusion/issues/6906
1106
1107
            // This is only reached if groups_accumulator_supported is out of sync
1108
0
            _ => internal_err!("GroupsAccumulator not supported for min({})", data_type),
1109
        }
1110
0
    }
1111
1112
0
    fn create_sliding_accumulator(
1113
0
        &self,
1114
0
        args: AccumulatorArgs,
1115
0
    ) -> Result<Box<dyn Accumulator>> {
1116
0
        Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?))
1117
0
    }
1118
1119
0
    fn is_descending(&self) -> Option<bool> {
1120
0
        Some(false)
1121
0
    }
1122
1123
0
    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1124
0
        self.value_from_statistics(statistics_args)
1125
0
    }
1126
0
    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
1127
0
        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
1128
0
    }
1129
1130
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1131
0
        get_min_max_result_type(arg_types)
1132
0
    }
1133
1134
0
    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
1135
0
        datafusion_expr::ReversedUDAF::Identical
1136
0
    }
1137
}
1138
/// An accumulator to compute the minimum value
1139
#[derive(Debug)]
1140
pub struct MinAccumulator {
1141
    min: ScalarValue,
1142
}
1143
1144
impl MinAccumulator {
1145
    /// new min accumulator
1146
0
    pub fn try_new(datatype: &DataType) -> Result<Self> {
1147
0
        Ok(Self {
1148
0
            min: ScalarValue::try_from(datatype)?,
1149
        })
1150
0
    }
1151
}
1152
1153
impl Accumulator for MinAccumulator {
1154
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1155
0
        Ok(vec![self.evaluate()?])
1156
0
    }
1157
1158
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1159
0
        let values = &values[0];
1160
0
        let delta = &min_batch(values)?;
1161
0
        let new_min: Result<ScalarValue, DataFusionError> =
1162
0
            min_max!(&self.min, delta, min);
1163
0
        self.min = new_min?;
1164
0
        Ok(())
1165
0
    }
1166
1167
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1168
0
        self.update_batch(states)
1169
0
    }
1170
1171
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
1172
0
        Ok(self.min.clone())
1173
0
    }
1174
1175
0
    fn size(&self) -> usize {
1176
0
        std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size()
1177
0
    }
1178
}
1179
1180
#[derive(Debug)]
1181
pub struct SlidingMinAccumulator {
1182
    min: ScalarValue,
1183
    moving_min: MovingMin<ScalarValue>,
1184
}
1185
1186
impl SlidingMinAccumulator {
1187
0
    pub fn try_new(datatype: &DataType) -> Result<Self> {
1188
0
        Ok(Self {
1189
0
            min: ScalarValue::try_from(datatype)?,
1190
0
            moving_min: MovingMin::<ScalarValue>::new(),
1191
        })
1192
0
    }
1193
}
1194
1195
impl Accumulator for SlidingMinAccumulator {
1196
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1197
0
        Ok(vec![self.min.clone()])
1198
0
    }
1199
1200
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1201
0
        for idx in 0..values[0].len() {
1202
0
            let val = ScalarValue::try_from_array(&values[0], idx)?;
1203
0
            if !val.is_null() {
1204
0
                self.moving_min.push(val);
1205
0
            }
1206
        }
1207
0
        if let Some(res) = self.moving_min.min() {
1208
0
            self.min = res.clone();
1209
0
        }
1210
0
        Ok(())
1211
0
    }
1212
1213
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1214
0
        for idx in 0..values[0].len() {
1215
0
            let val = ScalarValue::try_from_array(&values[0], idx)?;
1216
0
            if !val.is_null() {
1217
0
                (self.moving_min).pop();
1218
0
            }
1219
        }
1220
0
        if let Some(res) = self.moving_min.min() {
1221
0
            self.min = res.clone();
1222
0
        }
1223
0
        Ok(())
1224
0
    }
1225
1226
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1227
0
        self.update_batch(states)
1228
0
    }
1229
1230
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
1231
0
        Ok(self.min.clone())
1232
0
    }
1233
1234
0
    fn supports_retract_batch(&self) -> bool {
1235
0
        true
1236
0
    }
1237
1238
0
    fn size(&self) -> usize {
1239
0
        std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size()
1240
0
    }
1241
}
1242
1243
//
1244
// Moving min and moving max
1245
// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs.
1246
1247
// Keep track of the minimum or maximum value in a sliding window.
1248
//
1249
// `moving min max` provides one data structure for keeping track of the
1250
// minimum value and one for keeping track of the maximum value in a sliding
1251
// window.
1252
//
1253
// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty,
1254
// push to this stack all elements popped from first stack while updating their current min/max. Now pop from
1255
// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue,
1256
// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values.
1257
//
1258
// The complexity of the operations are
1259
// - O(1) for getting the minimum/maximum
1260
// - O(1) for push
1261
// - amortized O(1) for pop
1262
1263
/// ```
1264
/// # use datafusion_functions_aggregate::min_max::MovingMin;
1265
/// let mut moving_min = MovingMin::<i32>::new();
1266
/// moving_min.push(2);
1267
/// moving_min.push(1);
1268
/// moving_min.push(3);
1269
///
1270
/// assert_eq!(moving_min.min(), Some(&1));
1271
/// assert_eq!(moving_min.pop(), Some(2));
1272
///
1273
/// assert_eq!(moving_min.min(), Some(&1));
1274
/// assert_eq!(moving_min.pop(), Some(1));
1275
///
1276
/// assert_eq!(moving_min.min(), Some(&3));
1277
/// assert_eq!(moving_min.pop(), Some(3));
1278
///
1279
/// assert_eq!(moving_min.min(), None);
1280
/// assert_eq!(moving_min.pop(), None);
1281
/// ```
1282
#[derive(Debug)]
1283
pub struct MovingMin<T> {
1284
    push_stack: Vec<(T, T)>,
1285
    pop_stack: Vec<(T, T)>,
1286
}
1287
1288
impl<T: Clone + PartialOrd> Default for MovingMin<T> {
1289
0
    fn default() -> Self {
1290
0
        Self {
1291
0
            push_stack: Vec::new(),
1292
0
            pop_stack: Vec::new(),
1293
0
        }
1294
0
    }
1295
}
1296
1297
impl<T: Clone + PartialOrd> MovingMin<T> {
1298
    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
1299
    /// window.
1300
    #[inline]
1301
0
    pub fn new() -> Self {
1302
0
        Self::default()
1303
0
    }
1304
1305
    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
1306
    /// window with `capacity` allocated slots.
1307
    #[inline]
1308
0
    pub fn with_capacity(capacity: usize) -> Self {
1309
0
        Self {
1310
0
            push_stack: Vec::with_capacity(capacity),
1311
0
            pop_stack: Vec::with_capacity(capacity),
1312
0
        }
1313
0
    }
1314
1315
    /// Returns the minimum of the sliding window or `None` if the window is
1316
    /// empty.
1317
    #[inline]
1318
0
    pub fn min(&self) -> Option<&T> {
1319
0
        match (self.push_stack.last(), self.pop_stack.last()) {
1320
0
            (None, None) => None,
1321
0
            (Some((_, min)), None) => Some(min),
1322
0
            (None, Some((_, min))) => Some(min),
1323
0
            (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }),
1324
        }
1325
0
    }
1326
1327
    /// Pushes a new element into the sliding window.
1328
    #[inline]
1329
0
    pub fn push(&mut self, val: T) {
1330
0
        self.push_stack.push(match self.push_stack.last() {
1331
0
            Some((_, min)) => {
1332
0
                if val > *min {
1333
0
                    (val, min.clone())
1334
                } else {
1335
0
                    (val.clone(), val)
1336
                }
1337
            }
1338
0
            None => (val.clone(), val),
1339
        });
1340
0
    }
1341
1342
    /// Removes and returns the last value of the sliding window.
1343
    #[inline]
1344
0
    pub fn pop(&mut self) -> Option<T> {
1345
0
        if self.pop_stack.is_empty() {
1346
0
            match self.push_stack.pop() {
1347
0
                Some((val, _)) => {
1348
0
                    let mut last = (val.clone(), val);
1349
0
                    self.pop_stack.push(last.clone());
1350
0
                    while let Some((val, _)) = self.push_stack.pop() {
1351
0
                        let min = if last.1 < val {
1352
0
                            last.1.clone()
1353
                        } else {
1354
0
                            val.clone()
1355
                        };
1356
0
                        last = (val.clone(), min);
1357
0
                        self.pop_stack.push(last.clone());
1358
                    }
1359
                }
1360
0
                None => return None,
1361
            }
1362
0
        }
1363
0
        self.pop_stack.pop().map(|(val, _)| val)
1364
0
    }
1365
1366
    /// Returns the number of elements stored in the sliding window.
1367
    #[inline]
1368
0
    pub fn len(&self) -> usize {
1369
0
        self.push_stack.len() + self.pop_stack.len()
1370
0
    }
1371
1372
    /// Returns `true` if the moving window contains no elements.
1373
    #[inline]
1374
0
    pub fn is_empty(&self) -> bool {
1375
0
        self.len() == 0
1376
0
    }
1377
}
1378
/// ```
1379
/// # use datafusion_functions_aggregate::min_max::MovingMax;
1380
/// let mut moving_max = MovingMax::<i32>::new();
1381
/// moving_max.push(2);
1382
/// moving_max.push(3);
1383
/// moving_max.push(1);
1384
///
1385
/// assert_eq!(moving_max.max(), Some(&3));
1386
/// assert_eq!(moving_max.pop(), Some(2));
1387
///
1388
/// assert_eq!(moving_max.max(), Some(&3));
1389
/// assert_eq!(moving_max.pop(), Some(3));
1390
///
1391
/// assert_eq!(moving_max.max(), Some(&1));
1392
/// assert_eq!(moving_max.pop(), Some(1));
1393
///
1394
/// assert_eq!(moving_max.max(), None);
1395
/// assert_eq!(moving_max.pop(), None);
1396
/// ```
1397
#[derive(Debug)]
1398
pub struct MovingMax<T> {
1399
    push_stack: Vec<(T, T)>,
1400
    pop_stack: Vec<(T, T)>,
1401
}
1402
1403
impl<T: Clone + PartialOrd> Default for MovingMax<T> {
1404
0
    fn default() -> Self {
1405
0
        Self {
1406
0
            push_stack: Vec::new(),
1407
0
            pop_stack: Vec::new(),
1408
0
        }
1409
0
    }
1410
}
1411
1412
impl<T: Clone + PartialOrd> MovingMax<T> {
1413
    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window.
1414
    #[inline]
1415
0
    pub fn new() -> Self {
1416
0
        Self::default()
1417
0
    }
1418
1419
    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with
1420
    /// `capacity` allocated slots.
1421
    #[inline]
1422
0
    pub fn with_capacity(capacity: usize) -> Self {
1423
0
        Self {
1424
0
            push_stack: Vec::with_capacity(capacity),
1425
0
            pop_stack: Vec::with_capacity(capacity),
1426
0
        }
1427
0
    }
1428
1429
    /// Returns the maximum of the sliding window or `None` if the window is empty.
1430
    #[inline]
1431
0
    pub fn max(&self) -> Option<&T> {
1432
0
        match (self.push_stack.last(), self.pop_stack.last()) {
1433
0
            (None, None) => None,
1434
0
            (Some((_, max)), None) => Some(max),
1435
0
            (None, Some((_, max))) => Some(max),
1436
0
            (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }),
1437
        }
1438
0
    }
1439
1440
    /// Pushes a new element into the sliding window.
1441
    #[inline]
1442
0
    pub fn push(&mut self, val: T) {
1443
0
        self.push_stack.push(match self.push_stack.last() {
1444
0
            Some((_, max)) => {
1445
0
                if val < *max {
1446
0
                    (val, max.clone())
1447
                } else {
1448
0
                    (val.clone(), val)
1449
                }
1450
            }
1451
0
            None => (val.clone(), val),
1452
        });
1453
0
    }
1454
1455
    /// Removes and returns the last value of the sliding window.
1456
    #[inline]
1457
0
    pub fn pop(&mut self) -> Option<T> {
1458
0
        if self.pop_stack.is_empty() {
1459
0
            match self.push_stack.pop() {
1460
0
                Some((val, _)) => {
1461
0
                    let mut last = (val.clone(), val);
1462
0
                    self.pop_stack.push(last.clone());
1463
0
                    while let Some((val, _)) = self.push_stack.pop() {
1464
0
                        let max = if last.1 > val {
1465
0
                            last.1.clone()
1466
                        } else {
1467
0
                            val.clone()
1468
                        };
1469
0
                        last = (val.clone(), max);
1470
0
                        self.pop_stack.push(last.clone());
1471
                    }
1472
                }
1473
0
                None => return None,
1474
            }
1475
0
        }
1476
0
        self.pop_stack.pop().map(|(val, _)| val)
1477
0
    }
1478
1479
    /// Returns the number of elements stored in the sliding window.
1480
    #[inline]
1481
0
    pub fn len(&self) -> usize {
1482
0
        self.push_stack.len() + self.pop_stack.len()
1483
0
    }
1484
1485
    /// Returns `true` if the moving window contains no elements.
1486
    #[inline]
1487
0
    pub fn is_empty(&self) -> bool {
1488
0
        self.len() == 0
1489
0
    }
1490
}
1491
1492
make_udaf_expr_and_func!(
1493
    Max,
1494
    max,
1495
    expression,
1496
    "Returns the maximum of a group of values.",
1497
    max_udaf
1498
);
1499
1500
make_udaf_expr_and_func!(
1501
    Min,
1502
    min,
1503
    expression,
1504
    "Returns the minimum of a group of values.",
1505
    min_udaf
1506
);
1507
1508
#[cfg(test)]
1509
mod tests {
1510
    use super::*;
1511
    use arrow::datatypes::{
1512
        IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
1513
    };
1514
    use std::sync::Arc;
1515
1516
    #[test]
1517
    fn interval_min_max() {
1518
        // IntervalYearMonth
1519
        let b = IntervalYearMonthArray::from(vec![
1520
            IntervalYearMonthType::make_value(0, 1),
1521
            IntervalYearMonthType::make_value(5, 34),
1522
            IntervalYearMonthType::make_value(-2, 4),
1523
            IntervalYearMonthType::make_value(7, -4),
1524
            IntervalYearMonthType::make_value(0, 1),
1525
        ]);
1526
        let b: ArrayRef = Arc::new(b);
1527
1528
        let mut min =
1529
            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1530
                .unwrap();
1531
        min.update_batch(&[Arc::clone(&b)]).unwrap();
1532
        let min_res = min.evaluate().unwrap();
1533
        assert_eq!(
1534
            min_res,
1535
            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1536
                -2, 4
1537
            )))
1538
        );
1539
1540
        let mut max =
1541
            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1542
                .unwrap();
1543
        max.update_batch(&[Arc::clone(&b)]).unwrap();
1544
        let max_res = max.evaluate().unwrap();
1545
        assert_eq!(
1546
            max_res,
1547
            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1548
                5, 34
1549
            )))
1550
        );
1551
1552
        // IntervalDayTime
1553
        let b = IntervalDayTimeArray::from(vec![
1554
            IntervalDayTimeType::make_value(0, 0),
1555
            IntervalDayTimeType::make_value(5, 454000),
1556
            IntervalDayTimeType::make_value(-34, 0),
1557
            IntervalDayTimeType::make_value(7, -4000),
1558
            IntervalDayTimeType::make_value(1, 0),
1559
        ]);
1560
        let b: ArrayRef = Arc::new(b);
1561
1562
        let mut min =
1563
            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1564
        min.update_batch(&[Arc::clone(&b)]).unwrap();
1565
        let min_res = min.evaluate().unwrap();
1566
        assert_eq!(
1567
            min_res,
1568
            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0)))
1569
        );
1570
1571
        let mut max =
1572
            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1573
        max.update_batch(&[Arc::clone(&b)]).unwrap();
1574
        let max_res = max.evaluate().unwrap();
1575
        assert_eq!(
1576
            max_res,
1577
            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000)))
1578
        );
1579
1580
        // IntervalMonthDayNano
1581
        let b = IntervalMonthDayNanoArray::from(vec![
1582
            IntervalMonthDayNanoType::make_value(1, 0, 0),
1583
            IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000),
1584
            IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000),
1585
            IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000),
1586
            IntervalMonthDayNanoType::make_value(1, 0, 0),
1587
        ]);
1588
        let b: ArrayRef = Arc::new(b);
1589
1590
        let mut min =
1591
            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1592
                .unwrap();
1593
        min.update_batch(&[Arc::clone(&b)]).unwrap();
1594
        let min_res = min.evaluate().unwrap();
1595
        assert_eq!(
1596
            min_res,
1597
            ScalarValue::IntervalMonthDayNano(Some(
1598
                IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000)
1599
            ))
1600
        );
1601
1602
        let mut max =
1603
            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1604
                .unwrap();
1605
        max.update_batch(&[Arc::clone(&b)]).unwrap();
1606
        let max_res = max.evaluate().unwrap();
1607
        assert_eq!(
1608
            max_res,
1609
            ScalarValue::IntervalMonthDayNano(Some(
1610
                IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000)
1611
            ))
1612
        );
1613
    }
1614
1615
    #[test]
1616
    fn float_min_max_with_nans() {
1617
        let pos_nan = f32::NAN;
1618
        let zero = 0_f32;
1619
        let neg_inf = f32::NEG_INFINITY;
1620
1621
        let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| {
1622
            for batch in values.iter() {
1623
                let batch =
1624
                    Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
1625
                acc.update_batch(&[batch]).unwrap();
1626
            }
1627
            let result = acc.evaluate().unwrap();
1628
            assert_eq!(result, ScalarValue::Float32(Some(expected)));
1629
        };
1630
1631
        // This test checks both comparison between batches (which uses the min_max macro
1632
        // defined above) and within a batch (which uses the arrow min/max compute function
1633
        // and verifies both respect the total order comparison for floats)
1634
1635
        let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
1636
        let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
1637
1638
        check(&mut min(), &[&[zero], &[pos_nan]], zero);
1639
        check(&mut min(), &[&[zero, pos_nan]], zero);
1640
        check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
1641
        check(&mut min(), &[&[zero, neg_inf]], neg_inf);
1642
        check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
1643
        check(&mut max(), &[&[zero, pos_nan]], pos_nan);
1644
        check(&mut max(), &[&[zero], &[neg_inf]], zero);
1645
        check(&mut max(), &[&[zero, neg_inf]], zero);
1646
    }
1647
1648
    use datafusion_common::Result;
1649
    use rand::Rng;
1650
1651
    fn get_random_vec_i32(len: usize) -> Vec<i32> {
1652
        let mut rng = rand::thread_rng();
1653
        let mut input = Vec::with_capacity(len);
1654
        for _i in 0..len {
1655
            input.push(rng.gen_range(0..100));
1656
        }
1657
        input
1658
    }
1659
1660
    fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1661
        let data = get_random_vec_i32(len);
1662
        let mut expected = Vec::with_capacity(len);
1663
        let mut moving_min = MovingMin::<i32>::new();
1664
        let mut res = Vec::with_capacity(len);
1665
        for i in 0..len {
1666
            let start = i.saturating_sub(n_sliding_window);
1667
            expected.push(*data[start..i + 1].iter().min().unwrap());
1668
1669
            moving_min.push(data[i]);
1670
            if i > n_sliding_window {
1671
                moving_min.pop();
1672
            }
1673
            res.push(*moving_min.min().unwrap());
1674
        }
1675
        assert_eq!(res, expected);
1676
        Ok(())
1677
    }
1678
1679
    fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1680
        let data = get_random_vec_i32(len);
1681
        let mut expected = Vec::with_capacity(len);
1682
        let mut moving_max = MovingMax::<i32>::new();
1683
        let mut res = Vec::with_capacity(len);
1684
        for i in 0..len {
1685
            let start = i.saturating_sub(n_sliding_window);
1686
            expected.push(*data[start..i + 1].iter().max().unwrap());
1687
1688
            moving_max.push(data[i]);
1689
            if i > n_sliding_window {
1690
                moving_max.pop();
1691
            }
1692
            res.push(*moving_max.max().unwrap());
1693
        }
1694
        assert_eq!(res, expected);
1695
        Ok(())
1696
    }
1697
1698
    #[test]
1699
    fn moving_min_tests() -> Result<()> {
1700
        moving_min_i32(100, 10)?;
1701
        moving_min_i32(100, 20)?;
1702
        moving_min_i32(100, 50)?;
1703
        moving_min_i32(100, 100)?;
1704
        Ok(())
1705
    }
1706
1707
    #[test]
1708
    fn moving_max_tests() -> Result<()> {
1709
        moving_max_i32(100, 10)?;
1710
        moving_max_i32(100, 20)?;
1711
        moving_max_i32(100, 50)?;
1712
        moving_max_i32(100, 100)?;
1713
        Ok(())
1714
    }
1715
1716
    #[test]
1717
    fn test_min_max_coerce_types() {
1718
        // the coerced types is same with input types
1719
        let funs: Vec<Box<dyn AggregateUDFImpl>> =
1720
            vec![Box::new(Min::new()), Box::new(Max::new())];
1721
        let input_types = vec![
1722
            vec![DataType::Int32],
1723
            vec![DataType::Decimal128(10, 2)],
1724
            vec![DataType::Decimal256(1, 1)],
1725
            vec![DataType::Utf8],
1726
        ];
1727
        for fun in funs {
1728
            for input_type in &input_types {
1729
                let result = fun.coerce_types(input_type);
1730
                assert_eq!(*input_type, result.unwrap());
1731
            }
1732
        }
1733
    }
1734
1735
    #[test]
1736
    fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
1737
        let data_type =
1738
            DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32));
1739
        let result = get_min_max_result_type(&[data_type])?;
1740
        assert_eq!(result, vec![DataType::Int32]);
1741
        Ok(())
1742
    }
1743
}