/Users/andrewlamb/Software/datafusion/datafusion/expr-common/src/type_coercion/aggregates.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::signature::TypeSignature; |
19 | | use arrow::datatypes::{ |
20 | | DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, |
21 | | DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, |
22 | | }; |
23 | | |
24 | | use datafusion_common::{internal_err, plan_err, Result}; |
25 | | |
26 | | pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; |
27 | | |
28 | | pub static SIGNED_INTEGERS: &[DataType] = &[ |
29 | | DataType::Int8, |
30 | | DataType::Int16, |
31 | | DataType::Int32, |
32 | | DataType::Int64, |
33 | | ]; |
34 | | |
35 | | pub static UNSIGNED_INTEGERS: &[DataType] = &[ |
36 | | DataType::UInt8, |
37 | | DataType::UInt16, |
38 | | DataType::UInt32, |
39 | | DataType::UInt64, |
40 | | ]; |
41 | | |
42 | | pub static INTEGERS: &[DataType] = &[ |
43 | | DataType::Int8, |
44 | | DataType::Int16, |
45 | | DataType::Int32, |
46 | | DataType::Int64, |
47 | | DataType::UInt8, |
48 | | DataType::UInt16, |
49 | | DataType::UInt32, |
50 | | DataType::UInt64, |
51 | | ]; |
52 | | |
53 | | pub static NUMERICS: &[DataType] = &[ |
54 | | DataType::Int8, |
55 | | DataType::Int16, |
56 | | DataType::Int32, |
57 | | DataType::Int64, |
58 | | DataType::UInt8, |
59 | | DataType::UInt16, |
60 | | DataType::UInt32, |
61 | | DataType::UInt64, |
62 | | DataType::Float32, |
63 | | DataType::Float64, |
64 | | ]; |
65 | | |
66 | | pub static TIMESTAMPS: &[DataType] = &[ |
67 | | DataType::Timestamp(TimeUnit::Second, None), |
68 | | DataType::Timestamp(TimeUnit::Millisecond, None), |
69 | | DataType::Timestamp(TimeUnit::Microsecond, None), |
70 | | DataType::Timestamp(TimeUnit::Nanosecond, None), |
71 | | ]; |
72 | | |
73 | | pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; |
74 | | |
75 | | pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; |
76 | | |
77 | | pub static TIMES: &[DataType] = &[ |
78 | | DataType::Time32(TimeUnit::Second), |
79 | | DataType::Time32(TimeUnit::Millisecond), |
80 | | DataType::Time64(TimeUnit::Microsecond), |
81 | | DataType::Time64(TimeUnit::Nanosecond), |
82 | | ]; |
83 | | |
84 | | /// Validate the length of `input_types` matches the `signature` for `agg_fun`. |
85 | | /// |
86 | | /// This method DOES NOT validate the argument types - only that (at least one, |
87 | | /// in the case of [`TypeSignature::OneOf`]) signature matches the desired |
88 | | /// number of input types. |
89 | 66 | pub fn check_arg_count( |
90 | 66 | func_name: &str, |
91 | 66 | input_types: &[DataType], |
92 | 66 | signature: &TypeSignature, |
93 | 66 | ) -> Result<()> { |
94 | 66 | match signature { |
95 | 7 | TypeSignature::Uniform(agg_count0 , _) | TypeSignature::Any(agg_count) => { |
96 | 7 | if input_types.len() != *agg_count { |
97 | 0 | return plan_err!( |
98 | 0 | "The function {func_name} expects {:?} arguments, but {:?} were provided", |
99 | 0 | agg_count, |
100 | 0 | input_types.len() |
101 | 0 | ); |
102 | 7 | } |
103 | | } |
104 | 0 | TypeSignature::Exact(types) => { |
105 | 0 | if types.len() != input_types.len() { |
106 | 0 | return plan_err!( |
107 | 0 | "The function {func_name} expects {:?} arguments, but {:?} were provided", |
108 | 0 | types.len(), |
109 | 0 | input_types.len() |
110 | 0 | ); |
111 | 0 | } |
112 | | } |
113 | 20 | TypeSignature::OneOf(variants) => { |
114 | 20 | let ok = variants |
115 | 20 | .iter() |
116 | 30 | .any(|v| check_arg_count(func_name, input_types, v).is_ok()); |
117 | 20 | if !ok { |
118 | 0 | return plan_err!( |
119 | 0 | "The function {func_name} does not accept {:?} function arguments.", |
120 | 0 | input_types.len() |
121 | 0 | ); |
122 | 20 | } |
123 | | } |
124 | | TypeSignature::VariadicAny => { |
125 | 10 | if input_types.is_empty() { |
126 | 0 | return plan_err!( |
127 | 0 | "The function {func_name} expects at least one argument" |
128 | 0 | ); |
129 | 10 | } |
130 | | } |
131 | | TypeSignature::UserDefined |
132 | | | TypeSignature::Numeric(_) |
133 | 19 | | TypeSignature::Coercible(_) => { |
134 | 19 | // User-defined signature is validated in `coerce_types` |
135 | 19 | // Numeric and Coercible signature is validated in `get_valid_types` |
136 | 19 | } |
137 | | _ => { |
138 | 10 | return internal_err!( |
139 | 10 | "Aggregate functions do not support this {signature:?}" |
140 | 10 | ); |
141 | | } |
142 | | } |
143 | 56 | Ok(()) |
144 | 66 | } |
145 | | |
146 | | /// function return type of a sum |
147 | 0 | pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> { |
148 | 0 | match arg_type { |
149 | 0 | DataType::Int64 => Ok(DataType::Int64), |
150 | 0 | DataType::UInt64 => Ok(DataType::UInt64), |
151 | 0 | DataType::Float64 => Ok(DataType::Float64), |
152 | 0 | DataType::Decimal128(precision, scale) => { |
153 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
154 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
155 | 0 | let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); |
156 | 0 | Ok(DataType::Decimal128(new_precision, *scale)) |
157 | | } |
158 | 0 | DataType::Decimal256(precision, scale) => { |
159 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+10), s) |
160 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 |
161 | 0 | let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); |
162 | 0 | Ok(DataType::Decimal256(new_precision, *scale)) |
163 | | } |
164 | 0 | other => plan_err!("SUM does not support type \"{other:?}\""), |
165 | | } |
166 | 0 | } |
167 | | |
168 | | /// function return type of variance |
169 | 0 | pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> { |
170 | 0 | if NUMERICS.contains(arg_type) { |
171 | 0 | Ok(DataType::Float64) |
172 | | } else { |
173 | 0 | plan_err!("VAR does not support {arg_type:?}") |
174 | | } |
175 | 0 | } |
176 | | |
177 | | /// function return type of covariance |
178 | 0 | pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> { |
179 | 0 | if NUMERICS.contains(arg_type) { |
180 | 0 | Ok(DataType::Float64) |
181 | | } else { |
182 | 0 | plan_err!("COVAR does not support {arg_type:?}") |
183 | | } |
184 | 0 | } |
185 | | |
186 | | /// function return type of correlation |
187 | 0 | pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> { |
188 | 0 | if NUMERICS.contains(arg_type) { |
189 | 0 | Ok(DataType::Float64) |
190 | | } else { |
191 | 0 | plan_err!("CORR does not support {arg_type:?}") |
192 | | } |
193 | 0 | } |
194 | | |
195 | | /// function return type of an average |
196 | 7 | pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType> { |
197 | 0 | match arg_type { |
198 | 0 | DataType::Decimal128(precision, scale) => { |
199 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). |
200 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 |
201 | 0 | let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); |
202 | 0 | let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); |
203 | 0 | Ok(DataType::Decimal128(new_precision, new_scale)) |
204 | | } |
205 | 0 | DataType::Decimal256(precision, scale) => { |
206 | 0 | // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). |
207 | 0 | // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 |
208 | 0 | let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); |
209 | 0 | let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); |
210 | 0 | Ok(DataType::Decimal256(new_precision, new_scale)) |
211 | | } |
212 | 7 | arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), |
213 | 0 | DataType::Dictionary(_, dict_value_type) => { |
214 | 0 | avg_return_type(func_name, dict_value_type.as_ref()) |
215 | | } |
216 | 0 | other => plan_err!("{func_name} does not support {other:?}"), |
217 | | } |
218 | 7 | } |
219 | | |
220 | | /// internal sum type of an average |
221 | 0 | pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> { |
222 | 0 | match arg_type { |
223 | 0 | DataType::Decimal128(precision, scale) => { |
224 | 0 | // in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) |
225 | 0 | let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); |
226 | 0 | Ok(DataType::Decimal128(new_precision, *scale)) |
227 | | } |
228 | 0 | DataType::Decimal256(precision, scale) => { |
229 | 0 | // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s) |
230 | 0 | let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); |
231 | 0 | Ok(DataType::Decimal256(new_precision, *scale)) |
232 | | } |
233 | 0 | arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), |
234 | 0 | DataType::Dictionary(_, dict_value_type) => { |
235 | 0 | avg_sum_type(dict_value_type.as_ref()) |
236 | | } |
237 | 0 | other => plan_err!("AVG does not support {other:?}"), |
238 | | } |
239 | 0 | } |
240 | | |
241 | 0 | pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { |
242 | 0 | match arg_type { |
243 | 0 | DataType::Dictionary(_, dict_value_type) => { |
244 | 0 | is_sum_support_arg_type(dict_value_type.as_ref()) |
245 | | } |
246 | 0 | _ => matches!( |
247 | 0 | arg_type, |
248 | 0 | arg_type if NUMERICS.contains(arg_type) |
249 | 0 | || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) |
250 | | ), |
251 | | } |
252 | 0 | } |
253 | | |
254 | 0 | pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { |
255 | 0 | match arg_type { |
256 | 0 | DataType::Dictionary(_, dict_value_type) => { |
257 | 0 | is_avg_support_arg_type(dict_value_type.as_ref()) |
258 | | } |
259 | 0 | _ => matches!( |
260 | 0 | arg_type, |
261 | 0 | arg_type if NUMERICS.contains(arg_type) |
262 | 0 | || matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _)) |
263 | | ), |
264 | | } |
265 | 0 | } |
266 | | |
267 | 0 | pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { |
268 | 0 | matches!( |
269 | 0 | arg_type, |
270 | 0 | arg_type if NUMERICS.contains(arg_type) |
271 | | ) |
272 | 0 | } |
273 | | |
274 | 0 | pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { |
275 | 0 | matches!( |
276 | 0 | arg_type, |
277 | 0 | arg_type if NUMERICS.contains(arg_type) |
278 | | ) |
279 | 0 | } |
280 | | |
281 | 0 | pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { |
282 | 0 | matches!( |
283 | 0 | arg_type, |
284 | 0 | arg_type if NUMERICS.contains(arg_type) |
285 | | ) |
286 | 0 | } |
287 | | |
288 | 0 | pub fn is_integer_arg_type(arg_type: &DataType) -> bool { |
289 | 0 | arg_type.is_integer() |
290 | 0 | } |
291 | | |
292 | 0 | pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
293 | | // Supported types smallint, int, bigint, real, double precision, decimal, or interval |
294 | | // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc |
295 | 0 | fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> { |
296 | 0 | return match &data_type { |
297 | 0 | DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), |
298 | 0 | DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), |
299 | 0 | d if d.is_numeric() => Ok(DataType::Float64), |
300 | 0 | DataType::Dictionary(_, v) => return coerced_type(func_name, v.as_ref()), |
301 | | _ => { |
302 | 0 | return plan_err!( |
303 | 0 | "The function {:?} does not support inputs of type {:?}.", |
304 | 0 | func_name, |
305 | 0 | data_type |
306 | 0 | ) |
307 | | } |
308 | | }; |
309 | 0 | } |
310 | 0 | Ok(vec![coerced_type(func_name, &arg_types[0])?]) |
311 | 0 | } |
312 | | #[cfg(test)] |
313 | | mod tests { |
314 | | use super::*; |
315 | | |
316 | | #[test] |
317 | | fn test_variance_return_data_type() -> Result<()> { |
318 | | let data_type = DataType::Float64; |
319 | | let result_type = variance_return_type(&data_type)?; |
320 | | assert_eq!(DataType::Float64, result_type); |
321 | | |
322 | | let data_type = DataType::Decimal128(36, 10); |
323 | | assert!(variance_return_type(&data_type).is_err()); |
324 | | Ok(()) |
325 | | } |
326 | | |
327 | | #[test] |
328 | | fn test_sum_return_data_type() -> Result<()> { |
329 | | let data_type = DataType::Decimal128(10, 5); |
330 | | let result_type = sum_return_type(&data_type)?; |
331 | | assert_eq!(DataType::Decimal128(20, 5), result_type); |
332 | | |
333 | | let data_type = DataType::Decimal128(36, 10); |
334 | | let result_type = sum_return_type(&data_type)?; |
335 | | assert_eq!(DataType::Decimal128(38, 10), result_type); |
336 | | Ok(()) |
337 | | } |
338 | | |
339 | | #[test] |
340 | | fn test_covariance_return_data_type() -> Result<()> { |
341 | | let data_type = DataType::Float64; |
342 | | let result_type = covariance_return_type(&data_type)?; |
343 | | assert_eq!(DataType::Float64, result_type); |
344 | | |
345 | | let data_type = DataType::Decimal128(36, 10); |
346 | | assert!(covariance_return_type(&data_type).is_err()); |
347 | | Ok(()) |
348 | | } |
349 | | |
350 | | #[test] |
351 | | fn test_correlation_return_data_type() -> Result<()> { |
352 | | let data_type = DataType::Float64; |
353 | | let result_type = correlation_return_type(&data_type)?; |
354 | | assert_eq!(DataType::Float64, result_type); |
355 | | |
356 | | let data_type = DataType::Decimal128(36, 10); |
357 | | assert!(correlation_return_type(&data_type).is_err()); |
358 | | Ok(()) |
359 | | } |
360 | | } |