Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr-common/src/type_coercion/aggregates.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use crate::signature::TypeSignature;
19
use arrow::datatypes::{
20
    DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
21
    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
22
};
23
24
use datafusion_common::{internal_err, plan_err, Result};
25
26
pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];
27
28
pub static SIGNED_INTEGERS: &[DataType] = &[
29
    DataType::Int8,
30
    DataType::Int16,
31
    DataType::Int32,
32
    DataType::Int64,
33
];
34
35
pub static UNSIGNED_INTEGERS: &[DataType] = &[
36
    DataType::UInt8,
37
    DataType::UInt16,
38
    DataType::UInt32,
39
    DataType::UInt64,
40
];
41
42
pub static INTEGERS: &[DataType] = &[
43
    DataType::Int8,
44
    DataType::Int16,
45
    DataType::Int32,
46
    DataType::Int64,
47
    DataType::UInt8,
48
    DataType::UInt16,
49
    DataType::UInt32,
50
    DataType::UInt64,
51
];
52
53
pub static NUMERICS: &[DataType] = &[
54
    DataType::Int8,
55
    DataType::Int16,
56
    DataType::Int32,
57
    DataType::Int64,
58
    DataType::UInt8,
59
    DataType::UInt16,
60
    DataType::UInt32,
61
    DataType::UInt64,
62
    DataType::Float32,
63
    DataType::Float64,
64
];
65
66
pub static TIMESTAMPS: &[DataType] = &[
67
    DataType::Timestamp(TimeUnit::Second, None),
68
    DataType::Timestamp(TimeUnit::Millisecond, None),
69
    DataType::Timestamp(TimeUnit::Microsecond, None),
70
    DataType::Timestamp(TimeUnit::Nanosecond, None),
71
];
72
73
pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
74
75
pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary];
76
77
pub static TIMES: &[DataType] = &[
78
    DataType::Time32(TimeUnit::Second),
79
    DataType::Time32(TimeUnit::Millisecond),
80
    DataType::Time64(TimeUnit::Microsecond),
81
    DataType::Time64(TimeUnit::Nanosecond),
82
];
83
84
/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
85
///
86
/// This method DOES NOT validate the argument types - only that (at least one,
87
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
88
/// number of input types.
89
66
pub fn check_arg_count(
90
66
    func_name: &str,
91
66
    input_types: &[DataType],
92
66
    signature: &TypeSignature,
93
66
) -> Result<()> {
94
66
    match signature {
95
7
        TypeSignature::Uniform(
agg_count0
, _) | TypeSignature::Any(agg_count) => {
96
7
            if input_types.len() != *agg_count {
97
0
                return plan_err!(
98
0
                    "The function {func_name} expects {:?} arguments, but {:?} were provided",
99
0
                    agg_count,
100
0
                    input_types.len()
101
0
                );
102
7
            }
103
        }
104
0
        TypeSignature::Exact(types) => {
105
0
            if types.len() != input_types.len() {
106
0
                return plan_err!(
107
0
                    "The function {func_name} expects {:?} arguments, but {:?} were provided",
108
0
                    types.len(),
109
0
                    input_types.len()
110
0
                );
111
0
            }
112
        }
113
20
        TypeSignature::OneOf(variants) => {
114
20
            let ok = variants
115
20
                .iter()
116
30
                .any(|v| check_arg_count(func_name, input_types, v).is_ok());
117
20
            if !ok {
118
0
                return plan_err!(
119
0
                    "The function {func_name} does not accept {:?} function arguments.",
120
0
                    input_types.len()
121
0
                );
122
20
            }
123
        }
124
        TypeSignature::VariadicAny => {
125
10
            if input_types.is_empty() {
126
0
                return plan_err!(
127
0
                    "The function {func_name} expects at least one argument"
128
0
                );
129
10
            }
130
        }
131
        TypeSignature::UserDefined
132
        | TypeSignature::Numeric(_)
133
19
        | TypeSignature::Coercible(_) => {
134
19
            // User-defined signature is validated in `coerce_types`
135
19
            // Numeric and Coercible signature is validated in `get_valid_types`
136
19
        }
137
        _ => {
138
10
            return internal_err!(
139
10
                "Aggregate functions do not support this {signature:?}"
140
10
            );
141
        }
142
    }
143
56
    Ok(())
144
66
}
145
146
/// function return type of a sum
147
0
pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
148
0
    match arg_type {
149
0
        DataType::Int64 => Ok(DataType::Int64),
150
0
        DataType::UInt64 => Ok(DataType::UInt64),
151
0
        DataType::Float64 => Ok(DataType::Float64),
152
0
        DataType::Decimal128(precision, scale) => {
153
0
            // in the spark, the result type is DECIMAL(min(38,precision+10), s)
154
0
            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
155
0
            let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
156
0
            Ok(DataType::Decimal128(new_precision, *scale))
157
        }
158
0
        DataType::Decimal256(precision, scale) => {
159
0
            // in the spark, the result type is DECIMAL(min(38,precision+10), s)
160
0
            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
161
0
            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
162
0
            Ok(DataType::Decimal256(new_precision, *scale))
163
        }
164
0
        other => plan_err!("SUM does not support type \"{other:?}\""),
165
    }
166
0
}
167
168
/// function return type of variance
169
0
pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
170
0
    if NUMERICS.contains(arg_type) {
171
0
        Ok(DataType::Float64)
172
    } else {
173
0
        plan_err!("VAR does not support {arg_type:?}")
174
    }
175
0
}
176
177
/// function return type of covariance
178
0
pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
179
0
    if NUMERICS.contains(arg_type) {
180
0
        Ok(DataType::Float64)
181
    } else {
182
0
        plan_err!("COVAR does not support {arg_type:?}")
183
    }
184
0
}
185
186
/// function return type of correlation
187
0
pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
188
0
    if NUMERICS.contains(arg_type) {
189
0
        Ok(DataType::Float64)
190
    } else {
191
0
        plan_err!("CORR does not support {arg_type:?}")
192
    }
193
0
}
194
195
/// function return type of an average
196
7
pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType> {
197
0
    match arg_type {
198
0
        DataType::Decimal128(precision, scale) => {
199
0
            // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
200
0
            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
201
0
            let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
202
0
            let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
203
0
            Ok(DataType::Decimal128(new_precision, new_scale))
204
        }
205
0
        DataType::Decimal256(precision, scale) => {
206
0
            // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
207
0
            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
208
0
            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
209
0
            let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
210
0
            Ok(DataType::Decimal256(new_precision, new_scale))
211
        }
212
7
        arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
213
0
        DataType::Dictionary(_, dict_value_type) => {
214
0
            avg_return_type(func_name, dict_value_type.as_ref())
215
        }
216
0
        other => plan_err!("{func_name} does not support {other:?}"),
217
    }
218
7
}
219
220
/// internal sum type of an average
221
0
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
222
0
    match arg_type {
223
0
        DataType::Decimal128(precision, scale) => {
224
0
            // in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
225
0
            let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
226
0
            Ok(DataType::Decimal128(new_precision, *scale))
227
        }
228
0
        DataType::Decimal256(precision, scale) => {
229
0
            // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s)
230
0
            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
231
0
            Ok(DataType::Decimal256(new_precision, *scale))
232
        }
233
0
        arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
234
0
        DataType::Dictionary(_, dict_value_type) => {
235
0
            avg_sum_type(dict_value_type.as_ref())
236
        }
237
0
        other => plan_err!("AVG does not support {other:?}"),
238
    }
239
0
}
240
241
0
pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
242
0
    match arg_type {
243
0
        DataType::Dictionary(_, dict_value_type) => {
244
0
            is_sum_support_arg_type(dict_value_type.as_ref())
245
        }
246
0
        _ => matches!(
247
0
            arg_type,
248
0
            arg_type if NUMERICS.contains(arg_type)
249
0
            || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
250
        ),
251
    }
252
0
}
253
254
0
pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
255
0
    match arg_type {
256
0
        DataType::Dictionary(_, dict_value_type) => {
257
0
            is_avg_support_arg_type(dict_value_type.as_ref())
258
        }
259
0
        _ => matches!(
260
0
            arg_type,
261
0
            arg_type if NUMERICS.contains(arg_type)
262
0
                || matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _))
263
        ),
264
    }
265
0
}
266
267
0
pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
268
0
    matches!(
269
0
        arg_type,
270
0
        arg_type if NUMERICS.contains(arg_type)
271
    )
272
0
}
273
274
0
pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
275
0
    matches!(
276
0
        arg_type,
277
0
        arg_type if NUMERICS.contains(arg_type)
278
    )
279
0
}
280
281
0
pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
282
0
    matches!(
283
0
        arg_type,
284
0
        arg_type if NUMERICS.contains(arg_type)
285
    )
286
0
}
287
288
0
pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
289
0
    arg_type.is_integer()
290
0
}
291
292
0
pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<DataType>> {
293
    // Supported types smallint, int, bigint, real, double precision, decimal, or interval
294
    // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
295
0
    fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> {
296
0
        return match &data_type {
297
0
            DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
298
0
            DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
299
0
            d if d.is_numeric() => Ok(DataType::Float64),
300
0
            DataType::Dictionary(_, v) => return coerced_type(func_name, v.as_ref()),
301
            _ => {
302
0
                return plan_err!(
303
0
                    "The function {:?} does not support inputs of type {:?}.",
304
0
                    func_name,
305
0
                    data_type
306
0
                )
307
            }
308
        };
309
0
    }
310
0
    Ok(vec![coerced_type(func_name, &arg_types[0])?])
311
0
}
312
#[cfg(test)]
313
mod tests {
314
    use super::*;
315
316
    #[test]
317
    fn test_variance_return_data_type() -> Result<()> {
318
        let data_type = DataType::Float64;
319
        let result_type = variance_return_type(&data_type)?;
320
        assert_eq!(DataType::Float64, result_type);
321
322
        let data_type = DataType::Decimal128(36, 10);
323
        assert!(variance_return_type(&data_type).is_err());
324
        Ok(())
325
    }
326
327
    #[test]
328
    fn test_sum_return_data_type() -> Result<()> {
329
        let data_type = DataType::Decimal128(10, 5);
330
        let result_type = sum_return_type(&data_type)?;
331
        assert_eq!(DataType::Decimal128(20, 5), result_type);
332
333
        let data_type = DataType::Decimal128(36, 10);
334
        let result_type = sum_return_type(&data_type)?;
335
        assert_eq!(DataType::Decimal128(38, 10), result_type);
336
        Ok(())
337
    }
338
339
    #[test]
340
    fn test_covariance_return_data_type() -> Result<()> {
341
        let data_type = DataType::Float64;
342
        let result_type = covariance_return_type(&data_type)?;
343
        assert_eq!(DataType::Float64, result_type);
344
345
        let data_type = DataType::Decimal128(36, 10);
346
        assert!(covariance_return_type(&data_type).is_err());
347
        Ok(())
348
    }
349
350
    #[test]
351
    fn test_correlation_return_data_type() -> Result<()> {
352
        let data_type = DataType::Float64;
353
        let result_type = correlation_return_type(&data_type)?;
354
        assert_eq!(DataType::Float64, result_type);
355
356
        let data_type = DataType::Decimal128(36, 10);
357
        assert!(correlation_return_type(&data_type).is_err());
358
        Ok(())
359
    }
360
}